diff --git a/pybamm/expression_tree/operations/evaluate.py b/pybamm/expression_tree/operations/evaluate.py index a18010ae71..08b40eb605 100644 --- a/pybamm/expression_tree/operations/evaluate.py +++ b/pybamm/expression_tree/operations/evaluate.py @@ -485,6 +485,7 @@ def __init__(self, symbol): python_str = python_str + "\nself._evaluate = evaluate" self._python_str = python_str + self._result_var = result_var self._symbol = symbol # compile and run the generated python code, @@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): else: return result + def __getstate__(self): + # Control the state of instances of EvaluatorPython + # before pickling. Method "_evaluate" cannot be pickled. + # See https://github.com/pybamm-team/PyBaMM/issues/1283 + state = self.__dict__.copy() + del state["_evaluate"] + return state + + def __setstate__(self, state): + # Restore pickled attributes and + # compile code from "python_str" + # Execution of bytecode (re)adds attribute + # "_method" + self.__dict__.update(state) + compiled_function = compile(self._python_str, self._result_var, "exec") + exec(compiled_function) + class EvaluatorJax: """ diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 927c3569c9..bb3418922e 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -589,6 +589,13 @@ def solve( for inputs in inputs_list ] + # Cannot use multiprocessing with model in "jax" format + if(len(inputs_list) > 1) and model.convert_to_format == "jax": + raise pybamm.SolverError( + "Cannot solve list of inputs with multiprocessing " + "when model in format \"jax\"." + ) + # Set up timer = pybamm.Timer() @@ -731,6 +738,8 @@ def solve( ext_and_inputs_list, ), ) + p.close() + p.join() # Setting the solve time for each segment. # pybamm.Solution.append assumes attribute # solve_time. diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 4f9d2a0a47..b3851c6f0f 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -200,10 +200,9 @@ def test_model_solver_with_inputs(self): var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} - model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))] # No need to set parameters; can use base discretisation (no spatial # operators) - + model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))] # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -218,17 +217,42 @@ def test_model_solver_with_inputs(self): np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) def test_model_solver_multiple_inputs_happy_path(self): + for convert_to_format in ["python", "casadi"]: + # Create model + model = pybamm.BaseModel() + model.convert_to_format = convert_to_format + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1} + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") + t_eval = np.linspace(0, 10, 100) + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + for i in range(ninputs): + with self.subTest(i=i): + solution = solutions[i] + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) + ) + + def test_model_solver_multiple_inputs_discontinuity_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = "casadi" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} - # No need to set parameters; can use base discretisation (no spatial - # operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -240,27 +264,30 @@ def test_model_solver_multiple_inputs_happy_path(self): ninputs = 8 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - for i in range(ninputs): - with self.subTest(i=i): - solution = solutions[i] - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose( - solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) - ) + model.events = [ + pybamm.Event( + "discontinuity", + pybamm.Scalar(t_eval[-1] / 2), + event_type=pybamm.EventType.DISCONTINUITY, + ) + ] + with self.assertRaisesRegex( + pybamm.SolverError, + ( + "Cannot solve for a list of input parameters" + " sets with discontinuities" + ), + ): + solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - def test_model_solver_multiple_inputs_discontinuity_error(self): + def test_model_solver_multiple_inputs_initial_conditions_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = "casadi" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} - model.initial_conditions = {var: 1} - # No need to set parameters; can use base discretisation (no spatial - # operators) + model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")} # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -272,28 +299,16 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): ninputs = 8 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - model.events = [ - pybamm.Event( - "discontinuity", - pybamm.Scalar(t_eval[-1] / 2), - event_type=pybamm.EventType.DISCONTINUITY, - ) - ] with self.assertRaisesRegex( pybamm.SolverError, - ( - "Cannot solve for a list of input parameters" - " sets with discontinuities" - ), + ("Input parameters cannot appear in expression " "for initial conditions."), ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - def test_model_solver_multiple_inputs_initial_conditions_error(self): + def test_model_solver_multiple_inputs_jax_format_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. - model.convert_to_format = "casadi" + model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} @@ -314,8 +329,8 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): with self.assertRaisesRegex( pybamm.SolverError, ( - "Input parameters cannot appear in expression " - "for initial conditions." + "Cannot solve list of inputs with multiprocessing " + 'when model in format "jax".' ), ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) diff --git a/tox.ini b/tox.ini index 4e90a958c3..b7f094fcee 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,11 @@ deps = scikits.odes commands = coverage run run-tests.py --nosub + # Some tests make use of multiple processes through + # multiprocessing. Coverage data is then generated for each + # process separately and data must then be combined into one + # single coverage data file. + coverage combine coverage xml [testenv:docs] @@ -114,4 +119,7 @@ ignore= W605, [coverage:run] -source = pybamm \ No newline at end of file +source = pybamm +# By default coverage data isn't collected in forked processes, see +# https://coverage.readthedocs.io/en/coverage-5.3.1/subprocess.html +concurrency = multiprocessing