Skip to content

Commit

Permalink
Make steps_per_execution parameters settable.
Browse files Browse the repository at this point in the history
In the case of trying to tune with a custom `steps_per_execution` initial heuristic, it is helpful to be able to set to a certain value.

PiperOrigin-RevId: 539729291
  • Loading branch information
grasskin authored and tensorflower-gardener committed Jun 22, 2023
1 parent 5f9d052 commit b15a522
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 18 deletions.
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.-model.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.-sequential.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.models.-model.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.-model.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.-sequential.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.models.-model.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ tf_class {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
Expand Down
56 changes: 40 additions & 16 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def __init__(self, *args, **kwargs):
self._checkpoint = tf.train.Checkpoint(root=weakref.ref(self))

self._steps_per_execution = None
self._enable_tune_steps_per_execution = False
self._steps_per_execution_tuner = None
self._autotune_steps_per_execution = False

self._layout_map = layout_map_lib.get_current_layout_map()

Expand Down Expand Up @@ -803,12 +804,14 @@ def compile(
)

if steps_per_execution == "auto":
self._configure_steps_per_execution(1)
if self._steps_per_execution is None:
self._configure_steps_per_execution(1)
self._steps_per_execution_tuner = (
steps_per_execution_tuning.StepsPerExecutionTuner(
self.optimizer, self._steps_per_execution
)
)
self._autotune_steps_per_execution = True
else:
self._configure_steps_per_execution(steps_per_execution or 1)

Expand Down Expand Up @@ -1006,12 +1009,33 @@ def run_eagerly(self, value):
self._run_eagerly = value

@property
def enable_tune_steps_per_execution(self):
return self._enable_tune_steps_per_execution
def autotune_steps_per_execution(self):
"""Settable property to enable tuning for steps_per_execution"""
return self._autotune_steps_per_execution

@autotune_steps_per_execution.setter
def autotune_steps_per_execution(self, value):
self._autotune_steps_per_execution = value
if value and self._steps_per_execution_tuner is None:
if self._steps_per_execution is None:
self._configure_steps_per_execution(1)
self._steps_per_execution_tuner = (
steps_per_execution_tuning.StepsPerExecutionTuner(
self.optimizer, self._steps_per_execution
)
)

@enable_tune_steps_per_execution.setter
def enable_tune_steps_per_execution(self, value):
self._enable_tune_steps_per_execution = value
@property
def steps_per_execution(self):
"""Settable `steps_per_execution variable. Requires a compiled model."""
return self._steps_per_execution

@steps_per_execution.setter
def steps_per_execution(self, value):
if self._steps_per_execution is None:
self._configure_steps_per_execution(value)
else:
self._steps_per_execution.assign(value)

@property
def jit_compile(self):
Expand Down Expand Up @@ -1376,7 +1400,7 @@ def run_step(data):
if (
self._steps_per_execution is None
or self._steps_per_execution.numpy().item() == 1
and not self.enable_tune_steps_per_execution
and not self.autotune_steps_per_execution
):

def train_function(iterator):
Expand Down Expand Up @@ -1759,7 +1783,7 @@ def fit(
self._train_counter.assign(0)
callbacks.on_train_begin()
training_logs = None
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.start()
# Handle fault-tolerance for multi-worker.
# TODO(omalleyt): Fix the ordering issues that mean this has to
Expand Down Expand Up @@ -1867,7 +1891,7 @@ def fit(
# If eval data_handler exists, delete it after all epochs are done.
if getattr(self, "_eval_data_handler", None) is not None:
del self._eval_data_handler
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.stop()
callbacks.on_train_end(logs=training_logs)
return self.history
Expand Down Expand Up @@ -2041,7 +2065,7 @@ def run_step(data):
if (
self._steps_per_execution is None
or self._steps_per_execution.numpy().item() == 1
and not self.enable_tune_steps_per_execution
and not self.autotune_steps_per_execution
):

def test_function(iterator):
Expand Down Expand Up @@ -2263,7 +2287,7 @@ def evaluate(
test_function_runner = self._get_test_function_runner(callbacks)
self._test_counter.assign(0)
callbacks.on_test_begin()
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.start()
for (
_,
Expand All @@ -2289,7 +2313,7 @@ def evaluate(
logs = self._aggregate_exact_metrics(logs)
else:
logs = self._validate_and_get_metrics_result(logs)
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.stop()
callbacks.on_test_end(logs=logs)

Expand Down Expand Up @@ -2415,7 +2439,7 @@ def run_step(data):
if (
self._steps_per_execution is None
or self._steps_per_execution.numpy().item() == 1
and not self.enable_tune_steps_per_execution
and not self.autotune_steps_per_execution
):

def predict_function(iterator):
Expand Down Expand Up @@ -2628,7 +2652,7 @@ def predict(
self.predict_function = self.make_predict_function()
self._predict_counter.assign(0)
callbacks.on_predict_begin()
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.start()
batch_outputs = None
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
Expand Down Expand Up @@ -2668,7 +2692,7 @@ def predict(
"information of where went wrong, or file a "
"issue/bug to `tf.keras`."
)
if self.enable_tune_steps_per_execution:
if self.autotune_steps_per_execution:
self._steps_per_execution_tuner.stop()
callbacks.on_predict_end()
all_outputs = tf.__internal__.nest.map_structure_up_to(
Expand Down
39 changes: 37 additions & 2 deletions keras/engine/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,9 +2472,44 @@ def test_spe_tune_compile_fit_then_false_predict(self):
x, y = np.ones((10, 1)), np.ones((10, 1))
model.fit(x, y, epochs=2)
model.evaluate(x, y)
model.enable_tune_steps_per_execution = False
model.autotune_steps_per_execution = False
model.predict(x)
assert model.enable_tune_steps_per_execution == False
assert model.autotune_steps_per_execution == False

@test_combinations.run_all_keras_modes(always_skip_v1=True)
def test_spe_tune_set_after_compile(self):
model = sequential.Sequential([layers_module.Dense(1)])
model.compile(
"sgd",
loss="mse",
run_eagerly=False,
jit_compile=True,
steps_per_execution=5,
)
x, y = np.ones((10, 1)), np.ones((10, 1))
model.fit(x, y, epochs=2)
assert model._steps_per_execution_tuner is None
model.autotune_steps_per_execution = True
model.fit(x, y, epochs=2)
assert model.steps_per_execution.numpy().item() == 5
assert model._steps_per_execution_tuner

@test_combinations.run_all_keras_modes(always_skip_v1=True)
def test_spe_tune_set_before_compile(self):
model = sequential.Sequential([layers_module.Dense(1)])
model.steps_per_execution = 5
model.compile(
"sgd",
loss="mse",
run_eagerly=False,
jit_compile=True,
steps_per_execution="auto",
)
assert model.steps_per_execution.numpy().item() == 5
assert model._steps_per_execution_tuner

x, y = np.ones((10, 1)), np.ones((10, 1))
model.fit(x, y, epochs=2)


class TestExceptionsAndWarnings(test_combinations.TestCase):
Expand Down

0 comments on commit b15a522

Please sign in to comment.