Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add report_progress to TrainContext #9826

Merged
merged 13 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ trial ID in the checkpoint and use it to distinguish the two types of continues.
.. literalinclude:: ../../../../examples/tutorials/core_api/2_checkpoints.py
:language: python
:start-at: def main
:end-at: for batch in range(starting_batch, 100)
:end-at: for batch in range(starting_batch, max_length)

#. You can checkpoint your model as frequently as you like. For this exercise, save a checkpoint
after each training report, and check for a preemption signal after each checkpoint:
Expand Down
5 changes: 3 additions & 2 deletions examples/features/unmanaged/1_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def main():
# project="...",
),
)

for i in range(100):
max_length = 100
for i in range(max_length):
print(f"training loss: {random.random()}")

core_v2.train.report_training_metrics(steps_completed=i, metrics={"loss": random.random()})
Expand All @@ -28,6 +28,7 @@ def main():
core_v2.train.report_validation_metrics(
steps_completed=i, metrics={"loss": random.random()}
)
core_v2.train.report_progress(i / float(max_length))

core_v2.close()

Expand Down
4 changes: 3 additions & 1 deletion examples/features/unmanaged/2_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def main():

latest_checkpoint = core_v2.info.latest_checkpoint
initial_i = 0
max_length = 100
if latest_checkpoint is not None:
with core_v2.checkpoint.restore_path(latest_checkpoint) as path:
with (path / "state").open() as fin:
Expand All @@ -32,11 +33,12 @@ def main():

print("determined experiment id: ", core_v2.info._trial_info.experiment_id)
print("initial step:", initial_i)
for i in range(initial_i, initial_i + 100):
for i in range(initial_i, initial_i + max_length):
core_v2.train.report_training_metrics(steps_completed=i, metrics={"loss": random.random()})
if (i + 1) % 10 == 0:
loss = random.random()
core_v2.train.report_validation_metrics(steps_completed=i, metrics={"loss": loss})
core_v2.train.report_progress((i - initial_i) / float(max_length))

