Skip to content

Commit

Permalink
Fix device of masks in tests (#27887)
Browse files Browse the repository at this point in the history
fix device of mask in tests
  • Loading branch information
fxmarty authored Dec 7, 2023
1 parent fc71e81 commit c99f254
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)

token_type_ids = None
if self.use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)

token_type_ids = None
if self.use_token_type_ids:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/persimmon/test_modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self):

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length))
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)

token_type_ids = None
if self.use_token_type_ids:
Expand Down

0 comments on commit c99f254

Please sign in to comment.