Skip to content

Commit

Permalink
Remove default for checkpoint argument.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 476332142
  • Loading branch information
Marvin182 authored and copybara-github committed Sep 24, 2022
1 parent bc60709 commit a093b4a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion vmoe/evaluate/evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _create_dataset_and_expected_state(cls):
# 0 or loss[i].
sum_loss=tf.reduce_sum(loss * valid).numpy(),
rngs={})
return TfDatasetIterator(dataset), expected_eval_state
return TfDatasetIterator(dataset, checkpoint=False), expected_eval_state

def test_evaluate_dataset(self):
# Create random test dataset.
Expand Down
9 changes: 6 additions & 3 deletions vmoe/evaluate/fewshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ def setUp(self):
'label': labels,
fewshot.VALID_KEY: valid,
})
self.mock_get_dataset = self.enter_context(mock.patch.object(
fewshot.vmoe.data.input_pipeline, 'get_dataset',
side_effect=lambda *a, **kw: clu.data.TfDatasetIterator(dataset)))
self.mock_get_dataset = self.enter_context(
mock.patch.object(
fewshot.vmoe.data.input_pipeline,
'get_dataset',
side_effect=lambda *a, **kw: clu.data.TfDatasetIterator( # pylint: disable=g-long-lambda
dataset, checkpoint=False)))

@classmethod
def _apply_fn(cls, variables, images, rngs=None):
Expand Down

0 comments on commit a093b4a

Please sign in to comment.