Skip to content

Commit 9a267a0

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents 5a055c7 + 4ecb07b commit 9a267a0

File tree

8 files changed

+517
-390
lines changed

8 files changed

+517
-390
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 168 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import mindspore as ms
77
import mindspore.ops as P
88
from mindspore import context
9+
from mindspore.ops.primitive import constexpr
910
from mindspore.nn.cell import Cell
1011
from mindspore._checkparam import Rel
1112
from mindspore.ops import functional as F
@@ -17,6 +18,8 @@
1718
from mindspore.communication.management import get_group_size, get_rank
1819
from mindspore.ops.operations import LayerNorm
1920
import mindspore.numpy as np
21+
from mindspore.common.parameter import ParameterTuple
22+
from mindspore.nn.layer.rnns import _DynamicRNN
2023
import warnings
2124
import math
2225

@@ -833,7 +836,9 @@ def __init__(self, ksize, strides, padding, data_format=None):
833836
self.data_format, self.padding = preprocess_2d_format(data_format=data_format, padding=padding)
834837
ms_ksize = ksize[1]
835838
ms_strides = strides[1]
836-
self.avgpool = P.AvgPool(kernel_size=ms_ksize, strides=ms_strides, pad_mode=padding, data_format=self.data_format)
839+
self.avgpool = P.AvgPool(
840+
kernel_size=ms_ksize, strides=ms_strides, pad_mode=padding, data_format=self.data_format
841+
)
837842

838843
def construct(self, inputs):
839844
outputs = self.avgpool(inputs)
@@ -930,7 +935,7 @@ def __init__(self, ksize, strides, padding, data_format='NCDHW'):
930935
if data_format == 'NCDHW':
931936
strides = (strides[2], strides[3], strides[4])
932937
print(ksize, strides, padding)
933-
self.avg_pool = P.AvgPool3D(kernel_size=ksize, strides = strides, pad_mode=padding, data_format=data_format)
938+
self.avg_pool = P.AvgPool3D(kernel_size=ksize, strides=strides, pad_mode=padding, data_format=data_format)
934939

935940
def __call__(self, inputs):
936941
return self.avg_pool(inputs)
@@ -1838,15 +1843,12 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
18381843
self.bias_ih = bias_ih
18391844
self.bias_hh = bias_hh
18401845
self.act_fn = P.ReLU() if act == 'relu' else P.Tanh()
1841-
self.transpose = P.Transpose()
18421846

18431847
def construct(self, input, h):
1844-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1845-
i2h = P.matmul(input, self.weight_ih)
1848+
i2h = P.MatMul(False, True)(input, self.weight_ih)
18461849
if self.bias_ih is not None:
18471850
i2h += self.bias_ih
1848-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1849-
h2h = P.matmul(h, self.weight_hh)
1851+
h2h = P.MatMul(False, True)(h, self.weight_hh)
18501852
if self.bias_hh is not None:
18511853
h2h += self.bias_hh
18521854
h = self.act_fn(i2h + h2h)
@@ -1863,17 +1865,14 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
18631865
self.bias_hh = bias_hh
18641866
self.gate_act_fn = P.Sigmoid()
18651867
self.act_fn = P.Tanh()
1866-
self.transpose = P.Transpose()
18671868
self.split = P.Split(axis=-1, output_num=4)
18681869

18691870
def construct(self, input, h, c):
18701871

1871-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1872-
gates = P.matmul(input, self.weight_ih)
1872+
gates = P.MatMul(False, True)(input, self.weight_ih)
18731873
if self.bias_ih is not None:
18741874
gates += self.bias_ih
1875-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1876-
gates += P.matmul(h, self.weight_hh)
1875+
gates += P.MatMul(False, True)(h, self.weight_hh)
18771876
if self.bias_hh is not None:
18781877
gates += self.bias_hh
18791878

@@ -1902,12 +1901,10 @@ def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
19021901

19031902
def construct(self, input, h):
19041903

1905-
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
1906-
x_gates = P.matmul(input, self.weight_ih)
1904+
x_gates = P.MatMul(False, True)(input, self.weight_ih)
19071905
if self.bias_ih is not None:
19081906
x_gates += self.bias_ih
1909-
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
1910-
h_gates = P.matmul(h, self.weight_hh)
1907+
h_gates = P.MatMul(False, True)(h, self.weight_hh)
19111908
if self.bias_hh is not None:
19121909
h_gates += self.bias_hh
19131910

@@ -1922,6 +1919,19 @@ def construct(self, input, h):
19221919
return h, h
19231920

19241921

