diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 97846b79458..93d2b628e8d 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -9,6 +9,7 @@ from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty +from copy import deepcopy from datetime import datetime, timedelta from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -865,3 +866,20 @@ def _validate_can_attach_data(self) -> None: f"Trial {self.index} has been marked {self.status.name}, so it " "no longer expects data." ) + + def _update_trial_attrs_on_clone( + self, + new_trial: BaseTrial, + ) -> None: + """Updates attributes of the trial that are not copied over when cloning + a trial. + + Args: + new_trial: The cloned trial. + new_experiment: The experiment that the cloned trial belongs to. + new_status: The new status of the cloned trial. + """ + new_trial._run_metadata = deepcopy(self._run_metadata) + new_trial._stop_metadata = deepcopy(self._stop_metadata) + new_trial._num_arms_created = self._num_arms_created + new_trial.runner = self._runner.clone() if self._runner else None diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index f63f1ffd1a3..985e57ca6c7 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -11,7 +11,6 @@ import warnings from collections import defaultdict, OrderedDict -from copy import deepcopy from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -603,10 +602,7 @@ def clone_to( self._status_quo.clone(), weight=sq_weight, ) - new_trial.runner = self._runner.clone() if self._runner else None - new_trial._run_metadata = deepcopy(self._run_metadata) - new_trial._stop_metadata = deepcopy(self._stop_metadata) - new_trial._num_arms_created = self._num_arms_created + self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial def attach_batch_trial_data( diff --git a/ax/core/trial.py b/ax/core/trial.py index d79f903bb91..74a49bd9b10 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -8,8 +8,6 @@ from __future__ import annotations -from copy import deepcopy - from functools import partial from logging import Logger @@ -351,9 +349,5 @@ def clone_to( ) if self.generator_run is not None: new_trial.add_generator_run(self.generator_run.clone()) - new_trial._run_metadata = deepcopy(self._run_metadata) - new_trial._stop_metadata = deepcopy(self._stop_metadata) - new_trial._num_arms_created = self._num_arms_created - new_trial.runner = self._runner.clone() if self._runner else None - + self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial