From a5d1e55c26dcb47c7b225a87a49b107c256f5269 Mon Sep 17 00:00:00 2001 From: "Christopher J. Wood" Date: Tue, 16 Jan 2024 16:56:26 -0500 Subject: [PATCH] Update EstimatorV2 run return type, fix some typos --- qiskit/primitives/base/base_estimator.py | 14 ++++++++++++-- qiskit/primitives/containers/estimator_pub.py | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/qiskit/primitives/base/base_estimator.py b/qiskit/primitives/base/base_estimator.py index 1078091304ba..51c206d66979 100644 --- a/qiskit/primitives/base/base_estimator.py +++ b/qiskit/primitives/base/base_estimator.py @@ -166,9 +166,17 @@ from qiskit.quantum_info.operators.base_operator import BaseOperator from qiskit.utils.deprecation import deprecate_func -from ..containers import make_data_bin, DataBin, EstimatorPub, EstimatorPubLike +from ..containers import ( + make_data_bin, + DataBin, + EstimatorPub, + EstimatorPubLike, + PrimitiveResult, + PubResult, +) from . import validation from .base_primitive import BasePrimitive +from .base_primitive_job import BasePrimitiveJob T = TypeVar("T", bound=Job) @@ -351,7 +359,9 @@ def _make_data_bin(pub: EstimatorPub) -> DataBin: return make_data_bin((("evs", NDArray[np.float]), ("stds", NDArray[np.float])), pub.shape) @abstractmethod - def run(self, pubs: Iterable[EstimatorPubLike], precision: float | None = None) -> Job: + def run( + self, pubs: Iterable[EstimatorPubLike], precision: float | None = None + ) -> BasePrimitiveJob[PrimitiveResult[PubResult]]: """Estimate expectation values for each provided pub (Primitive Unified Bloc). Args: diff --git a/qiskit/primitives/containers/estimator_pub.py b/qiskit/primitives/containers/estimator_pub.py index 351f0f50564d..281b1fb5c80c 100644 --- a/qiskit/primitives/containers/estimator_pub.py +++ b/qiskit/primitives/containers/estimator_pub.py @@ -46,7 +46,7 @@ def __init__( observables: ObservablesArray, parameter_values: BindingsArray | None = None, precision: float | None = None, - validate: bool = False, + validate: bool = True, ): """Initialize an estimator pub. @@ -62,10 +62,7 @@ def __init__( self._observables = observables self._parameter_values = parameter_values or BindingsArray() self._precision = precision - - # For ShapedMixin self._shape = np.broadcast_shapes(self.observables.shape, self.parameter_values.shape) - if validate: self.validate() @@ -101,14 +98,20 @@ def coerce(cls, pub: EstimatorPubLike, precision: float | None = None) -> Estima Returns: An estimator pub. """ + # Validate precision kwarg if provided + if precision is not None: + if not isinstance(precision, Real): + raise TypeError(f"precision must be a real number, not {type(precision)}.") + if precision < 0: + raise ValueError("precision must be non-negative") if isinstance(pub, EstimatorPub): - if pub / precision is None and precision is not None: + if pub.precision is None and precision is not None: cls( circuit=pub.circuit, observables=pub.observables, parameter_values=pub.parameter_values, precision=precision, - validate=False, + validate=False, # Assume Pub is already validated ) return pub if len(pub) not in [2, 3, 4]: @@ -125,7 +128,7 @@ def coerce(cls, pub: EstimatorPubLike, precision: float | None = None) -> Estima observables=observables, parameter_values=parameter_values, precision=precision, - validate=False, + validate=True, ) def validate(self): @@ -140,7 +143,7 @@ def validate(self): if not isinstance(self.precision, Real): raise TypeError(f"precision must be a real number, not {type(self.precision)}.") if self.precision < 0: - raise ValueError("precisions must be non-negative.") + raise ValueError("precision must be non-negative.") # Cross validate circuits and observables for i, observable in enumerate(self.observables):