diff --git a/dymos/examples/brachistochrone/test/test_state_rate_introspection.py b/dymos/examples/brachistochrone/test/test_state_rate_introspection.py index 8456851bd..d35f4ca47 100644 --- a/dymos/examples/brachistochrone/test/test_state_rate_introspection.py +++ b/dymos/examples/brachistochrone/test/test_state_rate_introspection.py @@ -847,3 +847,82 @@ def test_gl(self): def test_radau(self): self._test_transcription(transcription=dm.Radau) + + +@use_tempdirs +class TestInvalidStateRateSource(unittest.TestCase): + + def test_brach_invalid_state_rate_source(self): + + class _BrachODE(om.ExplicitComponent): + + def initialize(self): + self.options.declare('num_nodes', types=int) + + def setup(self): + nn = self.options['num_nodes'] + + # Inputs + self.add_input('v', val=np.zeros(nn), desc='velocity', units='m/s') + self.add_input('g', val=9.80665 * np.ones(nn), desc='grav. acceleration', units='m/s/s') + self.add_input('theta', val=np.ones(nn), desc='angle of wire', units='rad') + + self.add_output('xdot', val=np.zeros(nn), desc='velocity component in x', units='m/s') + + self.add_output('ydot', val=np.zeros(nn), desc='velocity component in y', units='m/s') + + self.add_output('vdot', val=np.zeros(nn), desc='acceleration magnitude', units='m/s**2') + + self.add_output('check', val=np.zeros(nn), desc='check solution: v/sin(theta) = constant', + units='m/s') + + self.declare_coloring(wrt='*', method='cs') + + def compute(self, inputs, outputs): + theta = inputs['theta'] + cos_theta = np.cos(theta) + sin_theta = np.sin(theta) + g = inputs['g'] + v = inputs['v'] + + outputs['vdot'] = g * cos_theta + outputs['xdot'] = v * sin_theta + outputs['ydot'] = -v * cos_theta + outputs['check'] = v / sin_theta + + p = om.Problem(model=om.Group()) + + p.driver = om.ScipyOptimizeDriver() + + t = dm.Radau(num_segments=10, order=3) + + traj = dm.Trajectory() + phase = dm.Phase(ode_class=_BrachODE, transcription=t) + p.model.add_subsystem('traj0', traj) + traj.add_phase('phase0', phase) + + phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10)) + + phase.add_state('x', fix_initial=True, fix_final=False, rate_source='xdot') + phase.add_state('y', fix_initial=True, fix_final=False, rate_source='ydot') + + # Intentionally incorrect rate source to trigger an error during configure. + phase.add_state('v', fix_initial=True, fix_final=False, rate_source='vel_dot') + + phase.add_control('theta', + continuity=True, rate_continuity=True, + units='deg', lower=0.01, upper=179.9) + + phase.add_parameter('g', units='m/s**2') + + phase.add_boundary_constraint('x', loc='final', equals=10) + phase.add_boundary_constraint('y', loc='final', equals=5) + + # Minimize time at the end of the phase + phase.add_objective('time_phase', loc='final', scaler=10) + + with self.assertRaises(RuntimeError) as ctx: + p.setup() + + expected = 'Error during configure_states_introspection in phase traj0.phases.phase0.' + self.assertEqual(str(ctx.exception), expected) diff --git a/dymos/phase/phase.py b/dymos/phase/phase.py index a8256001e..1f622b851 100644 --- a/dymos/phase/phase.py +++ b/dymos/phase/phase.py @@ -1549,9 +1549,13 @@ def configure(self): raise ValueError(f'Invalid parameter in phase `{self.pathname}`.\n{str(e)}') from e self.configure_state_discovery() - configure_states_introspection(self.state_options, self.time_options, self.control_options, - self.parameter_options, self.polynomial_control_options, - ode) + + try: + configure_states_introspection(self.state_options, self.time_options, self.control_options, + self.parameter_options, self.polynomial_control_options, + ode) + except RuntimeError as val_err: + raise RuntimeError(f'Error during configure_states_introspection in phase {self.pathname}.') from val_err transcription.configure_time(self) transcription.configure_controls(self) diff --git a/dymos/trajectory/test/test_trajectory.py b/dymos/trajectory/test/test_trajectory.py index aa3d54528..528d66b64 100644 --- a/dymos/trajectory/test/test_trajectory.py +++ b/dymos/trajectory/test/test_trajectory.py @@ -1153,7 +1153,7 @@ def test_invalid_linkage_variable(self): # Finish Problem Setup p.model.linear_solver = om.DirectSolver() - with self.assertRaises(ValueError) as e: + with self.assertRaises(RuntimeError) as e: p.setup(check=True) self.assertEqual(str(e.exception), 'Error in linking bar from burn1 to bar in burn2: ' diff --git a/dymos/trajectory/trajectory.py b/dymos/trajectory/trajectory.py index 45ad36526..1a7d774d7 100644 --- a/dymos/trajectory/trajectory.py +++ b/dymos/trajectory/trajectory.py @@ -544,9 +544,9 @@ def _update_linkage_options_configure(self, linkage_options): shapes[i], units[i] = get_source_metadata(phases[i]._get_subsystem(rhs_source), vars[i], user_units=units[i], user_shape=_unspecified) - except ValueError: - raise ValueError(f'{info_str}: Unable to find variable \'{vars[i]}\' in ' - f'phase \'{phases[i].pathname}\' or its ODE.') + except RuntimeError as e: + raise RuntimeError(f'{info_str}: Unable to find variable \'{vars[i]}\' in ' + f'phase \'{phases[i].pathname}\' or its ODE.') linkage_options._src_a = sources['a'] linkage_options._src_b = sources['b'] diff --git a/dymos/utils/introspection.py b/dymos/utils/introspection.py index 7ed26c37c..2edfa79d5 100644 --- a/dymos/utils/introspection.py +++ b/dymos/utils/introspection.py @@ -951,7 +951,7 @@ def _get_targets_metadata(ode, name, user_targets=_unspecified, user_units=_unsp def get_source_metadata(ode, src, user_units, user_shape): """ - Return the targets of a state variable in a given ODE system. + Return the units and shape of output src in the given ODE. If the targets of the state is _unspecified, and the state name is a top level input name in the ODE, then the state values are automatically connected to that top-level input. @@ -989,7 +989,7 @@ def get_source_metadata(ode, src, user_units, user_shape): ode_outputs = ode if isinstance(ode, dict) else get_promoted_vars(ode, iotypes='output') if src not in ode_outputs: - raise ValueError(f'Unable to find the source {src} in the ODE at {ode.pathname}.') + raise RuntimeError(f'Unable to find the source {src} in the ODE.') if user_units in {None, _unspecified}: units = ode_outputs[src]['units']