Skip to content

Commit

Permalink
update type hints and self describe docs
Browse files Browse the repository at this point in the history
  • Loading branch information
scivision committed Jun 5, 2024
1 parent 6e5608f commit b764060
Show file tree
Hide file tree
Showing 16 changed files with 216 additions and 64 deletions.
3 changes: 1 addition & 2 deletions example/plot_Efield.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from pathlib import Path
import argparse

import xarray
import matplotlib as mpl

import gemini3d.read as read


def plotVmaxx1it(ax: mpl.axes.Axes, V: xarray.DataArray) -> None:
def plotVmaxx1it(ax, V) -> None:
ax.set_title("Vmaxx1it: Potential")
if V.ndim == 1:
ax.plot(dat["mlat"], V)
Expand Down
12 changes: 10 additions & 2 deletions example/plotdata.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
#!/usr/bin/env python3
from argparse import ArgumentParser

import xarray
from matplotlib.pyplot import figure, show

import gemini3d.read as read


def plotplasma(dat: xarray.Dataset):
def plotplasma(dat):
"""
Parameters
----------
dat: xarray.Dataset
plasma state data to plot
"""
for k, v in dat.items():
if k == "t":
t = v
Expand Down
12 changes: 8 additions & 4 deletions src/gemini3d/compare/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from datetime import datetime
import typing

import xarray
import matplotlib as mpl

from .. import read


