Skip to content

Commit

Permalink
Update to use [IMG]
Browse files Browse the repository at this point in the history
  • Loading branch information
kohjingyu committed Jul 4, 2023
1 parent 3d9161b commit 53fdcf2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def main_worker(gpu, ngpus_per_node, args):
model_args.retrieval_token_idx.append(ret_token_idx[0])
args.retrieval_token_idx.append(ret_token_idx[0])

# Add [GEN] tokens to the vocabulary.
# Add [IMG] tokens to the vocabulary.
model_args.gen_token_idx = model_args.retrieval_token_idx
args.gen_token_idx = args.retrieval_token_idx

Expand Down Expand Up @@ -575,7 +575,7 @@ def train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler

# Update weights
if ((i + 1) % args.grad_accumulation_steps == 0) or (i == args.steps_per_epoch - 1):
# Zero out gradients of the embedding matrix outside of [GEN].
# Zero out gradients of the embedding matrix outside of [IMG].
for param in model.module.model.input_embeddings.parameters():
assert param.grad.shape[0] == len(tokenizer)
# Keep other embeddings frozen.
Expand Down

0 comments on commit 53fdcf2

Please sign in to comment.