Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pickling of models in "python" format #1298

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def __init__(self, symbol):
python_str = python_str + "\nself._evaluate = evaluate"

self._python_str = python_str
self._result_var = result_var
self._symbol = symbol

# compile and run the generated python code,
Expand All @@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
else:
return result

def __getstate__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @tlestang, do you think we can use the same trick for the Jax evaluator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I naively tried the same trick with the Jax evaluator but couldn't get it to work as smoothly.
Something in jax.jit can't be pickle, but I don't know what.

AttributeError: Can't pickle local object 'jit.<locals>.f_jitted'

I tried not pickling attributes jit_evaluate and jac_evaluate in the same way, i.e.

def __getstate__(self):
    state = self.__dict__.copy()
    del state["_evaluate_jax"]
    del state["_jit_evaluate"]
    del state["_jac_evaluate"]
    return state

but still get the AttributeError..

# Control the state of instances of EvaluatorPython
# before pickling. Method "_evaluate" cannot be pickled.
# See https://github.com/pybamm-team/PyBaMM/issues/1283
state = self.__dict__.copy()
del state["_evaluate"]
return state

def __setstate__(self, state):
# Restore pickled attributes and
# compile code from "python_str"
# Execution of bytecode (re)adds attribute
# "_method"
self.__dict__.update(state)
compiled_function = compile(self._python_str, self._result_var, "exec")
exec(compiled_function)
tlestang marked this conversation as resolved.
Show resolved Hide resolved


class EvaluatorJax:
"""
Expand Down
9 changes: 9 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,13 @@ def solve(
for inputs in inputs_list
]

# Cannot use multiprocessing with model in "jax" format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that Jax is designed to support parallel solution of models with different inputs, so this should be doable. But it can be left for a PR later on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay good, would be better to rely on Jax's native behavior anyway :)

if(len(inputs_list) > 1) and model.convert_to_format == "jax":
raise pybamm.SolverError(
"Cannot solve list of inputs with multiprocessing "
"when model in format \"jax\"."
)

# Set up
timer = pybamm.Timer()

Expand Down Expand Up @@ -731,6 +738,8 @@ def solve(
ext_and_inputs_list,
),
)
p.close()
p.join()
# Setting the solve time for each segment.
# pybamm.Solution.append assumes attribute
# solve_time.
Expand Down
89 changes: 52 additions & 37 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,9 @@ def test_model_solver_with_inputs(self):
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))]
# No need to set parameters; can use base discretisation (no spatial
# operators)

model.events = [pybamm.Event("var=0.5", pybamm.min(var - 0.5))]
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
Expand All @@ -218,17 +217,42 @@ def test_model_solver_with_inputs(self):
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t))

def test_model_solver_multiple_inputs_happy_path(self):
for convert_to_format in ["python", "casadi"]:
# Create model
model = pybamm.BaseModel()
model.convert_to_format = convert_to_format
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45")
t_eval = np.linspace(0, 10, 100)
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
for i in range(ninputs):
with self.subTest(i=i):
solution = solutions[i]
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(-0.01 * (i + 1) * solution.t)
)

def test_model_solver_multiple_inputs_discontinuity_error(self):
# Create model
model = pybamm.BaseModel()
# Covert to casadi instead of python to avoid pickling of
# "EvaluatorPython" objects.
model.convert_to_format = "casadi"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# No need to set parameters; can use base discretisation (no spatial
# operators)
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
Expand All @@ -240,27 +264,30 @@ def test_model_solver_multiple_inputs_happy_path(self):
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
for i in range(ninputs):
with self.subTest(i=i):
solution = solutions[i]
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(-0.01 * (i + 1) * solution.t)
)
model.events = [
pybamm.Event(
"discontinuity",
pybamm.Scalar(t_eval[-1] / 2),
event_type=pybamm.EventType.DISCONTINUITY,
)
]
with self.assertRaisesRegex(
pybamm.SolverError,
(
"Cannot solve for a list of input parameters"
" sets with discontinuities"
),
):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)

def test_model_solver_multiple_inputs_discontinuity_error(self):
def test_model_solver_multiple_inputs_initial_conditions_error(self):
# Create model
model = pybamm.BaseModel()
# Covert to casadi instead of python to avoid pickling of
# "EvaluatorPython" objects.
model.convert_to_format = "casadi"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# No need to set parameters; can use base discretisation (no spatial
# operators)
model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")}
# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
Expand All @@ -272,28 +299,16 @@ def test_model_solver_multiple_inputs_discontinuity_error(self):
ninputs = 8
inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)]

model.events = [
pybamm.Event(
"discontinuity",
pybamm.Scalar(t_eval[-1] / 2),
event_type=pybamm.EventType.DISCONTINUITY,
)
]
with self.assertRaisesRegex(
pybamm.SolverError,
(
"Cannot solve for a list of input parameters"
" sets with discontinuities"
),
("Input parameters cannot appear in expression " "for initial conditions."),
):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)

def test_model_solver_multiple_inputs_initial_conditions_error(self):
def test_model_solver_multiple_inputs_jax_format_error(self):
# Create model
model = pybamm.BaseModel()
# Covert to casadi instead of python to avoid pickling of
# "EvaluatorPython" objects.
model.convert_to_format = "casadi"
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
Expand All @@ -314,8 +329,8 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self):
with self.assertRaisesRegex(
pybamm.SolverError,
(
"Input parameters cannot appear in expression "
"for initial conditions."
"Cannot solve list of inputs with multiprocessing "
'when model in format "jax".'
),
):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)
Expand Down
10 changes: 9 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ deps =
scikits.odes
commands =
coverage run run-tests.py --nosub
# Some tests make use of multiple processes through
# multiprocessing. Coverage data is then generated for each
# process separately and data must then be combined into one
# single coverage data file.
coverage combine
coverage xml

[testenv:docs]
Expand Down Expand Up @@ -114,4 +119,7 @@ ignore=
W605,

[coverage:run]
source = pybamm
source = pybamm
# By default coverage data isn't collected in forked processes, see
# https://coverage.readthedocs.io/en/coverage-5.3.1/subprocess.html
concurrency = multiprocessing