Skip to content

Commit

Permalink
Update docstring for GenerationNode.gen & fit (facebook#2245)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2245

--

Reviewed By: lena-kashtelyan

Differential Revision: D54506927

fbshipit-source-id: ef1f0bcb0b4b38f597344d7cae45691ed9c3c31d
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 6, 2024
1 parent 85768e8 commit e8018a2
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,15 @@ def fit(
the model kwargs set on each corresponding model spec and the kwargs
passed to this method.
NOTE: Local kwargs take precedence over the ones stored in
``ModelSpec.model_kwargs``.
Args:
experiment: The experiment to fit the model to.
data: The experiment data used to fit the model.
search_space: An optional overwrite for the experiment search space.
optimization_config: An optional overwrite for the experiment
optimization config.
kwargs: Additional keyword arguments to pass to the model's
``fit`` method. NOTE: Local kwargs take precedence over the ones
stored in ``ModelSpec.model_kwargs``.
"""
self._model_spec_to_gen_from = None
for model_spec in self.model_specs:
Expand All @@ -283,21 +290,27 @@ def gen(
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.
NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
model-determined number of arms. In that case this method will also output
a generator run with number of arms that may differ from ``n``.
Args:
n: Optional nteger representing how many arms should be in the generator
n: Optional integer representing how many arms should be in the generator
run produced by this method. When this is ``None``, ``n`` will be
determined by the ``ModelSpec`` that we are generating from.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
max_gen_draws_for_deduplication: TODO
max_gen_draws_for_deduplication: Maximum number of attempts for generating
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.
NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
model-determined number of arms. In that case this method will also output
a generator run with number of arms (that can differ from ``n``).
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
should_generate_run = True
Expand Down

0 comments on commit e8018a2

Please sign in to comment.