Skip to content

Commit 10b48da

Browse files
committed
[layers] default rnn cell is None, for TF12
1 parent 8bd7eb7 commit 10b48da

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

tensorlayer/layers.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,7 +2851,7 @@ class RNNLayer(Layer):
28512851
----------
28522852
layer : a :class:`Layer` instance
28532853
The `Layer` class feeding into this layer.
2854-
cell_fn : a TensorFlow's core RNN cell as follow.
2854+
cell_fn : a TensorFlow's core RNN cell as follow (Note TF1.0+ is different).
28552855
- see `RNN Cells in TensorFlow <https://www.tensorflow.org/api_docs/python/rnn_cell/>`_
28562856
- class ``tf.nn.rnn_cell.BasicRNNCell``
28572857
- class ``tf.nn.rnn_cell.BasicLSTMCell``
@@ -2996,7 +2996,7 @@ class RNNLayer(Layer):
29962996
def __init__(
29972997
self,
29982998
layer = None,
2999-
cell_fn = tf.nn.rnn_cell.BasicRNNCell,
2999+
cell_fn = None,#tf.nn.rnn_cell.BasicRNNCell,
30003000
cell_init_args = {},
30013001
n_hidden = 100,
30023002
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -3008,6 +3008,9 @@ def __init__(
30083008
name = 'rnn_layer',
30093009
):
30103010
Layer.__init__(self, name=name)
3011+
if cell_fn is None:
3012+
raise Exception("Please put in cell_fn")
3013+
30113014
self.inputs = layer.outputs
30123015

30133016
print(" tensorlayer:Instantiate RNNLayer %s: n_hidden:%d, n_steps:%d, in_dim:%d %s, cell_fn:%s " % (self.name, n_hidden,
@@ -3101,7 +3104,7 @@ class BiRNNLayer(Layer):
31013104
----------
31023105
layer : a :class:`Layer` instance
31033106
The `Layer` class feeding into this layer.
3104-
cell_fn : a TensorFlow's core RNN cell as follow.
3107+
cell_fn : a TensorFlow's core RNN cell as follow (Note TF1.0+ is different).
31053108
- see `RNN Cells in TensorFlow <https://www.tensorflow.org/api_docs/python/rnn_cell/>`_
31063109
- class ``tf.nn.rnn_cell.BasicRNNCell``
31073110
- class ``tf.nn.rnn_cell.BasicLSTMCell``
@@ -3169,7 +3172,7 @@ class BiRNNLayer(Layer):
31693172
def __init__(
31703173
self,
31713174
layer = None,
3172-
cell_fn = tf.nn.rnn_cell.LSTMCell,
3175+
cell_fn = None, #tf.nn.rnn_cell.LSTMCell,
31733176
cell_init_args = {'use_peepholes':True, 'state_is_tuple':True},
31743177
n_hidden = 100,
31753178
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -3183,6 +3186,8 @@ def __init__(
31833186
name = 'birnn_layer',
31843187
):
31853188
Layer.__init__(self, name=name)
3189+
if cell_fn is None:
3190+
raise Exception("Please put in cell_fn")
31863191
self.inputs = layer.outputs
31873192

31883193
print(" tensorlayer:Instantiate BiRNNLayer %s: n_hidden:%d, n_steps:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d " % (self.name, n_hidden,
@@ -3409,7 +3414,7 @@ class DynamicRNNLayer(Layer):
34093414
----------
34103415
layer : a :class:`Layer` instance
34113416
The `Layer` class feeding into this layer.
3412-
cell_fn : a TensorFlow's core RNN cell as follow.
3417+
cell_fn : a TensorFlow's core RNN cell as follow (Note TF1.0+ is different).
34133418
- see `RNN Cells in TensorFlow <https://www.tensorflow.org/api_docs/python/rnn_cell/>`_
34143419
- class ``tf.nn.rnn_cell.BasicRNNCell``
34153420
- class ``tf.nn.rnn_cell.BasicLSTMCell``
@@ -3499,7 +3504,7 @@ class DynamicRNNLayer(Layer):
34993504
def __init__(
35003505
self,
35013506
layer = None,
3502-
cell_fn = tf.nn.rnn_cell.LSTMCell,
3507+
cell_fn = None,#tf.nn.rnn_cell.LSTMCell,
35033508
cell_init_args = {'state_is_tuple' : True},
35043509
n_hidden = 256,
35053510
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -3512,6 +3517,8 @@ def __init__(
35123517
name = 'dyrnn_layer',
35133518
):
35143519
Layer.__init__(self, name=name)
3520+
if cell_fn is None:
3521+
raise Exception("Please put in cell_fn")
35153522
self.inputs = layer.outputs
35163523

35173524
print(" tensorlayer:Instantiate DynamicRNNLayer %s: n_hidden:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d" % (self.name, n_hidden,
@@ -3631,7 +3638,7 @@ class BiDynamicRNNLayer(Layer):
36313638
----------
36323639
layer : a :class:`Layer` instance
36333640
The `Layer` class feeding into this layer.
3634-
cell_fn : a TensorFlow's core RNN cell as follow.
3641+
cell_fn : a TensorFlow's core RNN cell as follow (Note TF1.0+ is different).
36353642
- see `RNN Cells in TensorFlow <https://www.tensorflow.org/api_docs/python/rnn_cell/>`_\n
36363643
- class ``tf.nn.rnn_cell.BasicRNNCell``
36373644
- class ``tf.nn.rnn_cell.BasicLSTMCell``
@@ -3703,7 +3710,7 @@ class BiDynamicRNNLayer(Layer):
37033710
def __init__(
37043711
self,
37053712
layer = None,
3706-
cell_fn = tf.nn.rnn_cell.LSTMCell,
3713+
cell_fn = None,#tf.nn.rnn_cell.LSTMCell,
37073714
cell_init_args = {'state_is_tuple':True},
37083715
n_hidden = 256,
37093716
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -3717,6 +3724,8 @@ def __init__(
37173724
name = 'bi_dyrnn_layer',
37183725
):
37193726
Layer.__init__(self, name=name)
3727+
if cell_fn is None:
3728+
raise Exception("Please put in cell_fn")
37203729
self.inputs = layer.outputs
37213730

37223731
print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d" %
@@ -3843,7 +3852,7 @@ class Seq2Seq(Layer):
38433852
Encode sequences, [batch_size, None, n_features].
38443853
net_decode_in : a :class:`Layer` instance
38453854
Decode sequences, [batch_size, None, n_features].
3846-
cell_fn : a TensorFlow's core RNN cell as follow.
3855+
cell_fn : a TensorFlow's core RNN cell as follow (Note TF1.0+ is different)
38473856
- see `RNN Cells in TensorFlow <https://www.tensorflow.org/api_docs/python/rnn_cell/>`_\n
38483857
- class ``tf.nn.rnn_cell.BasicRNNCell``
38493858
- class ``tf.nn.rnn_cell.BasicLSTMCell``
@@ -3929,7 +3938,7 @@ def __init__(
39293938
self,
39303939
net_encode_in = None,
39313940
net_decode_in = None,
3932-
cell_fn = tf.nn.rnn_cell.LSTMCell,
3941+
cell_fn = None,#tf.nn.rnn_cell.LSTMCell,
39333942
cell_init_args = {'state_is_tuple':True},
39343943
n_hidden = 256,
39353944
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -3943,6 +3952,8 @@ def __init__(
39433952
name = 'seq2seq',
39443953
):
39453954
Layer.__init__(self, name=name)
3955+
if cell_fn is None:
3956+
raise Exception("Please put in cell_fn")
39463957
# self.inputs = layer.outputs
39473958
print(" tensorlayer:Instantiate Seq2Seq %s: n_hidden:%d, cell_fn:%s, dropout:%s, n_layer:%d" %
39483959
(self.name, n_hidden, cell_fn.__name__, dropout, n_layer))
@@ -4003,7 +4014,7 @@ def __init__(
40034014
self,
40044015
net_encode_in = None,
40054016
net_decode_in = None,
4006-
cell_fn = tf.nn.rnn_cell.LSTMCell,
4017+
cell_fn = None,#tf.nn.rnn_cell.LSTMCell,
40074018
cell_init_args = {'state_is_tuple':True},
40084019
n_hidden = 256,
40094020
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -4017,6 +4028,8 @@ def __init__(
40174028
name = 'peeky_seq2seq',
40184029
):
40194030
Layer.__init__(self, name=name)
4031+
if cell_fn is None:
4032+
raise Exception("Please put in cell_fn")
40204033
# self.inputs = layer.outputs
40214034
print(" tensorlayer:Instantiate PeekySeq2seq %s: n_hidden:%d, cell_fn:%s, dropout:%s, n_layer:%d" %
40224035
(self.name, n_hidden, cell_fn.__name__, dropout, n_layer))
@@ -4032,7 +4045,7 @@ def __init__(
40324045
self,
40334046
net_encode_in = None,
40344047
net_decode_in = None,
4035-
cell_fn = tf.nn.rnn_cell.LSTMCell,
4048+
cell_fn = None,#tf.nn.rnn_cell.LSTMCell,
40364049
cell_init_args = {'state_is_tuple':True},
40374050
n_hidden = 256,
40384051
initializer = tf.random_uniform_initializer(-0.1, 0.1),
@@ -4046,6 +4059,8 @@ def __init__(
40464059
name = 'attention_seq2seq',
40474060
):
40484061
Layer.__init__(self, name=name)
4062+
if cell_fn is None:
4063+
raise Exception("Please put in cell_fn")
40494064
# self.inputs = layer.outputs
40504065
print(" tensorlayer:Instantiate PeekySeq2seq %s: n_hidden:%d, cell_fn:%s, dropout:%s, n_layer:%d" %
40514066
(self.name, n_hidden, cell_fn.__name__, dropout, n_layer))

0 commit comments

Comments
 (0)