Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding conda environment.yml to manage dependencies versions #11

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ The model is taken from the P2D model outlined in the [paper](http://web.mit.edu

## Getting started

### pip method (may not work on recent linux versions)

* `pip install numpy scipy` (or get these from your OS)
* `pip install jax[cpu]`
* `pip install scikit-umfpack`

Then run `examples/main_reorder.py`.
### conda method (tested on linux)

* install conda according to your system : <https://www.anaconda.com/download>
* use the environment.yml to create a suitable conda environment : `conda env create -f environment.yml`

Then run `examples/main_reorder.py` or `examples/reorder_test.py`
4 changes: 2 additions & 2 deletions decoupled/p2d_main_fast_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
video_name = 'video.avi'
from utils.precompute_c import precompute
from model.p2d_param import get_battery_sections

from functools import partial



Expand Down Expand Up @@ -71,7 +71,7 @@ def form_c2_n_jit(temp, j, T):
val=vmap(fn,(0,None,0),1)(j,temp,Deff_vec)
return val

@jax.partial(jax.jit, static_argnums=(2, 3,))
@partial(jax.jit, static_argnums=(2, 3,))
def combine_c(cII, cI_vec, M,N):
return np.reshape(cII, [M * (N + 2)], order="F") + cI_vec

Expand Down
16 changes: 16 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: p2d_fast_solver-env
channels:
- defaults
- conda-forge
dependencies:
- python=3.10
- jax=0.3.0
- scikit-umfpack=0.3.3
- gcc=12.1.0
- suitesparse=5.10.1
- numpy=1.22.3
- matplotlib=3.7.1
- numba=0.57.1
## Optionals
# - nose=1.3.7
# - jupyter=1.0.0
2 changes: 1 addition & 1 deletion examples/reorder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def compute_der(U, Uold, cs_pe1, cs_ne1):

diff0 = abs(Joutput - Jab)
diff0_matrix = np.zeros_like(Joutput)
diff0_matrix[Jab.nonzero()] = diff0[Jab.nonzero()]/abs(Jab[Jab.nonzero()])
diff0_matrix.at[Jab.nonzero()].set(diff0[Jab.nonzero()]/abs(Jab[Jab.nonzero()]))
plt.figure()
plt.imshow(diff0_matrix);
plt.colorbar()
Expand Down