Skip to content

Commit

Permalink
Update EstimatorV2 run return type, fix some typos
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseclectic committed Jan 16, 2024
1 parent dea7b3b commit a5d1e55
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
14 changes: 12 additions & 2 deletions qiskit/primitives/base/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,17 @@
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_func

from ..containers import make_data_bin, DataBin, EstimatorPub, EstimatorPubLike
from ..containers import (
make_data_bin,
DataBin,
EstimatorPub,
EstimatorPubLike,
PrimitiveResult,
PubResult,
)
from . import validation
from .base_primitive import BasePrimitive
from .base_primitive_job import BasePrimitiveJob

T = TypeVar("T", bound=Job)

Expand Down Expand Up @@ -351,7 +359,9 @@ def _make_data_bin(pub: EstimatorPub) -> DataBin:
return make_data_bin((("evs", NDArray[np.float]), ("stds", NDArray[np.float])), pub.shape)

@abstractmethod
def run(self, pubs: Iterable[EstimatorPubLike], precision: float | None = None) -> Job:
def run(
self, pubs: Iterable[EstimatorPubLike], precision: float | None = None
) -> BasePrimitiveJob[PrimitiveResult[PubResult]]:
"""Estimate expectation values for each provided pub (Primitive Unified Bloc).
Args:
Expand Down
19 changes: 11 additions & 8 deletions qiskit/primitives/containers/estimator_pub.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
observables: ObservablesArray,
parameter_values: BindingsArray | None = None,
precision: float | None = None,
validate: bool = False,
validate: bool = True,
):
"""Initialize an estimator pub.
Expand All @@ -62,10 +62,7 @@ def __init__(
self._observables = observables
self._parameter_values = parameter_values or BindingsArray()
self._precision = precision

# For ShapedMixin
self._shape = np.broadcast_shapes(self.observables.shape, self.parameter_values.shape)

if validate:
self.validate()

Expand Down Expand Up @@ -101,14 +98,20 @@ def coerce(cls, pub: EstimatorPubLike, precision: float | None = None) -> Estima
Returns:
An estimator pub.
"""
# Validate precision kwarg if provided
if precision is not None:
if not isinstance(precision, Real):
raise TypeError(f"precision must be a real number, not {type(precision)}.")
if precision < 0:
raise ValueError("precision must be non-negative")
if isinstance(pub, EstimatorPub):
if pub / precision is None and precision is not None:
if pub.precision is None and precision is not None:
cls(
circuit=pub.circuit,
observables=pub.observables,
parameter_values=pub.parameter_values,
precision=precision,
validate=False,
validate=False, # Assume Pub is already validated
)
return pub
if len(pub) not in [2, 3, 4]:
Expand All @@ -125,7 +128,7 @@ def coerce(cls, pub: EstimatorPubLike, precision: float | None = None) -> Estima
observables=observables,
parameter_values=parameter_values,
precision=precision,
validate=False,
validate=True,
)

def validate(self):
Expand All @@ -140,7 +143,7 @@ def validate(self):
if not isinstance(self.precision, Real):
raise TypeError(f"precision must be a real number, not {type(self.precision)}.")
if self.precision < 0:
raise ValueError("precisions must be non-negative.")
raise ValueError("precision must be non-negative.")

# Cross validate circuits and observables
for i, observable in enumerate(self.observables):
Expand Down

0 comments on commit a5d1e55

Please sign in to comment.