From dab096d31a669a153bce0b8a119023345f4a7482 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Fri, 11 Dec 2020 16:11:43 +0000 Subject: [PATCH 1/7] Add methods getstate and setstate to avoid including method _evaluate to pickle --- pybamm/expression_tree/operations/evaluate.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pybamm/expression_tree/operations/evaluate.py b/pybamm/expression_tree/operations/evaluate.py index a18010ae71..08b40eb605 100644 --- a/pybamm/expression_tree/operations/evaluate.py +++ b/pybamm/expression_tree/operations/evaluate.py @@ -485,6 +485,7 @@ def __init__(self, symbol): python_str = python_str + "\nself._evaluate = evaluate" self._python_str = python_str + self._result_var = result_var self._symbol = symbol # compile and run the generated python code, @@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): else: return result + def __getstate__(self): + # Control the state of instances of EvaluatorPython + # before pickling. Method "_evaluate" cannot be pickled. + # See https://github.com/pybamm-team/PyBaMM/issues/1283 + state = self.__dict__.copy() + del state["_evaluate"] + return state + + def __setstate__(self, state): + # Restore pickled attributes and + # compile code from "python_str" + # Execution of bytecode (re)adds attribute + # "_method" + self.__dict__.update(state) + compiled_function = compile(self._python_str, self._result_var, "exec") + exec(compiled_function) + class EvaluatorJax: """ From 35b26ddb56a8e9b7ba1c071de8ac09599593ba30 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Mon, 14 Dec 2020 11:47:28 +0000 Subject: [PATCH 2/7] Test solving multiple inputs using multiprocessing with model in both python and casadi mode --- tests/unit/test_solvers/test_scipy_solver.py | 64 ++++++++++---------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 4f9d2a0a47..02dd106274 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -218,36 +218,37 @@ def test_model_solver_with_inputs(self): np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) def test_model_solver_multiple_inputs_happy_path(self): - # Create model - model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. - model.convert_to_format = "casadi" - domain = ["negative electrode", "separator", "positive electrode"] - var = pybamm.Variable("var", domain=domain) - model.rhs = {var: -pybamm.InputParameter("rate") * var} - model.initial_conditions = {var: 1} - # No need to set parameters; can use base discretisation (no spatial - # operators) - # create discretisation - mesh = get_mesh_for_testing() - spatial_methods = {"macroscale": pybamm.FiniteVolume()} - disc = pybamm.Discretisation(mesh, spatial_methods) - disc.process_model(model) - - solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") - t_eval = np.linspace(0, 10, 100) - ninputs = 8 - inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + for convert_to_format in ["python", "casadi"]: + # Create model + model = pybamm.BaseModel() + # Covert to casadi instead of python to avoid pickling of + # "EvaluatorPython" objects. + model.convert_to_format = convert_to_format + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1} + # No need to set parameters; can use base discretisation (no spatial + # operators) + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) - solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - for i in range(ninputs): - with self.subTest(i=i): - solution = solutions[i] - np.testing.assert_array_equal(solution.t, t_eval) - np.testing.assert_allclose( - solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) - ) + solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") + t_eval = np.linspace(0, 10, 100) + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + for i in range(ninputs): + with self.subTest(i=i): + solution = solutions[i] + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], np.exp(-0.01 * (i + 1) * solution.t) + ) def test_model_solver_multiple_inputs_discontinuity_error(self): # Create model @@ -313,10 +314,7 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): with self.assertRaisesRegex( pybamm.SolverError, - ( - "Input parameters cannot appear in expression " - "for initial conditions." - ), + ("Input parameters cannot appear in expression " "for initial conditions."), ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) From 35db29dddc66200fb1d31e98dd9dc88ff42f994c Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Mon, 14 Dec 2020 11:59:23 +0000 Subject: [PATCH 3/7] Raise SolverError if attempt to solve for list of inputs with model in jax format. --- pybamm/solvers/base_solver.py | 6 ++++ tests/unit/test_solvers/test_scipy_solver.py | 32 ++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 927c3569c9..01402ee1fc 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -545,6 +545,12 @@ def solve( """ pybamm.logger.info("Start solving {} with {}".format(model.name, self.name)) + # Cannot use multiprocessing with model in "jax" format + if(len(inputs) > 1) and model.convert_to_format == "jax": + raise pybamm.SolverError( + "Cannot solve list of inputs with multiprocessing " + "when model in format \"jax\"." + ) # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 02dd106274..2af2aa803b 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -318,6 +318,38 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + def test_model_solver_multiple_inputs_jax_format_error(self): + # Create model + model = pybamm.BaseModel() + # Covert to casadi instead of python to avoid pickling of + # "EvaluatorPython" objects. + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")} + # No need to set parameters; can use base discretisation (no spatial + # operators) + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") + t_eval = np.linspace(0, 10, 100) + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + with self.assertRaisesRegex( + pybamm.SolverError, + ( + "Cannot solve list of inputs with multiprocessing " + 'when model in format "jax".' + ), + ): + solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + def test_model_solver_with_external(self): # Create model model = pybamm.BaseModel() From d3e132ae9da1e50dc2b6558c8ecb1a324f293937 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Mon, 14 Dec 2020 15:31:54 +0000 Subject: [PATCH 4/7] relocate check for jax format So that `inputs_list` is defined --- pybamm/solvers/base_solver.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 01402ee1fc..f38a972eb7 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -545,12 +545,6 @@ def solve( """ pybamm.logger.info("Start solving {} with {}".format(model.name, self.name)) - # Cannot use multiprocessing with model in "jax" format - if(len(inputs) > 1) and model.convert_to_format == "jax": - raise pybamm.SolverError( - "Cannot solve list of inputs with multiprocessing " - "when model in format \"jax\"." - ) # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: if not isinstance(self, pybamm.DummySolver): @@ -595,6 +589,13 @@ def solve( for inputs in inputs_list ] + # Cannot use multiprocessing with model in "jax" format + if(len(inputs_list) > 1) and model.convert_to_format == "jax": + raise pybamm.SolverError( + "Cannot solve list of inputs with multiprocessing " + "when model in format \"jax\"." + ) + # Set up timer = pybamm.Timer() From 43da47a336511347e1e29e1cc2fc4321eac6aaf3 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Tue, 5 Jan 2021 20:30:29 +0000 Subject: [PATCH 5/7] Add call to close() and join() for coverage to pick up processes forkred by Pool --- pybamm/solvers/base_solver.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index f38a972eb7..bb3418922e 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -738,6 +738,8 @@ def solve( ext_and_inputs_list, ), ) + p.close() + p.join() # Setting the solve time for each segment. # pybamm.Solution.append assumes attribute # solve_time. From 7729e87d32546a1ce4623580c25de29b89d2bda7 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Wed, 6 Jan 2021 07:38:28 +0000 Subject: [PATCH 6/7] Enable multiprocessing mode for coverage --- tox.ini | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4e90a958c3..b7f094fcee 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,11 @@ deps = scikits.odes commands = coverage run run-tests.py --nosub + # Some tests make use of multiple processes through + # multiprocessing. Coverage data is then generated for each + # process separately and data must then be combined into one + # single coverage data file. + coverage combine coverage xml [testenv:docs] @@ -114,4 +119,7 @@ ignore= W605, [coverage:run] -source = pybamm \ No newline at end of file +source = pybamm +# By default coverage data isn't collected in forked processes, see +# https://coverage.readthedocs.io/en/coverage-5.3.1/subprocess.html +concurrency = multiprocessing From e5ae0a1c2e7b6e7ac50f33571695fe6fbe48de59 Mon Sep 17 00:00:00 2001 From: Thibault Lestang Date: Wed, 6 Jan 2021 09:34:32 +0000 Subject: [PATCH 7/7] Remove out of date comments --- tests/unit/test_solvers/test_scipy_solver.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 2af2aa803b..b3851c6f0f 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -200,10 +200,9 @@ def test_model_solver_with_inputs(self): var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} - model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))] # No need to set parameters; can use base discretisation (no spatial # operators) - + model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))] # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -221,15 +220,11 @@ def test_model_solver_multiple_inputs_happy_path(self): for convert_to_format in ["python", "casadi"]: # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = convert_to_format domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} - # No need to set parameters; can use base discretisation (no spatial - # operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -253,15 +248,11 @@ def test_model_solver_multiple_inputs_happy_path(self): def test_model_solver_multiple_inputs_discontinuity_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = "casadi" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 1} - # No need to set parameters; can use base discretisation (no spatial - # operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -292,15 +283,11 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): def test_model_solver_multiple_inputs_initial_conditions_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = "casadi" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) model.rhs = {var: -pybamm.InputParameter("rate") * var} model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")} - # No need to set parameters; can use base discretisation (no spatial - # operators) # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -321,8 +308,6 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): def test_model_solver_multiple_inputs_jax_format_error(self): # Create model model = pybamm.BaseModel() - # Covert to casadi instead of python to avoid pickling of - # "EvaluatorPython" objects. model.convert_to_format = "jax" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain)