Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed an error when an invalid rate source is provided for a state. #761

Merged
merged 4 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,82 @@ def test_gl(self):

def test_radau(self):
self._test_transcription(transcription=dm.Radau)


@use_tempdirs
class TestInvalidStateRateSource(unittest.TestCase):

def test_brach_invalid_state_rate_source(self):

class _BrachODE(om.ExplicitComponent):

def initialize(self):
self.options.declare('num_nodes', types=int)

def setup(self):
nn = self.options['num_nodes']

# Inputs
self.add_input('v', val=np.zeros(nn), desc='velocity', units='m/s')
self.add_input('g', val=9.80665 * np.ones(nn), desc='grav. acceleration', units='m/s/s')
self.add_input('theta', val=np.ones(nn), desc='angle of wire', units='rad')

self.add_output('xdot', val=np.zeros(nn), desc='velocity component in x', units='m/s')

self.add_output('ydot', val=np.zeros(nn), desc='velocity component in y', units='m/s')

self.add_output('vdot', val=np.zeros(nn), desc='acceleration magnitude', units='m/s**2')

self.add_output('check', val=np.zeros(nn), desc='check solution: v/sin(theta) = constant',
units='m/s')

self.declare_coloring(wrt='*', method='cs')

def compute(self, inputs, outputs):
theta = inputs['theta']
cos_theta = np.cos(theta)
sin_theta = np.sin(theta)
g = inputs['g']
v = inputs['v']

outputs['vdot'] = g * cos_theta
outputs['xdot'] = v * sin_theta
outputs['ydot'] = -v * cos_theta
outputs['check'] = v / sin_theta

p = om.Problem(model=om.Group())

p.driver = om.ScipyOptimizeDriver()

t = dm.Radau(num_segments=10, order=3)

traj = dm.Trajectory()
phase = dm.Phase(ode_class=_BrachODE, transcription=t)
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

phase.set_time_options(fix_initial=True, duration_bounds=(.5, 10))

phase.add_state('x', fix_initial=True, fix_final=False, rate_source='xdot')
phase.add_state('y', fix_initial=True, fix_final=False, rate_source='ydot')

# Intentionally incorrect rate source to trigger an error during configure.
phase.add_state('v', fix_initial=True, fix_final=False, rate_source='vel_dot')

phase.add_control('theta',
continuity=True, rate_continuity=True,
units='deg', lower=0.01, upper=179.9)

phase.add_parameter('g', units='m/s**2')

phase.add_boundary_constraint('x', loc='final', equals=10)
phase.add_boundary_constraint('y', loc='final', equals=5)

# Minimize time at the end of the phase
phase.add_objective('time_phase', loc='final', scaler=10)

with self.assertRaises(RuntimeError) as ctx:
p.setup()

expected = 'Error during configure_states_introspection in phase traj0.phases.phase0.'
self.assertEqual(str(ctx.exception), expected)
10 changes: 7 additions & 3 deletions dymos/phase/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,9 +1549,13 @@ def configure(self):
raise ValueError(f'Invalid parameter in phase `{self.pathname}`.\n{str(e)}') from e

self.configure_state_discovery()
configure_states_introspection(self.state_options, self.time_options, self.control_options,
self.parameter_options, self.polynomial_control_options,
ode)

try:
configure_states_introspection(self.state_options, self.time_options, self.control_options,
self.parameter_options, self.polynomial_control_options,
ode)
except RuntimeError as val_err:
raise RuntimeError(f'Error during configure_states_introspection in phase {self.pathname}.') from val_err

transcription.configure_time(self)
transcription.configure_controls(self)
Expand Down
2 changes: 1 addition & 1 deletion dymos/trajectory/test/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def test_invalid_linkage_variable(self):
# Finish Problem Setup
p.model.linear_solver = om.DirectSolver()

with self.assertRaises(ValueError) as e:
with self.assertRaises(RuntimeError) as e:
p.setup(check=True)

self.assertEqual(str(e.exception), 'Error in linking bar from burn1 to bar in burn2: '
Expand Down
6 changes: 3 additions & 3 deletions dymos/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,9 @@ def _update_linkage_options_configure(self, linkage_options):
shapes[i], units[i] = get_source_metadata(phases[i]._get_subsystem(rhs_source),
vars[i], user_units=units[i],
user_shape=_unspecified)
except ValueError:
raise ValueError(f'{info_str}: Unable to find variable \'{vars[i]}\' in '
f'phase \'{phases[i].pathname}\' or its ODE.')
except RuntimeError as e:
raise RuntimeError(f'{info_str}: Unable to find variable \'{vars[i]}\' in '
f'phase \'{phases[i].pathname}\' or its ODE.')

linkage_options._src_a = sources['a']
linkage_options._src_b = sources['b']
Expand Down
4 changes: 2 additions & 2 deletions dymos/utils/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def _get_targets_metadata(ode, name, user_targets=_unspecified, user_units=_unsp

def get_source_metadata(ode, src, user_units, user_shape):
"""
Return the targets of a state variable in a given ODE system.
Return the units and shape of output src in the given ODE.

If the targets of the state is _unspecified, and the state name is a top level input name
in the ODE, then the state values are automatically connected to that top-level input.
Expand Down Expand Up @@ -989,7 +989,7 @@ def get_source_metadata(ode, src, user_units, user_shape):
ode_outputs = ode if isinstance(ode, dict) else get_promoted_vars(ode, iotypes='output')

if src not in ode_outputs:
raise ValueError(f'Unable to find the source {src} in the ODE at {ode.pathname}.')
raise RuntimeError(f'Unable to find the source {src} in the ODE.')

if user_units in {None, _unspecified}:
units = ode_outputs[src]['units']
Expand Down