Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train.py의 id_to_string method 오류 #7

Open
jhj9109 opened this issue Jun 4, 2021 · 0 comments
Open

train.py의 id_to_string method 오류 #7

jhj9109 opened this issue Jun 4, 2021 · 0 comments

Comments

@jhj9109
Copy link
Member

jhj9109 commented Jun 4, 2021

토론 글을 통해서 제시된 오류.
Score를 계산하는 변환 시에 <EOS>를 만나면 종료되는 코드 추가하여 해결

원본 코드

def id_to_string(tokens, data_loader,do_eval=0):
    result = []
    if do_eval:
        special_ids = [data_loader.dataset.token_to_id["<PAD>"], data_loader.dataset.token_to_id["<SOS>"],
                       data_loader.dataset.token_to_id["<EOS>"]]

    for example in tokens:
        string = ""
        if do_eval:
            for token in example:
                token = token.item()
                if token not in special_ids:
                    if token != -1:
                        string += data_loader.dataset.id_to_token[token] + " "
        else:
            for token in example:
                token = token.item()
                if token != -1:
                    string += data_loader.dataset.id_to_token[token] + " "

문제점

<EOS> 이후에도 스페셜 토큰이 아닌 어떤 값(id)이 결과로 남아있을때, 불필요하게 이를 고려하여 score에 영향을 끼치게 된다.

e.g. teacher forcing이 적용될 때. <EOS>를 예측했지만, 올바른 출력이 아니라면
=> 올바른 출력이 다시 입력으로 들어가 예측에 활용되어 <EOS> 이후에도 맞지 않은 토큰 값들이 출력된다.
e.g.
gt: <SOS> { 1 } + { 2 } = { 3 } <EOS> <PAD> <PAD> ...
pred: <SOS> { 1 } <EOS> { 2 } = { 3 } <EOS> <PAD> ...

  • 첫번째 <EOS> 예측 이후 예측된 { 2 } = { 3 } 을 score 계산에 반영하는 것은 올바르지 않다고 판단된다.

수정 코드

def id_to_string(tokens, data_loader,do_eval=0): # 0 Preds 1 -1 -1....
    result = []
    if do_eval:
        eos_id =  data_loader.dataset.token_to_id["<EOS>"]
        pad_id = data_loader.dataset.token_to_id["<PAD>"]
        sos_id = data_loader.dataset.token_to_id["<SOS>"]
        pad_id2 = -1
        ignore_ids = {
            pad_id : 1,
            sos_id : 1,
            pad_id2 : 1,
        }
    for example in tokens:
        string = ""
        if do_eval:  # 계산 용도 => score 와 관련이 있다.
            for token in example:
                token = token.item()
                if token == eos_id: # <EOS>만나면 종료한다.
                    break
                if token not in ignore_ids: # eos 외 무시할 id들을 체크한다.
                    string += data_loader.dataset.id_to_token[token] + " "
        else: # display 용도.
            for token in example:
                token = token.item()
                if token != -1: # 길이 채우기 위한 -1만 무시한다.
                    string += data_loader.dataset.id_to_token[token] + " "

        result.append(string)
    return result
  • Hotfix - eos_id ~ ignore_ids 선언하는 부분 잘 못 된 코드 수정
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant