Skip to content

Commit

Permalink
update token counter and print trainable params
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Sep 12, 2023
1 parent 3d6f4bc commit 5a9612f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
5 changes: 4 additions & 1 deletion pipeline/mimicit_utils/mimicit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,15 @@ def process_image_text_pair(self, index):
raise NotImplementedError(f"Error: The task {cur_train_id} is not supported!")

all_text = self.tokenizer(
f"{all_texts}",
all_texts,
return_tensors="pt",
add_special_tokens=False,
truncation=True,
max_length=self.max_seq_len, # for current 2k mpt/llama model, setting to 2048 causes error (2042 works)
)
num_tokens = all_text['input_ids'].shape[1]
if num_tokens == self.max_seq_len:
print("The number of tokens in all_texts reaches the max_seq_len.")

all_item = all_text["input_ids"].squeeze(0)
all_item_mask = all_text["attention_mask"].squeeze(0)
Expand Down
2 changes: 1 addition & 1 deletion shared_scripts/Otter_MPT7B_Train_Decoder_4K.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"attn_config": {
"alibi": true,
"alibi_bias_max": 8,
"attn_impl": "torch",
"attn_impl": "triton",
"attn_pdrop": 0,
"attn_type": "multihead_attention",
"attn_uses_sequence_id": false,
Expand Down
4 changes: 2 additions & 2 deletions src/otter_ai/models/otter/modeling_otter.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def get_lang_encoder(self) -> nn.Module:

def init_weights(self):
# Freeze all parameters in self.model if train_vision_encoder is False or train_lang_encoder is False
if not ("train_full_model" in self.config.__dict__ and self.config.train_full_model is False):
if not ("train_full_model" in self.config.__dict__ and self.config.train_full_model is True):
for param in self.parameters():
param.requires_grad = False

Expand Down Expand Up @@ -881,7 +881,7 @@ def init_weights(self):
for name, param in self.named_parameters():
if param.requires_grad:
total_params += param.numel()
# print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
print(f"Total Trainable param: {total_params / 1e9:.6f} B")

def forward(
Expand Down

0 comments on commit 5a9612f

Please sign in to comment.