66import mindspore as ms
77import mindspore .ops as P
88from mindspore import context
9+ from mindspore .ops .primitive import constexpr
910from mindspore .nn .cell import Cell
1011from mindspore ._checkparam import Rel
1112from mindspore .ops import functional as F
1718from mindspore .communication .management import get_group_size , get_rank
1819from mindspore .ops .operations import LayerNorm
1920import mindspore .numpy as np
21+ from mindspore .common .parameter import ParameterTuple
22+ from mindspore .nn .layer .rnns import _DynamicRNN
2023import warnings
2124import math
2225
@@ -833,7 +836,9 @@ def __init__(self, ksize, strides, padding, data_format=None):
833836 self .data_format , self .padding = preprocess_2d_format (data_format = data_format , padding = padding )
834837 ms_ksize = ksize [1 ]
835838 ms_strides = strides [1 ]
836- self .avgpool = P .AvgPool (kernel_size = ms_ksize , strides = ms_strides , pad_mode = padding , data_format = self .data_format )
839+ self .avgpool = P .AvgPool (
840+ kernel_size = ms_ksize , strides = ms_strides , pad_mode = padding , data_format = self .data_format
841+ )
837842
838843 def construct (self , inputs ):
839844 outputs = self .avgpool (inputs )
@@ -930,7 +935,7 @@ def __init__(self, ksize, strides, padding, data_format='NCDHW'):
930935 if data_format == 'NCDHW' :
931936 strides = (strides [2 ], strides [3 ], strides [4 ])
932937 print (ksize , strides , padding )
933- self .avg_pool = P .AvgPool3D (kernel_size = ksize , strides = strides , pad_mode = padding , data_format = data_format )
938+ self .avg_pool = P .AvgPool3D (kernel_size = ksize , strides = strides , pad_mode = padding , data_format = data_format )
934939
935940 def __call__ (self , inputs ):
936941 return self .avg_pool (inputs )
@@ -1838,15 +1843,12 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
18381843 self .bias_ih = bias_ih
18391844 self .bias_hh = bias_hh
18401845 self .act_fn = P .ReLU () if act == 'relu' else P .Tanh ()
1841- self .transpose = P .Transpose ()
18421846
18431847 def construct (self , input , h ):
1844- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1845- i2h = P .matmul (input , self .weight_ih )
1848+ i2h = P .MatMul (False , True )(input , self .weight_ih )
18461849 if self .bias_ih is not None :
18471850 i2h += self .bias_ih
1848- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1849- h2h = P .matmul (h , self .weight_hh )
1851+ h2h = P .MatMul (False , True )(h , self .weight_hh )
18501852 if self .bias_hh is not None :
18511853 h2h += self .bias_hh
18521854 h = self .act_fn (i2h + h2h )
@@ -1863,17 +1865,14 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
18631865 self .bias_hh = bias_hh
18641866 self .gate_act_fn = P .Sigmoid ()
18651867 self .act_fn = P .Tanh ()
1866- self .transpose = P .Transpose ()
18671868 self .split = P .Split (axis = - 1 , output_num = 4 )
18681869
18691870 def construct (self , input , h , c ):
18701871
1871- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1872- gates = P .matmul (input , self .weight_ih )
1872+ gates = P .MatMul (False , True )(input , self .weight_ih )
18731873 if self .bias_ih is not None :
18741874 gates += self .bias_ih
1875- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1876- gates += P .matmul (h , self .weight_hh )
1875+ gates += P .MatMul (False , True )(h , self .weight_hh )
18771876 if self .bias_hh is not None :
18781877 gates += self .bias_hh
18791878
@@ -1902,12 +1901,10 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
19021901
19031902 def construct (self , input , h ):
19041903
1905- self .weight_ih = self .transpose (self .weight_ih , (1 , 0 ))
1906- x_gates = P .matmul (input , self .weight_ih )
1904+ x_gates = P .MatMul (False , True )(input , self .weight_ih )
19071905 if self .bias_ih is not None :
19081906 x_gates += self .bias_ih
1909- self .weight_hh = self .transpose (self .weight_hh , (1 , 0 ))
1910- h_gates = P .matmul (h , self .weight_hh )
1907+ h_gates = P .MatMul (False , True )(h , self .weight_hh )
19111908 if self .bias_hh is not None :
19121909 h_gates += self .bias_hh
19131910
@@ -1922,6 +1919,19 @@ def construct(self, input, h):
19221919 return h , h
19231920
19241921
1922+ @constexpr
1923+ def _init_state (shape , dtype , is_lstm ):
1924+ hx = ms .Tensor (np .zeros (shape ), dtype )
1925+ cx = ms .Tensor (np .zeros (shape ), dtype )
1926+ if is_lstm :
1927+ return (hx , cx )
1928+ return hx
1929+
1930+ @constexpr
1931+ def _check_input_dtype_same_and_valid (args_name , args_value , valid_values , cls_name ):
1932+ args = {args_name [i ]: args_value [i ] for i in range (len (args_value ))}
1933+ validator .check_types_same_and_valid (args , valid_values , cls_name )
1934+
19251935class rnnbase (Cell ):
19261936
19271937 def __init__ (
@@ -1935,47 +1945,159 @@ def __init__(
19351945 dropout ,
19361946 bidirectional ,
19371947 is_train ,
1948+ w_ih ,
1949+ w_hh ,
1950+ b_ih ,
1951+ b_hh ,
19381952 ):
19391953 super (rnnbase , self ).__init__ ()
1954+ if not 0 <= dropout < 1 :
1955+ raise ValueError ("dropout should be a number in range [0, 1)." )
1956+ if dropout > 0 and num_layers == 1 :
1957+ raise ValueError (
1958+ "dropout option adds dropout after all but last "
1959+ "recurrent layer, so non-zero dropout expects "
1960+ "num_layers greater than 1, but got dropout={} and "
1961+ "num_layers={}" .format (dropout , num_layers )
1962+ )
19401963 self .mode = mode
1964+ self .reverse = P .ReverseV2 ([0 ])
1965+ self .reverse_sequence = P .ReverseSequence (0 , 1 )
19411966 self .input_size = input_size
19421967 self .hidden_size = hidden_size
19431968 self .num_layers = num_layers
1944- self .bidirect = 2 if bidirectional else 1
1969+ self .dropout = dropout
1970+ self .dropout_op = ms .nn .Dropout (float (1 - dropout ))
1971+ self .has_bias = bias
1972+ self .bidirectional = bidirectional
19451973 self .batch_first = batch_first
1946- if mode == 'LSTM' :
1947- self .lstm = ms .nn .LSTM (
1948- input_size = input_size , hidden_size = hidden_size , num_layers = num_layers , has_bias = bias ,
1949- batch_first = batch_first , dropout = dropout , bidirectional = bidirectional
1950- )
1951- elif mode == 'GRU' :
1952-
1953- raise NotImplementedError
1954-
1955- elif mode == 'RNN_TANH' :
1956-
1957- raise NotImplementedError
1958-
1959- elif mode == 'RNN_RELU' :
1960-
1961- raise NotImplementedError
1974+ self .train = is_train
1975+ self .w_ih_list = ParameterTuple (w_ih )
1976+ self .w_hh_list = ParameterTuple (w_hh )
1977+ self .b_ih_list = ParameterTuple (b_ih )
1978+ self .b_hh_list = ParameterTuple (b_hh )
1979+ self .rnn = _DynamicRNN (mode )
1980+ self .is_lstm = mode == "LSTM"
19621981
19631982 self .zeros = P .Zeros ()
19641983
1965- def construct (self , input , states ):
1966- input_shape = input .shape
1967- input_dtype = input .dtype
1968- if self .mode == 'LSTM' :
1969- if self .batch_first :
1970- batch_size = input_shape [0 ]
1984+ def _stacked_bi_dynamic_rnn (self , x , h , seq_length ):
1985+ """stacked bidirectional dynamic_rnn"""
1986+ pre_layer = x
1987+ h_n = ()
1988+ c_n = ()
1989+ output = 0
1990+ for i in range (self .num_layers ):
1991+ offset = i * 2
1992+ if self .has_bias :
1993+ w_f_ih , w_f_hh , b_f_ih , b_f_hh = \
1994+ self .w_ih_list [offset ], self .w_hh_list [offset ], \
1995+ self .b_ih_list [offset ], self .b_hh_list [offset ]
1996+ w_b_ih , w_b_hh , b_b_ih , b_b_hh = \
1997+ self .w_ih_list [offset + 1 ], self .w_hh_list [offset + 1 ], \
1998+ self .b_ih_list [offset + 1 ], self .b_hh_list [offset + 1 ]
1999+ else :
2000+ w_f_ih , w_f_hh = self .w_ih_list [offset ], self .w_hh_list [offset ]
2001+ w_b_ih , w_b_hh = self .w_ih_list [offset + 1 ], self .w_hh_list [offset + 1 ]
2002+ b_f_ih , b_f_hh , b_b_ih , b_b_hh = None , None , None , None
2003+ if self .is_lstm :
2004+ h_f_i = (h [0 ][offset ], h [1 ][offset ])
2005+ h_b_i = (h [0 ][offset + 1 ], h [1 ][offset + 1 ])
19712006 else :
1972- batch_size = input_shape [1 ]
1973- if states is None :
1974- h = self .zeros ((self .bidirect * self .num_layers , batch_size , self .hidden_size ), input_dtype )
1975- c = self .zeros ((self .bidirect * self .num_layers , batch_size , self .hidden_size ), input_dtype )
1976- states = (h , c )
1977- output , (h , c ) = self .lstm (input , states )
1978- return output , (h , c )
2007+ h_f_i = h [offset ]
2008+ h_b_i = h [offset + 1 ]
2009+ if seq_length is None :
2010+ x_b = self .reverse (pre_layer )
2011+ else :
2012+ x_b = self .reverse_sequence (pre_layer , seq_length )
2013+ output_f , h_t_f = self .rnn (pre_layer , h_f_i , seq_length , w_f_ih , w_f_hh , b_f_ih , b_f_hh )
2014+ output_b , h_t_b = self .rnn (x_b , h_b_i , seq_length , w_b_ih , w_b_hh , b_b_ih , b_b_hh )
2015+ if seq_length is None :
2016+ output_b = self .reverse (output_b )
2017+ else :
2018+ output_b = self .reverse_sequence (output_b , seq_length )
2019+ output = P .Concat (2 )((output_f , output_b ))
2020+ pre_layer = self .dropout_op (output ) if (self .dropout != 0 and i < self .num_layers - 1 ) else output
2021+ if self .is_lstm :
2022+ h_n += (
2023+ h_t_f [0 ],
2024+ h_t_b [0 ],
2025+ )
2026+ c_n += (
2027+ h_t_f [1 ],
2028+ h_t_b [1 ],
2029+ )
2030+ else :
2031+ h_n += (
2032+ h_t_f ,
2033+ h_t_b ,
2034+ )
2035+ if self .is_lstm :
2036+ h_n = P .Concat (0 )(h_n )
2037+ c_n = P .Concat (0 )(c_n )
2038+ h_n = h_n .view (h [0 ].shape )
2039+ c_n = c_n .view (h [1 ].shape )
2040+ return output , (h_n .view (h [0 ].shape ), c_n .view (h [1 ].shape ))
2041+ h_n = P .Concat (0 )(h_n )
2042+ return output , h_n .view (h .shape )
2043+
2044+ def _stacked_dynamic_rnn (self , x , h , seq_length ):
2045+ """stacked mutil_layer dynamic_rnn"""
2046+ pre_layer = x
2047+ h_n = ()
2048+ c_n = ()
2049+ output = 0
2050+ for i in range (self .num_layers ):
2051+ if self .has_bias :
2052+ w_ih , w_hh , b_ih , b_hh = self .w_ih_list [i ], self .w_hh_list [i ], self .b_ih_list [i ], self .b_hh_list [i ]
2053+ else :
2054+ w_ih , w_hh = self .w_ih_list [i ], self .w_hh_list [i ]
2055+ b_ih , b_hh = None , None
2056+ if self .is_lstm :
2057+ h_i = (h [0 ][i ], h [1 ][i ])
2058+ else :
2059+ h_i = h [i ]
2060+ output , h_t = self .rnn (pre_layer , h_i , seq_length , w_ih , w_hh , b_ih , b_hh )
2061+ pre_layer = self .dropout_op (output ) if (self .dropout != 0 and i < self .num_layers - 1 ) else output
2062+ if self .is_lstm :
2063+ h_n += (h_t [0 ], )
2064+ c_n += (h_t [1 ], )
2065+ else :
2066+ h_n += (h_t , )
2067+ if self .is_lstm :
2068+ h_n = P .Concat (0 )(h_n )
2069+ c_n = P .Concat (0 )(c_n )
2070+ h_n = h_n .view (h [0 ].shape )
2071+ c_n = c_n .view (h [1 ].shape )
2072+ return output , (h_n .view (h [0 ].shape ), c_n .view (h [1 ].shape ))
2073+ h_n = P .Concat (0 )(h_n )
2074+ return output , h_n .view (h .shape )
2075+
2076+ def construct (self , x , hx = None , seq_length = None ):
2077+ '''Defines the RNN like operators performed'''
2078+ max_batch_size = x .shape [0 ] if self .batch_first else x .shape [1 ]
2079+ num_directions = 2 if self .bidirectional else 1
2080+ x_dtype = x .dtype
2081+ if hx is not None :
2082+ if not self .is_lstm :
2083+ _check_input_dtype_same_and_valid (['x' , 'hx' ], [x_dtype , hx .dtype ], \
2084+ [ms .float32 , ms .float16 ], self .cls_name )
2085+ else :
2086+ _check_input_dtype_same_and_valid (['x' , 'hx[0]' , 'hx[1]' ], [x_dtype , hx [0 ].dtype , hx [1 ].dtype ], \
2087+ [ms .float32 , ms .float16 ], self .cls_name )
2088+ else :
2089+ hx = _init_state ((self .num_layers * num_directions , max_batch_size , self .hidden_size ), x_dtype , self .is_lstm )
2090+ if self .batch_first :
2091+ x = P .Transpose ()(x , (1 , 0 , 2 ))
2092+ if self .bidirectional :
2093+ x_n , hx_n = self ._stacked_bi_dynamic_rnn (x , hx , seq_length )
2094+ else :
2095+ x_n , hx_n = self ._stacked_dynamic_rnn (x , hx , seq_length )
2096+ if self .batch_first :
2097+ x_n = P .Transpose ()(x_n , (1 , 0 , 2 ))
2098+ if not self .is_lstm :
2099+ return x_n .astype (x_dtype ), hx_n .astype (x_dtype )
2100+ return x_n .astype (x_dtype ), (hx_n [0 ].astype (x_dtype ), hx_n [1 ].astype (x_dtype ))
19792101
19802102
19812103class layernorm (Cell ):
0 commit comments