Skip to content

Commit

Permalink
Add disable_validation option to SymbolicPulse (#11029)
Browse files Browse the repository at this point in the history
* Add disable_check option

* Remove per instance option, add release notes

* Release notes correction
  • Loading branch information
TsafrirA authored Oct 23, 2023
1 parent d781dcc commit 5de1a06
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 179 deletions.
86 changes: 30 additions & 56 deletions qiskit/pulse/library/symbolic_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ class SymbolicPulse(Pulse):
which greatly reduces memory footprint during the program generation.
.. _symbolic_pulse_validation:
.. rubric:: Pulse validation
When a symbolic pulse is instantiated, the method :meth:`.validate_parameters` is called,
and performs validation of the pulse. The validation process involves testing the constraint
functions and the maximal amplitude of the pulse (see below). While the validation process
will improve code stability, it will reduce performance and might create
compatibility issues (particularly with JAX). Therefore, it is possible to disable the
validation by setting the class attribute :attr:`.disable_validation` to ``True``.
.. _symbolic_pulse_constraints:
.. rubric:: Constraint functions
Expand Down Expand Up @@ -390,6 +401,8 @@ def Sawtooth(duration, amp, freq, name):
"_valid_amp_conditions",
)

disable_validation = False

# Lambdify caches keyed on sympy expressions. Returns the corresponding callable.
_envelope_lam = LambdifiedExpression("_envelope")
_constraints_lam = LambdifiedExpression("_constraints")
Expand Down Expand Up @@ -440,6 +453,8 @@ def __init__(
self._envelope = envelope
self._constraints = constraints
self._valid_amp_conditions = valid_amp_conditions
if not self.__class__.disable_validation:
self.validate_parameters()

def __getattr__(self, item):
# Get pulse parameters with attribute-like access.
Expand Down Expand Up @@ -774,7 +789,7 @@ def __new__(
consts_expr = _sigma > 0
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type=cls.alias,
duration=duration,
amp=amp,
Expand All @@ -786,9 +801,6 @@ def __new__(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


class GaussianSquare(metaclass=_PulseType):
Expand Down Expand Up @@ -902,7 +914,7 @@ def __new__(
consts_expr = sym.And(_sigma > 0, _width >= 0, _duration >= _width)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type=cls.alias,
duration=duration,
amp=amp,
Expand All @@ -914,9 +926,6 @@ def __new__(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def GaussianSquareDrag(
Expand Down Expand Up @@ -1051,7 +1060,7 @@ def GaussianSquareDrag(
consts_expr = sym.And(_sigma > 0, _width >= 0, _duration >= _width)
valid_amp_conditions_expr = sym.And(sym.Abs(_amp) <= 1.0, sym.Abs(_beta) < _sigma)

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="GaussianSquareDrag",
duration=duration,
amp=amp,
Expand All @@ -1063,9 +1072,6 @@ def GaussianSquareDrag(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def gaussian_square_echo(
Expand Down Expand Up @@ -1142,6 +1148,7 @@ def gaussian_square_echo(
name: Display name for this pulse envelope.
limit_amplitude: If ``True``, then limit the amplitude of the
waveform to 1. The default is ``True`` and the amplitude is constrained to 1.
Returns:
ScalableSymbolicPulse instance.
Raises:
Expand Down Expand Up @@ -1254,7 +1261,7 @@ def gaussian_square_echo(
# Check validity of amplitudes
valid_amp_conditions_expr = sym.And(sym.Abs(_amp) + sym.Abs(_active_amp) <= 1.0)

instance = SymbolicPulse(
return SymbolicPulse(
pulse_type="gaussian_square_echo",
duration=duration,
parameters=parameters,
Expand All @@ -1264,9 +1271,6 @@ def gaussian_square_echo(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def GaussianDeriv(
Expand Down Expand Up @@ -1318,7 +1322,7 @@ def GaussianDeriv(
consts_expr = _sigma > 0
valid_amp_conditions_expr = sym.Abs(_amp / _sigma) <= sym.exp(1 / 2)

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="GaussianDeriv",
duration=duration,
amp=amp,
Expand All @@ -1330,9 +1334,6 @@ def GaussianDeriv(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


class Drag(metaclass=_PulseType):
Expand Down Expand Up @@ -1418,7 +1419,7 @@ def __new__(
consts_expr = _sigma > 0
valid_amp_conditions_expr = sym.And(sym.Abs(_amp) <= 1.0, sym.Abs(_beta) < _sigma)

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Drag",
duration=duration,
amp=amp,
Expand All @@ -1430,9 +1431,6 @@ def __new__(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


class Constant(metaclass=_PulseType):
Expand Down Expand Up @@ -1486,7 +1484,7 @@ def __new__(

valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Constant",
duration=duration,
amp=amp,
Expand All @@ -1496,9 +1494,6 @@ def __new__(
envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Sin(
Expand Down Expand Up @@ -1551,7 +1546,7 @@ def Sin(
# This might fail for waves shorter than a single cycle
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Sin",
duration=duration,
amp=amp,
Expand All @@ -1563,9 +1558,6 @@ def Sin(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Cos(
Expand Down Expand Up @@ -1618,7 +1610,7 @@ def Cos(
# This might fail for waves shorter than a single cycle
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Cos",
duration=duration,
amp=amp,
Expand All @@ -1630,9 +1622,6 @@ def Cos(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Sawtooth(
Expand Down Expand Up @@ -1689,7 +1678,7 @@ def Sawtooth(
# This might fail for waves shorter than a single cycle
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Sawtooth",
duration=duration,
amp=amp,
Expand All @@ -1701,9 +1690,6 @@ def Sawtooth(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Triangle(
Expand Down Expand Up @@ -1760,7 +1746,7 @@ def Triangle(
# This might fail for waves shorter than a single cycle
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Triangle",
duration=duration,
amp=amp,
Expand All @@ -1772,9 +1758,6 @@ def Triangle(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Square(
Expand Down Expand Up @@ -1833,7 +1816,7 @@ def Square(
# This might fail for waves shorter than a single cycle
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Square",
duration=duration,
amp=amp,
Expand All @@ -1845,9 +1828,6 @@ def Square(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def Sech(
Expand Down Expand Up @@ -1912,7 +1892,7 @@ def Sech(

valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="Sech",
duration=duration,
amp=amp,
Expand All @@ -1924,9 +1904,6 @@ def Sech(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance


def SechDeriv(
Expand Down Expand Up @@ -1978,7 +1955,7 @@ def SechDeriv(

valid_amp_conditions_expr = sym.Abs(_amp) / _sigma <= 2.0

instance = ScalableSymbolicPulse(
return ScalableSymbolicPulse(
pulse_type="SechDeriv",
duration=duration,
amp=amp,
Expand All @@ -1990,6 +1967,3 @@ def SechDeriv(
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
)
instance.validate_parameters()

return instance
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
upgrade:
- |
Validation of :class:`qiskit.pulse.SymbolicPulse` objects can now be disabled. By setting
the class attribute :attr:`qiskit.pulse.SymbolicPulse.disable_validation` to ``False``
the method :meth:`validate_parameters` will not be triggered for all `SymbolicPulse` objects.
The automatic validation hindered JAX compatibility of the symbolic pulse library, and this
upgrade will make it easier to use Qiskit Pulse with JAX.
Note that all library pulses automatically called :meth:`validate_parameters`. However, as part
of the upgrade the call was moved directly to the initialization process of
:class:`qiskit.pulse.SymbolicPulse`. While this doesn't change the behaviour of library pulses,
custom symbolic pulses which did not call :meth:`validate_parameters` will now trigger the
method. The new class attribute will allow to easily disable this.
Loading

0 comments on commit 5de1a06

Please sign in to comment.