Skip to content

Commit

Permalink
Merge pull request #4347 from nitnelave/python/layer_dict
Browse files Browse the repository at this point in the history
[pycaffe] add layer_dict to the python interface
  • Loading branch information
shelhamer authored Feb 17, 2017
2 parents 28ffe9c + 5417f10 commit c510e33
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def _Net_blob_loss_weights(self):
self._blob_loss_weights))
return self._blob_loss_weights_dict

@property
def _Net_layer_dict(self):
"""
An OrderedDict (bottom to top, i.e., input to output) of network
layers indexed by name
"""
if not hasattr(self, '_layer_dict'):
self._layer_dict = OrderedDict(zip(self._layer_names, self.layers))
return self._layer_dict


@property
def _Net_params(self):
Expand Down Expand Up @@ -321,6 +331,7 @@ def get_id_name(self):
# Attach methods to Net.
Net.blobs = _Net_blobs
Net.blob_loss_weights = _Net_blob_loss_weights
Net.layer_dict = _Net_layer_dict
Net.params = _Net_params
Net.forward = _Net_forward
Net.backward = _Net_backward
Expand Down
7 changes: 7 additions & 0 deletions python/caffe/test/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_memory(self):
for bl in blobs:
total += bl.data.sum() + bl.diff.sum()

def test_layer_dict(self):
layer_dict = self.net.layer_dict
self.assertEqual(list(layer_dict.keys()), list(self.net._layer_names))
for i, name in enumerate(self.net._layer_names):
self.assertEqual(layer_dict[name].type,
self.net.layers[i].type)

def test_forward_backward(self):
self.net.forward()
self.net.backward()
Expand Down

0 comments on commit c510e33

Please sign in to comment.