diff --git a/qiskit/pulse/library/symbolic_pulses.py b/qiskit/pulse/library/symbolic_pulses.py index 8c90777eae01..ca3c6a516d2b 100644 --- a/qiskit/pulse/library/symbolic_pulses.py +++ b/qiskit/pulse/library/symbolic_pulses.py @@ -265,6 +265,17 @@ class SymbolicPulse(Pulse): which greatly reduces memory footprint during the program generation. + .. _symbolic_pulse_validation: + + .. rubric:: Pulse validation + + When a symbolic pulse is instantiated, the method :meth:`.validate_parameters` is called, + and performs validation of the pulse. The validation process involves testing the constraint + functions and the maximal amplitude of the pulse (see below). While the validation process + will improve code stability, it will reduce performance and might create + compatibility issues (particularly with JAX). Therefore, it is possible to disable the + validation by setting the class attribute :attr:`.disable_validation` to ``True``. + .. _symbolic_pulse_constraints: .. rubric:: Constraint functions @@ -390,6 +401,8 @@ def Sawtooth(duration, amp, freq, name): "_valid_amp_conditions", ) + disable_validation = False + # Lambdify caches keyed on sympy expressions. Returns the corresponding callable. _envelope_lam = LambdifiedExpression("_envelope") _constraints_lam = LambdifiedExpression("_constraints") @@ -440,6 +453,8 @@ def __init__( self._envelope = envelope self._constraints = constraints self._valid_amp_conditions = valid_amp_conditions + if not self.__class__.disable_validation: + self.validate_parameters() def __getattr__(self, item): # Get pulse parameters with attribute-like access. @@ -774,7 +789,7 @@ def __new__( consts_expr = _sigma > 0 valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type=cls.alias, duration=duration, amp=amp, @@ -786,9 +801,6 @@ def __new__( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance class GaussianSquare(metaclass=_PulseType): @@ -902,7 +914,7 @@ def __new__( consts_expr = sym.And(_sigma > 0, _width >= 0, _duration >= _width) valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type=cls.alias, duration=duration, amp=amp, @@ -914,9 +926,6 @@ def __new__( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def GaussianSquareDrag( @@ -1051,7 +1060,7 @@ def GaussianSquareDrag( consts_expr = sym.And(_sigma > 0, _width >= 0, _duration >= _width) valid_amp_conditions_expr = sym.And(sym.Abs(_amp) <= 1.0, sym.Abs(_beta) < _sigma) - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="GaussianSquareDrag", duration=duration, amp=amp, @@ -1063,9 +1072,6 @@ def GaussianSquareDrag( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def gaussian_square_echo( @@ -1142,6 +1148,7 @@ def gaussian_square_echo( name: Display name for this pulse envelope. limit_amplitude: If ``True``, then limit the amplitude of the waveform to 1. The default is ``True`` and the amplitude is constrained to 1. + Returns: ScalableSymbolicPulse instance. Raises: @@ -1254,7 +1261,7 @@ def gaussian_square_echo( # Check validity of amplitudes valid_amp_conditions_expr = sym.And(sym.Abs(_amp) + sym.Abs(_active_amp) <= 1.0) - instance = SymbolicPulse( + return SymbolicPulse( pulse_type="gaussian_square_echo", duration=duration, parameters=parameters, @@ -1264,9 +1271,6 @@ def gaussian_square_echo( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def GaussianDeriv( @@ -1318,7 +1322,7 @@ def GaussianDeriv( consts_expr = _sigma > 0 valid_amp_conditions_expr = sym.Abs(_amp / _sigma) <= sym.exp(1 / 2) - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="GaussianDeriv", duration=duration, amp=amp, @@ -1330,9 +1334,6 @@ def GaussianDeriv( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance class Drag(metaclass=_PulseType): @@ -1418,7 +1419,7 @@ def __new__( consts_expr = _sigma > 0 valid_amp_conditions_expr = sym.And(sym.Abs(_amp) <= 1.0, sym.Abs(_beta) < _sigma) - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Drag", duration=duration, amp=amp, @@ -1430,9 +1431,6 @@ def __new__( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance class Constant(metaclass=_PulseType): @@ -1486,7 +1484,7 @@ def __new__( valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Constant", duration=duration, amp=amp, @@ -1496,9 +1494,6 @@ def __new__( envelope=envelope_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Sin( @@ -1551,7 +1546,7 @@ def Sin( # This might fail for waves shorter than a single cycle valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Sin", duration=duration, amp=amp, @@ -1563,9 +1558,6 @@ def Sin( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Cos( @@ -1618,7 +1610,7 @@ def Cos( # This might fail for waves shorter than a single cycle valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Cos", duration=duration, amp=amp, @@ -1630,9 +1622,6 @@ def Cos( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Sawtooth( @@ -1689,7 +1678,7 @@ def Sawtooth( # This might fail for waves shorter than a single cycle valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Sawtooth", duration=duration, amp=amp, @@ -1701,9 +1690,6 @@ def Sawtooth( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Triangle( @@ -1760,7 +1746,7 @@ def Triangle( # This might fail for waves shorter than a single cycle valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Triangle", duration=duration, amp=amp, @@ -1772,9 +1758,6 @@ def Triangle( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Square( @@ -1833,7 +1816,7 @@ def Square( # This might fail for waves shorter than a single cycle valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Square", duration=duration, amp=amp, @@ -1845,9 +1828,6 @@ def Square( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def Sech( @@ -1912,7 +1892,7 @@ def Sech( valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="Sech", duration=duration, amp=amp, @@ -1924,9 +1904,6 @@ def Sech( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance def SechDeriv( @@ -1978,7 +1955,7 @@ def SechDeriv( valid_amp_conditions_expr = sym.Abs(_amp) / _sigma <= 2.0 - instance = ScalableSymbolicPulse( + return ScalableSymbolicPulse( pulse_type="SechDeriv", duration=duration, amp=amp, @@ -1990,6 +1967,3 @@ def SechDeriv( constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, ) - instance.validate_parameters() - - return instance diff --git a/releasenotes/notes/symbolic-pulse-disable-validation-19cd8506b3a839b6.yaml b/releasenotes/notes/symbolic-pulse-disable-validation-19cd8506b3a839b6.yaml new file mode 100644 index 000000000000..0a7c6c599000 --- /dev/null +++ b/releasenotes/notes/symbolic-pulse-disable-validation-19cd8506b3a839b6.yaml @@ -0,0 +1,14 @@ +--- +upgrade: + - | + Validation of :class:`qiskit.pulse.SymbolicPulse` objects can now be disabled. By setting + the class attribute :attr:`qiskit.pulse.SymbolicPulse.disable_validation` to ``False`` + the method :meth:`validate_parameters` will not be triggered for all `SymbolicPulse` objects. + The automatic validation hindered JAX compatibility of the symbolic pulse library, and this + upgrade will make it easier to use Qiskit Pulse with JAX. + + Note that all library pulses automatically called :meth:`validate_parameters`. However, as part + of the upgrade the call was moved directly to the initialization process of + :class:`qiskit.pulse.SymbolicPulse`. While this doesn't change the behaviour of library pulses, + custom symbolic pulses which did not call :meth:`validate_parameters` will now trigger the + method. The new class attribute will allow to easily disable this. diff --git a/test/python/pulse/test_pulse_lib.py b/test/python/pulse/test_pulse_lib.py index 575360f81758..524861020d22 100644 --- a/test/python/pulse/test_pulse_lib.py +++ b/test/python/pulse/test_pulse_lib.py @@ -672,17 +672,43 @@ def test_param_validation(self): with self.assertRaises(PulseError): Drag(duration=25, amp=0.5, sigma=-7.8, beta=4, angle=np.pi / 3) - def test_gaussian_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" + def test_class_level_limit_amplitude(self): + """Test that the check for amplitude less than or equal to 1 can + be disabled on the class level. + + Tests for representative examples. + """ with self.assertRaises(PulseError): Gaussian(duration=100, sigma=1.0, amp=1.7, angle=np.pi * 1.1) with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): waveform = Gaussian(duration=100, sigma=1.0, amp=1.7, angle=np.pi * 1.1) self.assertGreater(np.abs(waveform.amp), 1.0) + waveform = GaussianSquare(duration=100, sigma=1.0, amp=1.5, width=10, angle=np.pi / 5) + self.assertGreater(np.abs(waveform.amp), 1.0) + waveform = GaussianSquareDrag(duration=100, sigma=1.0, amp=1.1, beta=0.1, width=10) + self.assertGreater(np.abs(waveform.amp), 1.0) + + def test_class_level_disable_validation(self): + """Test that pulse validation can be disabled on the class level. + + Tests for representative examples. + """ + with self.assertRaises(PulseError): + Gaussian(duration=100, sigma=-1.0, amp=0.5, angle=np.pi * 1.1) + + with patch( + "qiskit.pulse.library.symbolic_pulses.SymbolicPulse.disable_validation", new=True + ): + waveform = Gaussian(duration=100, sigma=-1.0, amp=0.5, angle=np.pi * 1.1) + self.assertLess(waveform.sigma, 0) + waveform = GaussianSquare(duration=100, sigma=1.0, amp=0.5, width=1000, angle=np.pi / 5) + self.assertGreater(waveform.width, waveform.duration) + waveform = GaussianSquareDrag(duration=100, sigma=1.0, amp=1.1, beta=0.1, width=-1) + self.assertLess(waveform.width, 0) def test_gaussian_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Gaussian instance.""" with self.assertRaises(PulseError): Gaussian(duration=100, sigma=1.0, amp=1.6, angle=np.pi / 2.5) @@ -691,17 +717,8 @@ def test_gaussian_limit_amplitude_per_instance(self): ) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - GaussianSquare(duration=100, sigma=1.0, amp=1.5, width=10, angle=np.pi / 5) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = GaussianSquare(duration=100, sigma=1.0, amp=1.5, width=10, angle=np.pi / 5) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per GaussianSquare instance.""" with self.assertRaises(PulseError): GaussianSquare(duration=100, sigma=1.0, amp=1.5, width=10, angle=np.pi / 3) @@ -710,17 +727,8 @@ def test_gaussian_square_limit_amplitude_per_instance(self): ) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_drag_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - GaussianSquareDrag(duration=100, sigma=1.0, amp=1.1, beta=0.1, width=10) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = GaussianSquareDrag(duration=100, sigma=1.0, amp=1.1, beta=0.1, width=10) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_drag_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per GaussianSquareDrag instance.""" with self.assertRaises(PulseError): GaussianSquareDrag(duration=100, sigma=1.0, amp=1.1, beta=0.1, width=10) @@ -729,17 +737,8 @@ def test_gaussian_square_drag_limit_amplitude_per_instance(self): ) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_echo_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - gaussian_square_echo(duration=1000, sigma=4.0, amp=1.01, width=100) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = gaussian_square_echo(duration=100, sigma=1.0, amp=1.1, width=10) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_square_echo_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per GaussianSquareEcho instance.""" with self.assertRaises(PulseError): gaussian_square_echo(duration=1000, sigma=4.0, amp=1.01, width=100) @@ -748,17 +747,8 @@ def test_gaussian_square_echo_limit_amplitude_per_instance(self): ) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_drag_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Drag(duration=100, sigma=1.0, beta=1.0, amp=1.8, angle=np.pi * 0.3) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Drag(duration=100, sigma=1.0, beta=1.0, amp=1.8, angle=np.pi * 0.3) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_drag_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per DRAG instance.""" with self.assertRaises(PulseError): Drag(duration=100, sigma=1.0, beta=1.0, amp=1.8, angle=np.pi * 0.3) @@ -767,136 +757,64 @@ def test_drag_limit_amplitude_per_instance(self): ) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_constant_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Constant(duration=100, amp=1.3, angle=0.1) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Constant(duration=100, amp=1.3, angle=0.1) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_constant_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Constant instance.""" with self.assertRaises(PulseError): Constant(duration=100, amp=1.6, angle=0.5) waveform = Constant(duration=100, amp=1.6, angle=0.5, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sin_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Sin(duration=100, amp=1.1, phase=0) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Sin(duration=100, amp=1.1, phase=0) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sin_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Sin instance.""" with self.assertRaises(PulseError): Sin(duration=100, amp=1.1, phase=0) waveform = Sin(duration=100, amp=1.1, phase=0, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sawtooth_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Sawtooth(duration=100, amp=1.1, phase=0) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Sawtooth(duration=100, amp=1.1, phase=0) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sawtooth_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Sawtooth instance.""" with self.assertRaises(PulseError): Sawtooth(duration=100, amp=1.1, phase=0) waveform = Sawtooth(duration=100, amp=1.1, phase=0, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_triangle_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Triangle(duration=100, amp=1.1, phase=0) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Triangle(duration=100, amp=1.1, phase=0) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_triangle_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Triangle instance.""" with self.assertRaises(PulseError): Triangle(duration=100, amp=1.1, phase=0) waveform = Triangle(duration=100, amp=1.1, phase=0, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_square_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Square(duration=100, amp=1.1, phase=0) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Square(duration=100, amp=1.1, phase=0) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_square_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Square instance.""" with self.assertRaises(PulseError): Square(duration=100, amp=1.1, phase=0) waveform = Square(duration=100, amp=1.1, phase=0, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_gaussian_deriv_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - GaussianDeriv(duration=100, amp=5, sigma=1) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = GaussianDeriv(duration=100, amp=5, sigma=1) - self.assertGreater(np.abs(waveform.amp / waveform.sigma), np.exp(0.5)) - def test_gaussian_deriv_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per GaussianDeriv instance.""" with self.assertRaises(PulseError): GaussianDeriv(duration=100, amp=5, sigma=1) waveform = GaussianDeriv(duration=100, amp=5, sigma=1, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp / waveform.sigma), np.exp(0.5)) - def test_sech_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - Sech(duration=100, amp=5, sigma=1) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = Sech(duration=100, amp=5, sigma=1) - self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sech_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per Sech instance.""" with self.assertRaises(PulseError): Sech(duration=100, amp=5, sigma=1) waveform = Sech(duration=100, amp=5, sigma=1, limit_amplitude=False) self.assertGreater(np.abs(waveform.amp), 1.0) - def test_sech_deriv_limit_amplitude(self): - """Test that the check for amplitude less than or equal to 1 can be disabled.""" - with self.assertRaises(PulseError): - SechDeriv(duration=100, amp=5, sigma=1) - - with patch("qiskit.pulse.library.pulse.Pulse.limit_amplitude", new=False): - waveform = SechDeriv(duration=100, amp=5, sigma=1) - self.assertGreater(np.abs(waveform.amp) / waveform.sigma, 2.0) - def test_sech_deriv_limit_amplitude_per_instance(self): - """Test that the check for amplitude per instance.""" + """Test limit amplitude option per SechDeriv instance.""" with self.assertRaises(PulseError): SechDeriv(duration=100, amp=5, sigma=1)