-
-
Notifications
You must be signed in to change notification settings - Fork 543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pickling of models in "python" format #1298
Changes from all commits
dab096d
35b26dd
35db29d
d3e132a
43da47a
7729e87
e5ae0a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -589,6 +589,13 @@ def solve( | |
for inputs in inputs_list | ||
] | ||
|
||
# Cannot use multiprocessing with model in "jax" format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that Jax is designed to support parallel solution of models with different inputs, so this should be doable. But it can be left for a PR later on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay good, would be better to rely on Jax's native behavior anyway :) |
||
if(len(inputs_list) > 1) and model.convert_to_format == "jax": | ||
raise pybamm.SolverError( | ||
"Cannot solve list of inputs with multiprocessing " | ||
"when model in format \"jax\"." | ||
) | ||
|
||
# Set up | ||
timer = pybamm.Timer() | ||
|
||
|
@@ -731,6 +738,8 @@ def solve( | |
ext_and_inputs_list, | ||
), | ||
) | ||
p.close() | ||
p.join() | ||
# Setting the solve time for each segment. | ||
# pybamm.Solution.append assumes attribute | ||
# solve_time. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @tlestang, do you think we can use the same trick for the Jax evaluator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I naively tried the same trick with the Jax evaluator but couldn't get it to work as smoothly.
Something in
jax.jit
can't be pickle, but I don't know what.I tried not pickling attributes
jit_evaluate
andjac_evaluate
in the same way, i.e.but still get the
AttributeError
..