Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SATRN.py의 코드 구현 오류 #5

Open
chjin0725 opened this issue May 29, 2021 · 1 comment
Open

SATRN.py의 코드 구현 오류 #5

chjin0725 opened this issue May 29, 2021 · 1 comment

Comments

@chjin0725
Copy link
Contributor

chjin0725 commented May 29, 2021

SATRN.py의 TransformerDecoderLayer의 forward()부분에서

att = self.attention_layer(tgt, src, src)

이 부분의 tgt를 out으로 바꿔야 할 것 같습니다.
한번 바꿔보고 실험 해보겠습니다.

<전체 코드>

def forward(self, tgt, tgt_prev, src, tgt_mask):
        if tgt_prev == None:  # Train
            att = self.self_attention_layer(tgt, tgt, tgt, tgt_mask)
            out = self.self_attention_norm(att + tgt)

            att = self.attention_layer(tgt, src, src) <- 문제가 있어 보이는 부분
            out = self.attention_norm(att + out)

            ff = self.feedforward_layer(out)
            out = self.feedforward_norm(ff + out)
        else:
            tgt_prev = torch.cat([tgt_prev, tgt], 1)
            att = self.self_attention_layer(tgt, tgt_prev, tgt_prev, tgt_mask)
            out = self.self_attention_norm(att + tgt)

            att = self.attention_layer(tgt, src, src) <- 문제가 있어 보이는 부분
            out = self.attention_norm(att + out)

            ff = self.feedforward_layer(out)
            out = self.feedforward_norm(ff + out)
        return out

기존의 Transfomer의 decoder에서는 masked self attention을 한 결과로 얻은 output으로 encoder-decoder attention을 합니다.
근데 여기서는 masked self attention과 encoder-decoder attention을 따로 하고 그 결과들을 더해서 attention_norm을 하고 있습니다.
이렇게 해도 self attention으로 얻은 정보와 encdoer의 input에 대한 정보를 둘 다 사용할 수는 있겠으나 encoder-decoder attention을 할 때 self attention으로 얻은 정보를 활용하지 못하고 있으므로 이를 수정해 주면 더 좋은 성능을 기대할 수 있을 것 같습니다.
아래 그림의 왼쪽이 기존 코드에서의 동작을 나타내고 있고 오른쪽이 수정 후의 동작입니다.
image

@chjin0725
Copy link
Contributor Author

image

  • v1 : 기존 코드, v2 : 수정한 코드
  • validation sentence accuracy가 0.3정도 증가.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant