Skip to content

Commit

Permalink
Abstract out attribute copy during clone_to (#2288)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2288

Adds a method BaseTrial._update_trial_attrs_on_clone that copies over trial attributes upon trial clone.

Reviewed By: Balandat

Differential Revision: D55024638

fbshipit-source-id: 6f35af84501a165ae0bb6a0696025d2aee72a542
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Mar 19, 2024
1 parent 54996ed commit cc50ce8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
18 changes: 18 additions & 0 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
Expand Down Expand Up @@ -865,3 +866,20 @@ def _validate_can_attach_data(self) -> None:
f"Trial {self.index} has been marked {self.status.name}, so it "
"no longer expects data."
)

def _update_trial_attrs_on_clone(
self,
new_trial: BaseTrial,
) -> None:
"""Updates attributes of the trial that are not copied over when cloning
a trial.
Args:
new_trial: The cloned trial.
new_experiment: The experiment that the cloned trial belongs to.
new_status: The new status of the cloned trial.
"""
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
new_trial.runner = self._runner.clone() if self._runner else None
6 changes: 1 addition & 5 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import warnings

from collections import defaultdict, OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -603,10 +602,7 @@ def clone_to(
self._status_quo.clone(),
weight=sq_weight,
)
new_trial.runner = self._runner.clone() if self._runner else None
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
self._update_trial_attrs_on_clone(new_trial=new_trial)
return new_trial

def attach_batch_trial_data(
Expand Down
8 changes: 1 addition & 7 deletions ax/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from __future__ import annotations

from copy import deepcopy

from functools import partial

from logging import Logger
Expand Down Expand Up @@ -351,9 +349,5 @@ def clone_to(
)
if self.generator_run is not None:
new_trial.add_generator_run(self.generator_run.clone())
new_trial._run_metadata = deepcopy(self._run_metadata)
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
new_trial.runner = self._runner.clone() if self._runner else None

self._update_trial_attrs_on_clone(new_trial=new_trial)
return new_trial

0 comments on commit cc50ce8

Please sign in to comment.