Skip to content

Commit

Permalink
Fixed an error when an invalid rate source is provided for a state. (#…
Browse files Browse the repository at this point in the history
…761)

* Fixed an error raised during configure_states_introspection due to ODE outputs being passed in rather than the ODE itself.
  • Loading branch information
robfalck authored Jun 15, 2022
1 parent 91dc2f4 commit e795dcd
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,82 @@ def test_gl(self):

def test_radau(self):
self._test_transcription(transcription=dm.Radau)


@use_tempdirs
class TestInvalidStateRateSource(unittest.TestCase):

def test_brach_invalid_state_rate_source(self):

class _BrachODE(om.ExplicitComponent):

def initialize(self):
self.options.declare('num_nodes', types=int)

def setup(self):
nn = self.options['num_nodes']

# Inputs
self.add_input('v', val=np.zeros(nn), desc='velocity', units='m/s')
self.add_input('g', val=9.80665 * np.ones(nn), desc='grav. acceleration', units='m/s/s')
self.add_input('theta', val=np.ones(nn), desc='angle of wire', units='rad')

self.add_output('xdot', val=np.zeros(nn), desc='velocity component in x', units='m/s')

self.add_output('ydot', val=np.zeros(nn), desc='velocity component in y', units='m/s')

self.add_output('vdot', val=np.zeros(nn), desc='acceleration magnitude', units='m/s**2')

self.add_output('check', val=np.zeros(nn), desc='check solution: v/sin(theta) = constant',
units='m/s')

self.declare_coloring(wrt='*', method='cs')

def compute(self, inputs, outputs):
theta = inputs['theta']
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
g = inputs['g']
v = inputs['v']

outputs['vdot'] = g * cos_theta
outputs['xdot'] = v * sin_theta
outputs['ydot'] = -v * cos_theta
outputs['check'] = v / sin_theta

p = om.Problem(model=om.Group())

p.driver = om.ScipyOptimizeDriver()

t = dm.Radau(num_segments=10, order=3)

traj = dm.Trajectory()
phase = dm.Phase(ode_class=_BrachODE, transcription=t)
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=False, rate_source='xdot')
phase.add_state('y', fix_initial=True, fix_final=False, rate_source='ydot')

# Intentionally incorrect rate source to trigger an error during configure.
phase.add_state('v', fix_initial=True, fix_final=False, rate_source='vel_dot')

phase.add_control('theta',
continuity=True, rate_continuity=True,
units='deg', lower=0.01, upper=179.9)

phase.add_parameter('g', units='m/s**2')

phase.add_boundary_constraint('x', loc='final', equals=10)
phase.add_boundary_constraint('y', loc='final', equals=5)

# Minimize time at the end of the phase
phase.add_objective('time_phase', loc='final', scaler=10)

with self.assertRaises(RuntimeError) as ctx:
p.setup()

expected = 'Error during configure_states_introspection in phase traj0.phases.phase0.'
self.assertEqual(str(ctx.exception), expected)
10 changes: 7 additions & 3 deletions dymos/phase/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,9 +1549,13 @@ def configure(self):
raise ValueError(f'Invalid parameter in phase `{self.pathname}`.\n{str(e)}') from e

self.configure_state_discovery()
configure_states_introspection(self.state_options, self.time_options, self.control_options,
self.parameter_options, self.polynomial_control_options,
ode)

try:
configure_states_introspection(self.state_options, self.time_options, self.control_options,
self.parameter_options, self.polynomial_control_options,
ode)
except RuntimeError as val_err:
raise RuntimeError(f'Error during configure_states_introspection in phase {self.pathname}.') from val_err

transcription.configure_time(self)
transcription.configure_controls(self)
Expand Down
2 changes: 1 addition & 1 deletion dymos/trajectory/test/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def test_invalid_linkage_variable(self):
# Finish Problem Setup
p.model.linear_solver = om.DirectSolver()

with self.assertRaises(ValueError) as e:
with self.assertRaises(RuntimeError) as e:
p.setup(check=True)

self.assertEqual(str(e.exception), 'Error in linking bar from burn1 to bar in burn2: '
Expand Down
6 changes: 3 additions & 3 deletions dymos/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,9 @@ def _update_linkage_options_configure(self, linkage_options):
shapes[i], units[i] = get_source_metadata(phases[i]._get_subsystem(rhs_source),
vars[i], user_units=units[i],
user_shape=_unspecified)
except ValueError:
raise ValueError(f'{info_str}: Unable to find variable \'{vars[i]}\' in '
f'phase \'{phases[i].pathname}\' or its ODE.')
except RuntimeError as e:
raise RuntimeError(f'{info_str}: Unable to find variable \'{vars[i]}\' in '
f'phase \'{phases[i].pathname}\' or its ODE.')

linkage_options._src_a = sources['a']
linkage_options._src_b = sources['b']
Expand Down
4 changes: 2 additions & 2 deletions dymos/utils/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def _get_targets_metadata(ode, name, user_targets=_unspecified, user_units=_unsp

def get_source_metadata(ode, src, user_units, user_shape):
"""
Return the targets of a state variable in a given ODE system.
Return the units and shape of output src in the given ODE.
If the targets of the state is _unspecified, and the state name is a top level input name
in the ODE, then the state values are automatically connected to that top-level input.
Expand Down Expand Up @@ -989,7 +989,7 @@ def get_source_metadata(ode, src, user_units, user_shape):
ode_outputs = ode if isinstance(ode, dict) else get_promoted_vars(ode, iotypes='output')

if src not in ode_outputs:
raise ValueError(f'Unable to find the source {src} in the ODE at {ode.pathname}.')
raise RuntimeError(f'Unable to find the source {src} in the ODE.')

if user_units in {None, _unspecified}:
units = ode_outputs[src]['units']
Expand Down

0 comments on commit e795dcd

Please sign in to comment.