Skip to content

Commit

Permalink
fix for #199
Browse files Browse the repository at this point in the history
  • Loading branch information
OnnoEbbens committed Jul 7, 2023
1 parent 6f3ea00 commit b0ac9b9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
4 changes: 2 additions & 2 deletions nlmod/gwf/gwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def npf(
gwf,
pname=pname,
icelltype=icelltype,
k=k.data,
k33=k33.data,
k=k,
k33=k33,
save_flows=save_flows,
**kwargs,
)
Expand Down
10 changes: 9 additions & 1 deletion nlmod/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _get_value_from_ds_attr(ds, varname, attr=None, value=None, warn=True):
return value


def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True, return_da=False):
"""Internal function to get value from dataset data variables.
Parameters
Expand All @@ -551,6 +551,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
the same as varname. If not passed as string, it is treated as data
warn : bool, optional
log warning if value not found
return_da : bool, optional
if True a dataarray can be returned, if False a dataarray is always
converted to a numpy array before being returned. The default is False.
Returns
-------
Expand Down Expand Up @@ -597,4 +600,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
f"to function or check whether 'ds.{datavar}' was set correctly."
)
logger.warning(msg)

if not return_da:
if isinstance(value, xr.DataArray):
value = value.values

return value
10 changes: 6 additions & 4 deletions tests/test_003_mfpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ def get_value_from_ds_datavar():
ds["test_var"] = ("layer", "y", "x"), np.arange(np.product(shape)).reshape(shape)

# get value from ds
v0 = nlmod.util._get_value_from_ds_datavar(ds, "test_var", "test_var")
v0 = nlmod.util._get_value_from_ds_datavar(
ds, "test_var", "test_var", return_da=True
)
xr.testing.assert_equal(ds["test_var"], v0)

# get value from ds, variable and stored name are different
v1 = nlmod.util._get_value_from_ds_datavar(ds, "test", "test_var")
xr.testing.assert_equal(ds["test_var"], v1)
xr.testing.assert_equal(ds["test_var"].values, v1)

# do not get value from ds, value is Data Array, should log info msg
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
xr.testing.assert_equal(ds["test_var"], v2)

# do not get value from ds, value is Data Array, no msg
v0.name = "test2"
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
assert (v0 == v3).all()

# return None, value is str but not in dataset, should log warning
Expand Down

0 comments on commit b0ac9b9

Please sign in to comment.