def plotdiff(
A: xarray.DataArray,
B: xarray.DataArray,
A,
B,
time: datetime,
new_dir: Path,
ref_dir: Path,
Expand All @@ -24,6 +23,11 @@ def plotdiff(
Parameters
----------
A: xarray.DataArray
new data
B: xarray.DataArray
reference data
imax: int, optional
index of maximum difference
"""
Expand Down Expand Up @@ -63,7 +67,7 @@ def plotdiff(
plotdiff(A[i], B[i], time, new_dir, ref_dir, name=f"{name}-{i}")
elif is3d:
# pick x2 and x3 slice at maximum difference
im: dict[str, xarray.DataArray] = abs(A - B).argmax(dim=...) # type: ignore
im = abs(A - B).argmax(dim=...)
ix2: int = im["x2"].data
ix3: int = im["x3"].data
plotdiff(
Expand Down
21 changes: 17 additions & 4 deletions src/gemini3d/compare/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from __future__ import annotations
import json

import xarray

from ..utils import get_pkg_file


def err_pct(a: xarray.DataArray, b: xarray.DataArray) -> float:
"""compute maximum error percent"""
def err_pct(a, b) -> float:
"""compute maximum error percent
Parameters
----------
a: xarray.DataArray
new data
b: xarray.DataArray
reference data
Returns
-------
float
maximum error percent
"""

return (abs(a - b).max() / abs(b).max()).item() * 100

Expand Down
4 changes: 1 addition & 3 deletions src/gemini3d/conductivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def collisions3D(atmos, Ts, ns, vsx1, ms):
T = Ts[:, :, :, lsp - 1]
nusn[:, :, :, lsp - 1, 0] = 8.9e-11 * (1 + 5.7e-4 * T) * (T**0.5) * nO * 1e-6
nusn[:, :, :, lsp - 1, 1] = 2.33e-11 * (1 - 1.21e-4 * T) * (T) * nN2 * 1e-6
nusn[:, :, :, lsp - 1, 2] = (
1.82e-10 * (1 + 3.6e-2 * (T**0.5)) * (T**0.5) * nO2 * 1e-6
)
nusn[:, :, :, lsp - 1, 2] = 1.82e-10 * (1 + 3.6e-2 * (T**0.5)) * (T**0.5) * nO2 * 1e-6
nusn[:, :, :, lsp - 1, 3] = 4.5e-9 * (1 - 1.35e-4 * T) * (T**0.5) * nH * 1e-6
nus[:, :, :, lsp - 1] = np.sum(nusn[:, :, :, lsp - 1, :], axis=3)

Expand Down
10 changes: 7 additions & 3 deletions src/gemini3d/efield/Efield_erf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,26 @@
import logging
import typing as T

import xarray
import numpy as np
from scipy.special import erf


def Efield_erf(
E: xarray.Dataset,
E,
xg: dict[str, T.Any],
lx1: int,
lx2: int,
lx3: int,
gridflag: int,
flagdip: bool,
) -> xarray.Dataset:
):
"""
synthesize a feature
Parameters
----------
E: xarray.Dataset
"""

if E.Etarg > 1:
Expand Down
18 changes: 16 additions & 2 deletions src/gemini3d/efield/Jcurrent_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import xarray
import numpy as np


def Jcurrent_gaussian(E: xarray.Dataset, gridflag: int, flagdip: bool) -> xarray.Dataset:
def Jcurrent_gaussian(E, gridflag: int, flagdip: bool):
"""
Parameters
----------
E: xarray.Dataset
Current density
Returns
-------
E: xarray.Dataset
Current density with boundary conditions
"""
if E.mlon.size > 1:
shapelon = np.exp(-((E.mlon - E.mlonmean) ** 2) / 2 / E.mlonsig**2)
else:
Expand Down
11 changes: 10 additions & 1 deletion src/gemini3d/efield/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ def Esigma(pwidth: float, pmax: float, pmin: float, px) -> tuple[float, T.Any]:
return wsig, xsig


def check_finite(v: xarray.DataArray, name: str | None = None):
def check_finite(v, name: str | None = None):
"""
checks that the input is finite
Parameters
----------
v: xarray.DataArray
input data to check
"""
i = np.logical_not(np.isfinite(v))

if i.any():
Expand Down
37 changes: 32 additions & 5 deletions src/gemini3d/hdf5/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@

import h5py
import numpy as np
import xarray

from ..utils import datetime2stem, to_datetime

CLVL = 3 # GZIP compression level: larger => better compression, slower to write


def state(fn: Path, dat: xarray.Dataset) -> None:
def state(fn: Path, dat) -> None:
"""
write STATE VARIABLE initial conditions
Expand All @@ -27,6 +26,14 @@ def state(fn: Path, dat: xarray.Dataset) -> None:
INPUT ARRAYS SHOULD BE TRIMMED TO THE CORRECT SIZE
I.E. THEY SHOULD NOT INCLUDE GHOST CELLS
Parameters
----------
fn: pathlib.Path
output filename
dat: xarray.Dataset
data to write
"""

logging.info(f"state: {fn}")
Expand All @@ -42,7 +49,7 @@ def state(fn: Path, dat: xarray.Dataset) -> None:
_write_var(f, "/Phiall", dat["Phitop"])


def _write_var(fid, name: str, A: xarray.DataArray) -> None:
def _write_var(fid, name: str, A) -> None:
"""
NOTE: The .transpose() reverses the dimension order.
The HDF Group never implemented the intended H5T_array_create(..., perm)
Expand All @@ -51,6 +58,12 @@ def _write_var(fid, name: str, A: xarray.DataArray) -> None:
Fortran, including the HDF Group Fortran interfaces and h5fortran as well as
Matlab read/write HDF5 in Fortran order. h5py read/write HDF5 in C order so we
need the .transpose() for h5py
Parameters
----------
A: xarray.DataArray
variable to write to HDF5
"""

p4s = ("species", "x3", "x2", "x1")
Expand Down Expand Up @@ -212,9 +225,15 @@ def grid(size_fn: Path, grid_fn: Path, xg: dict[str, T.Any]) -> None:
h["/glatctr"] = xg["glatctr"]


def Efield(outdir: Path, E: xarray.Dataset) -> None:
def Efield(outdir: Path, E) -> None:
"""
write Efield to disk
Parameters
----------
E: xarray.Dataset
Electric field
"""

with h5py.File(outdir / "simsize.h5", "w") as f:
Expand Down Expand Up @@ -248,7 +267,15 @@ def Efield(outdir: Path, E: xarray.Dataset) -> None:
f[f"/{k}"] = E[k].loc[time].astype(np.float32)


def precip(outdir: Path, P: xarray.Dataset) -> None:
def precip(outdir: Path, P) -> None:
"""
Parameters
----------
P: xarray.Dataset
precipitation data
"""
with h5py.File(outdir / "simsize.h5", "w") as f:
f.create_dataset("/llon", data=P.mlon.size, dtype=np.int32)
f.create_dataset("/llat", data=P.mlat.size, dtype=np.int32)
Expand Down
19 changes: 11 additions & 8 deletions src/gemini3d/particles/gaussian2d.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import numpy as np
import xarray


def gaussian2d(pg: xarray.Dataset, Qpeak: float, Qbackground: float):
def gaussian2d(pg, Qpeak: float, Qbackground: float):
"""
Parameters
----------
pg: xarray.Dataset
parameters of gaussian
"""
mlon_mean = pg.mlon.mean().item()
mlat_mean = pg.mlat.mean().item()

if "mlon_sigma" in pg.attrs and "mlat_sigma" in pg.attrs:
Q = (
Qpeak
* np.exp(
-((pg.mlon.data[:, None] - mlon_mean) ** 2) / (2 * pg.mlon_sigma**2)
)
* np.exp(
-((pg.mlat.data[None, :] - mlat_mean) ** 2) / (2 * pg.mlat_sigma**2)
)
* np.exp(-((pg.mlon.data[:, None] - mlon_mean) ** 2) / (2 * pg.mlon_sigma**2))
* np.exp(-((pg.mlat.data[None, :] - mlat_mean) ** 2) / (2 * pg.mlat_sigma**2))
)
elif "mlon_sigma" in pg.attrs:
Q = Qpeak * np.exp(
Expand Down
47 changes: 40 additions & 7 deletions src/gemini3d/plasma.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def equilibrium_resample(p: dict[str, T.Any], xg: dict[str, T.Any]):
write.state(p["indat_file"], dat_interp)


def model_resample(
xgin: dict[str, T.Any], dat: xarray.Dataset, xg: dict[str, T.Any]
) -> xarray.Dataset:
def model_resample(xgin: dict[str, T.Any], dat, xg: dict[str, T.Any]):
"""resample a grid
usually used to upsample an equilibrium simulation grid
Expand Down Expand Up @@ -198,7 +196,14 @@ def model_resample(
return dat_interp


def check_density(n: xarray.DataArray):
def check_density(n):
"""
Parameters
----------
n: xarray.DataArray
number density
"""
if not np.isfinite(n).all():
raise ValueError("non-finite density")
if (n < 0).any():
Expand All @@ -207,14 +212,28 @@ def check_density(n: xarray.DataArray):
raise ValueError("too small maximum density")


def check_drift(v: xarray.DataArray):
def check_drift(v):
"""
Parameters
----------
v: xarray.DataArray
velocity
"""
if not np.isfinite(v).all():
raise ValueError("non-finite drift")
if (abs(v) > 10e3).any():
raise ValueError("excessive drift velocity")


def check_temperature(Ts: xarray.DataArray):
def check_temperature(Ts):
"""
Parameters
----------
Ts: xarray.DataArray
temperature
"""
if not np.isfinite(Ts).all():
raise ValueError("non-finite temperature")
if (Ts < 0).any():
Expand All @@ -223,12 +242,26 @@ def check_temperature(Ts: xarray.DataArray):
raise ValueError("too cold maximum temperature")


def equilibrium_state(p: dict[str, T.Any], xg: dict[str, T.Any]) -> xarray.Dataset:
def equilibrium_state(p: dict[str, T.Any], xg: dict[str, T.Any]):
"""
generate (arbitrary) initial conditions for a grid.
NOTE: only works on symmmetric closed grids!
[f107a, f107, ap] = activ
Parameters
----------
p: dict
simulation parameters
xg: dict
simulation grid
Returns
-------
dat: xarray.Dataset
initial conditions of equilibrium state
"""

# %% MAKE UP SOME INITIAL CONDITIONS FOR FORTRAN CODE
Expand Down
Loading

0 comments on commit b764060

Please sign in to comment.