with core_v2.checkpoint.store_path({"steps_completed": i}) as (path, uuid):
with (path / "state").open("w") as fout:
Expand Down
5 changes: 4 additions & 1 deletion examples/tutorials/core_api/1_metrics.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. include this change in 2_checkpoints.py, too. they're meant to be incremental tutorials, hence the "# NEW: ..."
  2. also update detached mode tutorials (and make sure this works in detached mode. it should, but just in case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function would work with detached mode out of the box, because in the existing method, we retrieve experiment from experiment.ExperimentRegistry, but currently we do not include unmanaged experiments in ExperimentRegistry. So we can either choose to include unmanaged experiment in ExperimentRegistry, or retrieve unmanaged experiment from DB instead.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

def main(core_context, increment_by):
x = 0
for batch in range(100):
max_length = 100
for batch in range(max_length):
x += increment_by
steps_completed = batch + 1
time.sleep(0.1)
Expand All @@ -24,6 +25,8 @@ def main(core_context, increment_by):
core_context.train.report_training_metrics(
steps_completed=steps_completed, metrics={"x": x}
)
# NEW: report training progress.
core_context.train.report_progress(steps_completed / float(max_length))
# NEW: report a "validation" metric at the end.
core_context.train.report_validation_metrics(steps_completed=steps_completed, metrics={"x": x})

Expand Down
5 changes: 3 additions & 2 deletions examples/tutorials/core_api/2_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def load_state(trial_id, checkpoint_directory):

def main(core_context, latest_checkpoint, trial_id, increment_by):
x = 0

max_length = 100
# NEW: load a checkpoint if one was provided.
starting_batch = 0
if latest_checkpoint is not None:
with core_context.checkpoint.restore_path(latest_checkpoint) as path:
x, starting_batch = load_state(trial_id, path)

for batch in range(starting_batch, 100):
for batch in range(starting_batch, max_length):
x += increment_by
steps_completed = batch + 1
time.sleep(0.1)
Expand All @@ -57,6 +57,7 @@ def main(core_context, latest_checkpoint, trial_id, increment_by):
core_context.train.report_training_metrics(
steps_completed=steps_completed, metrics={"x": x}
)
core_context.train.report_progress(steps_completed / float(max_length))

# NEW: write checkpoints at regular intervals to limit lost progress
# in case of a crash during training.
Expand Down
43 changes: 39 additions & 4 deletions harness/determined/common/api/bindings.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion harness/determined/core/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def report_progress(self, length: float) -> None:
logger.debug(f"op.report_progress({length})")
self._session.post(
f"/api/v1/trials/{self._trial_id}/progress",
data=det.util.json_encode(length),
data=det.util.json_encode({"progress": length}),
)

def report_completed(self, searcher_metric: Any) -> None:
Expand Down
27 changes: 27 additions & 0 deletions harness/determined/core/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,30 @@ def report_early_exit(self, reason: EarlyExitReason) -> None:
if r.status_code == 400:
logger.warn("early exit has already been reported for this trial, ignoring new value")

def report_progress(self, progress: float) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method also needs to be in DummyTrainContext. otherwise local training will not work

"""
Report training progress to the master.

This is optional for training, but will be used by the WebUI to render completion status.

Progress must be reported as a float between 0 and 1.0, where 1.0 is 100% completion. It
should represent the current iteration step as a fraction of maximum training steps
(i.e.: `report_progress(step_num / max_steps)`).

Note that for hyperparameter search, progress should be reported through
``SearcherOperation.report_progress()`` in the Searcher API instead.

Arguments:
progress (float): completion progress in the range [0, 1.0].
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wording, and style. i know we're not consistent with style for docstrings in this class, but let's do it for new methods. we mostly try to follow the google style guide

suggestion:

        """
        Report training progress to the master.

        This is optional for training, but will be used by the WebUI to render completion status.

        Progress must be reported as a float between 0 and 1.0, where 1.0 is 100% completion. It 
        should represent the current iteration step as a fraction of maximum training steps 
        (i.e.: `report_progress(step_num / max_steps)`). 

        Note that for hyperparameter search, progress should be reported through
        ``SearcherOperation.report_progress()`` in the Searcher API instead.

        Arguments:
            progress (float): completion progress in the range [0, 1.0].
        """

logger.debug(f"report_progress with progress={progress}")
if progress < 0 or progress > 1:
raise ValueError(f"Progress should be between 0 and 1, not {progress}")
self._session.post(
f"/api/v1/trials/{self._trial_id}/progress",
data=det.util.json_encode({"progress": progress, "is_raw": True}),
)

def get_experiment_best_validation(self) -> Optional[float]:
"""
Get the best reported validation metric reported so far, across the whole experiment.
Expand Down Expand Up @@ -312,6 +336,9 @@ def upload_tensorboard_files(
def report_early_exit(self, reason: EarlyExitReason) -> None:
logger.info(f"report_early_exit({reason})")

def report_progress(self, progress: float) -> None:
logger.info(f"report_progress with progress={progress}")

def get_experiment_best_validation(self) -> Optional[float]:
return None

Expand Down
8 changes: 7 additions & 1 deletion master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -1393,19 +1393,25 @@ func (a *apiServer) ReportTrialProgress(
experiment.AuthZProvider.Get().CanEditExperiment); err != nil {
return nil, err
}

eID, rID, err := a.m.db.TrialExperimentAndRequestID(int(req.TrialId))
if err != nil {
return nil, err
}

e, ok := experiment.ExperimentRegistry.Load(eID)
if !ok {
return nil, api.NotFoundErrs("experiment", strconv.Itoa(eID), true)
// Unmanaged experiment is not included in ExperimentRegistry
if err := a.m.db.SaveExperimentProgress(eID, &req.Progress); err != nil {
return nil, err
}
return &apiv1.ReportTrialProgressResponse{}, nil
}

msg := experiment.TrialReportProgress{
RequestID: rID,
Progress: searcher.PartialUnits(req.Progress),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's not too much effort, it' be great if we could get rid of this searcher.PartialUnits type, it's just a float anyway. will need to do it sooner or later, but not strictly necessary as part of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it touches multiple files and not directly related to this PR, I created a ticket for it

IsRaw: req.IsRaw,
}
if err := e.TrialReportProgress(msg); err != nil {
return nil, err
Expand Down
3 changes: 3 additions & 0 deletions master/internal/db/postgres_experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ EXISTS(

// SaveExperimentProgress stores the progress for an experiment in the database.
func (db *PgDB) SaveExperimentProgress(id int, progress *float64) error {
if progress != nil && (*progress < 0 || *progress > 1) {
return errors.Errorf("invalid progress value: %f. Progress value should be between 0 and 1", *progress)
}
res, err := db.sql.Exec(`UPDATE experiments SET progress = $1 WHERE id = $2`, progress, id)
if err != nil {
return errors.Wrap(err, "saving experiment progress")
Expand Down
8 changes: 6 additions & 2 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,12 @@ func (e *internalExperiment) TrialReportProgress(msg experiment.TrialReportProgr
e.mu.Lock()
defer e.mu.Unlock()

e.searcher.SetTrialProgress(msg.RequestID, msg.Progress)
progress := e.searcher.Progress()
progress := float64(msg.Progress)
if !msg.IsRaw {
e.searcher.SetTrialProgress(msg.RequestID, msg.Progress)
gt2345 marked this conversation as resolved.
Show resolved Hide resolved
progress = e.searcher.Progress()
}

if err := e.db.SaveExperimentProgress(e.ID, &progress); err != nil {
e.syslog.WithError(err).Error("failed to save experiment progress")
}
Expand Down
1 change: 1 addition & 0 deletions master/internal/experiment/experiment_iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type (
TrialReportProgress struct {
RequestID model.RequestID
Progress searcher.PartialUnits
IsRaw bool
}

// UserInitiatedEarlyTrialExit is a user-injected message, provided through the early exit API. It
Expand Down
5 changes: 5 additions & 0 deletions master/internal/trials/postgres_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ func UpdateUnmanagedExperimentStatesTx(
endTime = ptrs.Ptr(time.Now())
}
exp.EndTime = endTime

if exp.State == model.CompletedState {
columns = append(columns, "progress")
exp.Progress = ptrs.Ptr(1.0)
}
}

if _, err := tx.NewUpdate().Model(exp).Column(columns...).WherePK().Exec(ctx); err != nil {
Expand Down
Loading
Loading