Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
NihalHarish committed Jun 5, 2020
1 parent d867a9b commit 5fd3a74
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,66 +352,70 @@ def _save_inputs(self, check_before_write=True):
# TODO
pass

def _add_metric(self, metric_name, metric_value: tf.Tensor = None):
if metric_name in self.tensor_to_collections:
def _save_tensor(self, t_name: str, t_value: tf.Tensor, collection: CollectionKeys):
if t_name in self.tensor_to_collections:
return
coll = self.collection_manager.get(collection)
if isinstance(t_value, tf.Tensor):
coll.set_tensor_ref(TensorRef.from_non_graph_var(t_name))
else:
coll.set_tensor_ref(TensorRef.from_non_graph_var(t_name))
self.tensor_to_collections[t_name] = {coll}
self._initialize_writers(only_initialize_if_missing=True)
self._save_for_tensor(t_name, t_value, check_before_write=False)

def _add_metric(self, metric_name, metric_value: tf.Tensor = None):
if metric_name in ["loss", "val_loss"]:
coll_name = CollectionKeys.LOSSES
else:
coll_name = CollectionKeys.METRICS
coll = self.collection_manager.get(coll_name)
if metric_value:
coll.set_tensor_ref(metric_value, metric_name)
else:
coll.set_tensor_ref(TensorRef.from_non_graph_var(metric_name))
self.tensor_to_collections[metric_name] = {coll}
self._save_tensor(metric_name, metric_value, coll_name)

def _save_model_outputs(self, logs):
if logs is None:
return

if self._is_collection_being_saved_for_step(CollectionKeys.OUTPUTS):
export_names = {
ModelOutput.Y_PRED: "train_output/y_pred",
ModelOutput.Y: "train_output/y",
}
self._initialize_writers(only_initialize_if_missing=True)
output_collection = self.collection_manager.get(CollectionKeys.OUTPUTS)
for key in logs:
if key in [ModelOutput.Y, ModelOutput.Y_PRED]:
tensor_ref = TensorRef.from_non_graph_var(export_names[key])
output_collection.set_tensor_ref(tensor_ref)
self.tensor_to_collections[export_names[key]] = {output_collection}
self._save_for_tensor(export_names[key], logs[key], check_before_write=False)
self._save_tensor(export_names[key], logs[key], CollectionKeys.OUTPUTS)

def _save_metrics(self, batch, logs, force_save=False):
# if force_save is True, doesn't check whether collection needs to be saved for steps
if logs is None:
return

if force_save or self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
self._initialize_writers(only_initialize_if_missing=True)
logs["batch"] = batch
for key in logs:
if key in ["loss", "val_loss", "outputs", ModelOutput.Y, ModelOutput.Y_PRED]:
# outputs is saved differently through outputs collection
continue
self._add_metric(metric_name=key)
self._save_for_tensor(key, logs[key], check_before_write=False)
self._add_metric(metric_name=key, metric_value=logs[key])

if force_save or self._is_collection_being_saved_for_step(CollectionKeys.LOSSES):
self._initialize_writers(only_initialize_if_missing=True)
for key in ["loss", "val_loss"]:
if key in logs:
self._add_metric(metric_name=key)
self._save_for_tensor(key, logs[key], check_before_write=False)
self._add_metric(metric_name=key, metric_value=logs[key])

def _save_gradients(self, logs):
if logs is None:
return

if self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
if "gradients" in logs:
tensor_ref = TensorRef.from_non_graph_var(export_names[key])

def _save_tensors_post_step(self, batch, logs):
# some tensors available as value from within hook are saved here
# weights, metrics
self._save_metrics(batch, logs)
self._save_model_outputs(logs)
self._save_gradients(logs)

if is_tf_version_2x() and tf.executing_eagerly():
for tensor_ref in self.tensor_refs_to_save_this_step:
Expand Down

0 comments on commit 5fd3a74

Please sign in to comment.