@@ -4750,6 +4750,43 @@ def __init__(
47504750 self .all_layers .extend ( [self .outputs ] )
47514751 self .all_params .extend ( variables )
47524752
4753+ ## Estimator layer
4754+ class EstimatorLayer (Layer ):
4755+ """
4756+ The :class:`EstimatorLayer` class accepts ``model_fn`` that described the model.
4757+ It is similar with :class:`KerasLayer`, see `tutorial_keras.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_keras.py>`_
4758+
4759+ Parameters
4760+ ----------
4761+ layer : a :class:`Layer` instance
4762+ The `Layer` class feeding into this layer.
4763+ model_fn : a function that described the model.
4764+ args : dictionary
4765+ The arguments for the model_fn.
4766+ name : a string or None
4767+ An optional name to attach to this layer.
4768+ """
4769+ def __init__ (
4770+ self ,
4771+ layer = None ,
4772+ model_fn = None ,
4773+ args = {},
4774+ name = 'estimator_layer' ,
4775+ ):
4776+ Layer .__init__ (self , name = name )
4777+ assert layer is not None
4778+ assert model_fn is not None
4779+ self .inputs = layer .outputs
4780+ print (" [TL] EstimatorLayer %s: %s" % (self .name , model_fn ))
4781+ with tf .variable_scope (name ) as vs :
4782+ self .outputs = model_fn (self .inputs , ** args )
4783+ variables = tf .get_collection (TF_GRAPHKEYS_VARIABLES , scope = vs .name )
4784+ self .all_layers = list (layer .all_layers )
4785+ self .all_params = list (layer .all_params )
4786+ self .all_drop = dict (layer .all_drop )
4787+ self .all_layers .extend ( [self .outputs ] )
4788+ self .all_params .extend ( variables )
4789+
47534790## Special activation
47544791class PReluLayer (Layer ):
47554792 """
0 commit comments