1922+
@constexpr
1923+
def _init_state(shape, dtype, is_lstm):
1924+
hx = ms.Tensor(np.zeros(shape), dtype)
1925+
cx = ms.Tensor(np.zeros(shape), dtype)
1926+
if is_lstm:
1927+
return (hx, cx)
1928+
return hx
1929+
1930+
@constexpr
1931+
def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_name):
1932+
args = {args_name[i]: args_value[i] for i in range(len(args_value))}
1933+
validator.check_types_same_and_valid(args, valid_values, cls_name)
1934+
19251935
class rnnbase(Cell):
19261936

19271937
def __init__(
@@ -1935,47 +1945,159 @@ def __init__(
19351945
dropout,
19361946
bidirectional,
19371947
is_train,
1948+
w_ih,
1949+
w_hh,
1950+
b_ih,
1951+
b_hh,
19381952
):
19391953
super(rnnbase, self).__init__()
1954+
if not 0 <= dropout < 1:
1955+
raise ValueError("dropout should be a number in range [0, 1).")
1956+
if dropout > 0 and num_layers == 1:
1957+
raise ValueError(
1958+
"dropout option adds dropout after all but last "
1959+
"recurrent layer, so non-zero dropout expects "
1960+
"num_layers greater than 1, but got dropout={} and "
1961+
"num_layers={}".format(dropout, num_layers)
1962+
)
19401963
self.mode = mode
1964+
self.reverse = P.ReverseV2([0])
1965+
self.reverse_sequence = P.ReverseSequence(0, 1)
19411966
self.input_size = input_size
19421967
self.hidden_size = hidden_size
19431968
self.num_layers = num_layers
1944-
self.bidirect = 2 if bidirectional else 1
1969+
self.dropout = dropout
1970+
self.dropout_op = ms.nn.Dropout(float(1 - dropout))
1971+
self.has_bias = bias
1972+
self.bidirectional = bidirectional
19451973
self.batch_first = batch_first
1946-
if mode == 'LSTM':
1947-
self.lstm = ms.nn.LSTM(
1948-
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias,
1949-
batch_first=batch_first, dropout=dropout, bidirectional=bidirectional
1950-
)
1951-
elif mode == 'GRU':
1952-
1953-
raise NotImplementedError
1954-
1955-
elif mode == 'RNN_TANH':
1956-
1957-
raise NotImplementedError
1958-
1959-
elif mode == 'RNN_RELU':
1960-
1961-
raise NotImplementedError
1974+
self.train = is_train
1975+
self.w_ih_list = ParameterTuple(w_ih)
1976+
self.w_hh_list = ParameterTuple(w_hh)
1977+
self.b_ih_list = ParameterTuple(b_ih)
1978+
self.b_hh_list = ParameterTuple(b_hh)
1979+
self.rnn = _DynamicRNN(mode)
1980+
self.is_lstm = mode == "LSTM"
19621981

19631982
self.zeros = P.Zeros()
19641983

1965-
def construct(self, input, states):
1966-
input_shape = input.shape
1967-
input_dtype = input.dtype
1968-
if self.mode == 'LSTM':
1969-
if self.batch_first:
1970-
batch_size = input_shape[0]
1984+
def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
1985+
"""stacked bidirectional dynamic_rnn"""
1986+
pre_layer = x
1987+
h_n = ()
1988+
c_n = ()
1989+
output = 0
1990+
for i in range(self.num_layers):
1991+
offset = i * 2
1992+
if self.has_bias:
1993+
w_f_ih, w_f_hh, b_f_ih, b_f_hh = \
1994+
self.w_ih_list[offset], self.w_hh_list[offset], \
1995+
self.b_ih_list[offset], self.b_hh_list[offset]
1996+
w_b_ih, w_b_hh, b_b_ih, b_b_hh = \
1997+
self.w_ih_list[offset + 1], self.w_hh_list[offset + 1], \
1998+
self.b_ih_list[offset + 1], self.b_hh_list[offset + 1]
1999+
else:
2000+
w_f_ih, w_f_hh = self.w_ih_list[offset], self.w_hh_list[offset]
2001+
w_b_ih, w_b_hh = self.w_ih_list[offset + 1], self.w_hh_list[offset + 1]
2002+
b_f_ih, b_f_hh, b_b_ih, b_b_hh = None, None, None, None
2003+
if self.is_lstm:
2004+
h_f_i = (h[0][offset], h[1][offset])
2005+
h_b_i = (h[0][offset + 1], h[1][offset + 1])
19712006
else:
1972-
batch_size = input_shape[1]
1973-
if states is None:
1974-
h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
1975-
c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
1976-
states = (h, c)
1977-
output, (h, c) = self.lstm(input, states)
1978-
return output, (h, c)
2007+
h_f_i = h[offset]
2008+
h_b_i = h[offset + 1]
2009+
if seq_length is None:
2010+
x_b = self.reverse(pre_layer)
2011+
else:
2012+
x_b = self.reverse_sequence(pre_layer, seq_length)
2013+
output_f, h_t_f = self.rnn(pre_layer, h_f_i, seq_length, w_f_ih, w_f_hh, b_f_ih, b_f_hh)
2014+
output_b, h_t_b = self.rnn(x_b, h_b_i, seq_length, w_b_ih, w_b_hh, b_b_ih, b_b_hh)
2015+
if seq_length is None:
2016+
output_b = self.reverse(output_b)
2017+
else:
2018+
output_b = self.reverse_sequence(output_b, seq_length)
2019+
output = P.Concat(2)((output_f, output_b))
2020+
pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
2021+
if self.is_lstm:
2022+
h_n += (
2023+
h_t_f[0],
2024+
h_t_b[0],
2025+
)
2026+
c_n += (
2027+
h_t_f[1],
2028+
h_t_b[1],
2029+
)
2030+
else:
2031+
h_n += (
2032+
h_t_f,
2033+
h_t_b,
2034+
)
2035+
if self.is_lstm:
2036+
h_n = P.Concat(0)(h_n)
2037+
c_n = P.Concat(0)(c_n)
2038+
h_n = h_n.view(h[0].shape)
2039+
c_n = c_n.view(h[1].shape)
2040+
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
2041+
h_n = P.Concat(0)(h_n)
2042+
return output, h_n.view(h.shape)
2043+
2044+
def _stacked_dynamic_rnn(self, x, h, seq_length):
2045+
"""stacked mutil_layer dynamic_rnn"""
2046+
pre_layer = x
2047+
h_n = ()
2048+
c_n = ()
2049+
output = 0
2050+
for i in range(self.num_layers):
2051+
if self.has_bias:
2052+
w_ih, w_hh, b_ih, b_hh = self.w_ih_list[i], self.w_hh_list[i], self.b_ih_list[i], self.b_hh_list[i]
2053+
else:
2054+
w_ih, w_hh = self.w_ih_list[i], self.w_hh_list[i]
2055+
b_ih, b_hh = None, None
2056+
if self.is_lstm:
2057+
h_i = (h[0][i], h[1][i])
2058+
else:
2059+
h_i = h[i]
2060+
output, h_t = self.rnn(pre_layer, h_i, seq_length, w_ih, w_hh, b_ih, b_hh)
2061+
pre_layer = self.dropout_op(output) if (self.dropout != 0 and i < self.num_layers - 1) else output
2062+
if self.is_lstm:
2063+
h_n += (h_t[0], )
2064+
c_n += (h_t[1], )
2065+
else:
2066+
h_n += (h_t, )
2067+
if self.is_lstm:
2068+
h_n = P.Concat(0)(h_n)
2069+
c_n = P.Concat(0)(c_n)
2070+
h_n = h_n.view(h[0].shape)
2071+
c_n = c_n.view(h[1].shape)
2072+
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
2073+
h_n = P.Concat(0)(h_n)
2074+
return output, h_n.view(h.shape)
2075+
2076+
def construct(self, x, hx=None, seq_length=None):
2077+
'''Defines the RNN like operators performed'''
2078+
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
2079+
num_directions = 2 if self.bidirectional else 1
2080+
x_dtype = x.dtype
2081+
if hx is not None:
2082+
if not self.is_lstm:
2083+
_check_input_dtype_same_and_valid(['x', 'hx'], [x_dtype, hx.dtype], \
2084+
[ms.float32, ms.float16], self.cls_name)
2085+
else:
2086+
_check_input_dtype_same_and_valid(['x', 'hx[0]', 'hx[1]'], [x_dtype, hx[0].dtype, hx[1].dtype], \
2087+
[ms.float32, ms.float16], self.cls_name)
2088+
else:
2089+
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), x_dtype, self.is_lstm)
2090+
if self.batch_first:
2091+
x = P.Transpose()(x, (1, 0, 2))
2092+
if self.bidirectional:
2093+
x_n, hx_n = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
2094+
else:
2095+
x_n, hx_n = self._stacked_dynamic_rnn(x, hx, seq_length)
2096+
if self.batch_first:
2097+
x_n = P.Transpose()(x_n, (1, 0, 2))
2098+
if not self.is_lstm:
2099+
return x_n.astype(x_dtype), hx_n.astype(x_dtype)
2100+
return x_n.astype(x_dtype), (hx_n[0].astype(x_dtype), hx_n[1].astype(x_dtype))
19792101

19802102

19812103
class layernorm(Cell):

0 commit comments

Comments
 (0)