Skip to content

Commit

Permalink
pass model as list to pass CI
Browse files Browse the repository at this point in the history
  • Loading branch information
aklife97 authored May 8, 2023
1 parent bb0c90c commit c0c098e
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def training_step(self, dataloader_iter, batch_idx):
losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
model=self.model if isinstance(self.model, list) else [self.model],
num_microbatches=get_num_microbatches(),
forward_only=False,
tensor_shape=tensor_shape,
Expand Down Expand Up @@ -700,7 +700,7 @@ def validation_step(self, dataloader_iter, batch_idx):
losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(validation_step=True),
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
model=self.model if isinstance(self.model, list) else [self.model],
num_microbatches=get_num_microbatches(),
forward_only=True,
tensor_shape=tensor_shape,
Expand Down

0 comments on commit c0c098e

Please sign in to comment.