Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parent_block.var(name) error in static mode for RNN #41162

Merged
merged 1 commit into from
Mar 31, 2022

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented Mar 30, 2022

PR types

Bug fixes

PR changes

Others

Describe

修复rnn在控制流中使用时,rnn调用parent_block.var(name)报错的问题。
问题描述:在静态图下网络参数都会在block0中,rnn的静态图逻辑中会调用parent_block.var去父block中找param,但是如果父block不是block不是block0则会报错。用户提供了如下的动转静代码,改用_find_var_recursive后动转静可以正常导出:

import paddle
from paddle import nn
import paddle.tensor as tensor
import paddle.nn.functional as F
import paddle.nn.initializer as I

class LSTMCell(nn.RNNCellBase):
    def __init__(self,
                input_size: int,
                hidden_size: int,
                activation="tanh",
                weight_ih_attr=None,
                weight_hh_attr=None,
                bias_ih_attr=None,
                bias_hh_attr=None,
                guass_mean=0.0,
                guass_std=0.02,
                name=None):
        super(LSTMCell, self).__init__()
        self.weight_ih = self.create_parameter(
            (4 * hidden_size, input_size),
            weight_ih_attr,
            default_initializer=I.Normal(guass_mean, guass_std))
        self.weight_hh = self.create_parameter(
            (4 * hidden_size, hidden_size),
            weight_hh_attr,
            default_initializer=I.Normal(guass_mean, guass_std))
        self.bias_ih = self.create_parameter(
            (4 * hidden_size, ),
            bias_ih_attr,
            is_bias=True,
            default_initializer=I.Normal(guass_mean, guass_std))
        self.bias_hh = self.create_parameter(
            (4 * hidden_size, ),
            bias_hh_attr,
            is_bias=True,
            default_initializer=I.Normal(guass_mean, guass_std))

        self.hidden_size = hidden_size
        self.input_size = input_size
        self.gate_activation = F.sigmoid
        activation_dict = {
            'tanh':paddle.tanh,
            'relu':F.relu,
            'gelu':F.gelu
        }
        if activation not in activation_dict:
            raise RuntimeError(f"{activation} is not supported in LSTMCell")

        self.activation = activation_dict[activation]

    def forward(self, inputs, states=None):
        # import pdb; pdb.set_trace()
        if states is None:
            states = self.get_initial_states(inputs, self.state_shape)
        prev_h, prev_c = states
        gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True)
        if self.bias_ih is not None:
            gates = gates + self.bias_ih
        gates += paddle.matmul(prev_h, self.weight_hh, transpose_y=True)
        if self.bias_hh is not None:
            gates = gates + self.bias_hh
        
        chunked_gates = paddle.split(gates, num_or_sections=4, axis=-1)

        i = self.gate_activation(chunked_gates[0])
        f = self.gate_activation(chunked_gates[1])
        g = self.activation(chunked_gates[2])
        o = self.gate_activation(chunked_gates[3])
        c = f * prev_c + i * g
        h = o * self.activation(c)

        return h, (h, c)

    @property
    def state_shape(self):
        r"""
        The `state_shape` of LSTMCell is a tuple with two shapes: 
        `((hidden_size, ), (hidden_size,))`. (-1 for batch size would be 
        automatically inserted into shape). These two shapes correspond 
        to :math:`h_{t-1}` and :math:`c_{t-1}` separately.
        """
        return ((self.hidden_size, ), (self.hidden_size, ))

    def extra_repr(self):
        return '{input_size}, {hidden_size}'.format(**self.__dict__)

class Decoder(nn.Layer):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = LSTMCell(input_size, hidden_size)
        self.rnn = nn.RNN(self.cell)
        self.sos = paddle.ones(shape=[1, 1, 2], dtype='float32')
        self.init_states = (paddle.zeros(shape=[1, 1, 4], dtype='float32'),
                        paddle.zeros(shape=[1, 1, 4], dtype='float32'))
        self.idx = paddle.zeros(shape=[1], dtype='int32')
        # self.idx = 0
        self.states = self.rnn(self.sos, self.init_states)
        self.step = 0

    def forward(self, inputs, hidden=None, cell=None):
        
        if hidden is None:
            states = self.states
        else:
            states = (hidden, cell)

        # import pdb; pdb.set_trace()
        if self.idx < 1:
            outs, states = self.rnn(inputs, states)
            self.idx += 1

        final_states = states
        return outs, final_states

    def export(self):
        static_model = paddle.jit.to_static(
            self,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[1, 1, 2], dtype='float32'),
                paddle.static.InputSpec(
                    shape=[1, 1, 4], dtype='float32'),
                paddle.static.InputSpec(
                    shape=[1, 1, 4], dtype='float32')
            ]
        )

        return static_model

model = Decoder(2, 4)

model.eval()
static_model = model.export()
paddle.jit.save(static_model, "test_model")

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit a54ec5a into PaddlePaddle:develop Mar 31, 2022
@0x45f 0x45f deleted the dy2st_cond_with_rnn branch March 31, 2022 08:24
@0x45f 0x45f changed the title Fix parent_block.var(name) error in static mode Fix parent_block.var(name) error in static mode for RNN Apr 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants