Skip to content

Commit 104d051

Browse files
committed
[layer] EstimatorLayer
1 parent 88e7e94 commit 104d051

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

docs/modules/layers.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ Layer list
324324
ExpandDimsLayer
325325
TileLayer
326326

327+
EstimatorLayer
327328
SlimNetsLayer
328329
KerasLayer
329330

@@ -375,7 +376,7 @@ Input layer
375376
.. autoclass:: InputLayer
376377
:members:
377378

378-
One-Hot layer
379+
One-hot layer
379380
----------------
380381
.. autoclass:: OneHotInputLayer
381382

@@ -637,8 +638,9 @@ Tile layer
637638
^^^^^^^^^^^^^^^^^^^^
638639
.. autoclass:: TileLayer
639640

640-
641-
641+
Estimator layer
642+
------------------
643+
.. autoclass:: EstimatorLayer
642644

643645
Connect TF-Slim
644646
------------------

tensorlayer/layers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4750,6 +4750,43 @@ def __init__(
47504750
self.all_layers.extend( [self.outputs] )
47514751
self.all_params.extend( variables )
47524752

4753+
## Estimator layer
4754+
class EstimatorLayer(Layer):
4755+
"""
4756+
The :class:`EstimatorLayer` class accepts ``model_fn`` that described the model.
4757+
It is similar with :class:`KerasLayer`, see `tutorial_keras.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_keras.py>`_
4758+
4759+
Parameters
4760+
----------
4761+
layer : a :class:`Layer` instance
4762+
The `Layer` class feeding into this layer.
4763+
model_fn : a function that described the model.
4764+
args : dictionary
4765+
The arguments for the model_fn.
4766+
name : a string or None
4767+
An optional name to attach to this layer.
4768+
"""
4769+
def __init__(
4770+
self,
4771+
layer = None,
4772+
model_fn = None,
4773+
args = {},
4774+
name ='estimator_layer',
4775+
):
4776+
Layer.__init__(self, name=name)
4777+
assert layer is not None
4778+
assert model_fn is not None
4779+
self.inputs = layer.outputs
4780+
print(" [TL] EstimatorLayer %s: %s" % (self.name, model_fn))
4781+
with tf.variable_scope(name) as vs:
4782+
self.outputs = model_fn(self.inputs, **args)
4783+
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
4784+
self.all_layers = list(layer.all_layers)
4785+
self.all_params = list(layer.all_params)
4786+
self.all_drop = dict(layer.all_drop)
4787+
self.all_layers.extend( [self.outputs] )
4788+
self.all_params.extend( variables )
4789+
47534790
## Special activation
47544791
class PReluLayer(Layer):
47554792
"""

0 commit comments

Comments
 (0)