Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acquisition metadata (different approach) #767

Draft
wants to merge 11 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 219 additions & 2 deletions trieste/acquisition/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Generic, Mapping, Optional
from typing import Any, Callable, Generic, Mapping, Optional

from ..data import Dataset
from ..models.interfaces import ProbabilisticModelType
Expand Down Expand Up @@ -67,14 +67,31 @@ def prepare_acquisition_function(
:return: An acquisition function.
"""

def prepare_acquisition_function_with_metadata(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Prepare an acquisition function using additional metadata. By default, this is just
dropped, but you can override this method to use the metadata during acquisition.

:param models: The models for each tag.
:param datasets: The data from the observer (optional).
:param metadata: Metadata from the observer (optional).
:return: An acquisition function.
"""
return self.prepare_acquisition_function(models, datasets=datasets)

def update_acquisition_function(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
) -> AcquisitionFunction:
"""
Update an acquisition function. By default this generates a new acquisition function each
Update an acquisition function. By default, this generates a new acquisition function each
time. However, if the function is decorated with `@tf.function`, then you can override
this method to update its variables instead and avoid retracing the acquisition function on
every optimization loop.
Expand All @@ -86,6 +103,25 @@ def update_acquisition_function(
"""
return self.prepare_acquisition_function(models, datasets=datasets)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Update an acquisition function. By default, this is just
dropped, but you can override this method to use the metadata during acquisition.

:param function: The acquisition function to update.
:param models: The models for each tag.
:param datasets: The data from the observer (optional).
:param metadata: Metadata from the observer (optional).
:return: The updated acquisition function.
"""
return self.update_acquisition_function(function, models, datasets=datasets)


class SingleModelAcquisitionBuilder(Generic[ProbabilisticModelType], ABC):
"""
Expand Down Expand Up @@ -115,6 +151,18 @@ def prepare_acquisition_function(
models[tag], dataset=None if datasets is None else datasets[tag]
)

def prepare_acquisition_function_with_metadata(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function_with_metadata(
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand All @@ -125,6 +173,20 @@ def update_acquisition_function(
function, models[tag], dataset=None if datasets is None else datasets[tag]
)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function_with_metadata(
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def __repr__(self) -> str:
return f"{self.single_builder!r} using tag {tag!r}"

Expand All @@ -142,6 +204,20 @@ def prepare_acquisition_function(
:return: An acquisition function.
"""

def prepare_acquisition_function_with_metadata(
self,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param model: The model.
:param dataset: The data to use to build the acquisition function (optional).
:param metadata: Metadata from the observer (optional).
:return: An acquisition function.
"""
return self.prepare_acquisition_function(model, dataset=dataset)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand All @@ -156,6 +232,21 @@ def update_acquisition_function(
"""
return self.prepare_acquisition_function(model, dataset=dataset)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
:param function: The acquisition function to update.
:param model: The model.
:param dataset: The data from the observer (optional).
:return: The updated acquisition function.
"""
return self.update_acquisition_function(function, model, dataset=dataset)


class GreedyAcquisitionFunctionBuilder(Generic[ProbabilisticModelType], ABC):
"""
Expand Down Expand Up @@ -187,6 +278,20 @@ def prepare_acquisition_function(
:return: An acquisition function.
"""

def prepare_acquisition_function_with_metadata(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Same as prepare_acquisition_function but accepts additional metadata argument.
"""
return self.prepare_acquisition_function(
models, datasets=datasets, pending_points=pending_points
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand Down Expand Up @@ -215,6 +320,26 @@ def update_acquisition_function(
models, datasets=datasets, pending_points=pending_points
)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Same as update_acquisition_function but accepts additional metadata argument.
"""
return self.update_acquisition_function(
function,
models,
datasets=datasets,
pending_points=pending_points,
new_optimization_step=new_optimization_step,
)


class SingleModelGreedyAcquisitionBuilder(Generic[ProbabilisticModelType], ABC):
"""
Expand Down Expand Up @@ -247,6 +372,20 @@ def prepare_acquisition_function(
pending_points=pending_points,
)

def prepare_acquisition_function_with_metadata(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function_with_metadata(
models[tag],
dataset=None if datasets is None else datasets[tag],
pending_points=pending_points,
metadata=metadata,
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand All @@ -263,6 +402,24 @@ def update_acquisition_function(
new_optimization_step=new_optimization_step,
)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function_with_metadata(
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
pending_points=pending_points,
new_optimization_step=new_optimization_step,
metadata=metadata,
)

def __repr__(self) -> str:
return f"{self.single_builder!r} using tag {tag!r}"

Expand All @@ -283,6 +440,20 @@ def prepare_acquisition_function(
:return: An acquisition function.
"""

def prepare_acquisition_function_with_metadata(
self,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
pending_points: Optional[TensorType] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Same as prepare_acquisition_function but accepts additional metadata argument.
"""
return self.prepare_acquisition_function(
model, dataset=dataset, pending_points=pending_points
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand All @@ -308,6 +479,26 @@ def update_acquisition_function(
pending_points=pending_points,
)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
model: ProbabilisticModelType,
dataset: Optional[Dataset] = None,
pending_points: Optional[TensorType] = None,
new_optimization_step: bool = True,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
"""
Same as prepare_acquisition_function but accepts additional metadata argument.
"""
return self.update_acquisition_function(
function,
model,
dataset=dataset,
pending_points=pending_points,
new_optimization_step=new_optimization_step,
)


class VectorizedAcquisitionFunctionBuilder(AcquisitionFunctionBuilder[ProbabilisticModelType]):
"""
Expand Down Expand Up @@ -349,6 +540,18 @@ def prepare_acquisition_function(
models[tag], dataset=None if datasets is None else datasets[tag]
)

def prepare_acquisition_function_with_metadata(
self,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.prepare_acquisition_function_with_metadata(
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def update_acquisition_function(
self,
function: AcquisitionFunction,
Expand All @@ -359,6 +562,20 @@ def update_acquisition_function(
function, models[tag], dataset=None if datasets is None else datasets[tag]
)

def update_acquisition_function_with_metadata(
self,
function: AcquisitionFunction,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> AcquisitionFunction:
return self.single_builder.update_acquisition_function_with_metadata(
function,
models[tag],
dataset=None if datasets is None else datasets[tag],
metadata=metadata,
)

def __repr__(self) -> str:
return f"{self.single_builder!r} using tag {tag!r}"

Expand Down
Loading