From b0ac9b93cb853f1aa3d4d51991e0eafd5d3a1e68 Mon Sep 17 00:00:00 2001 From: OnnoEbbens Date: Fri, 7 Jul 2023 12:36:44 +0200 Subject: [PATCH] fix for #199 --- nlmod/gwf/gwf.py | 4 ++-- nlmod/util.py | 10 +++++++++- tests/test_003_mfpackages.py | 10 ++++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/nlmod/gwf/gwf.py b/nlmod/gwf/gwf.py index 80cce721..7765a595 100644 --- a/nlmod/gwf/gwf.py +++ b/nlmod/gwf/gwf.py @@ -318,8 +318,8 @@ def npf( gwf, pname=pname, icelltype=icelltype, - k=k.data, - k33=k33.data, + k=k, + k33=k33, save_flows=save_flows, **kwargs, ) diff --git a/nlmod/util.py b/nlmod/util.py index 9189f329..18e6885b 100644 --- a/nlmod/util.py +++ b/nlmod/util.py @@ -536,7 +536,7 @@ def _get_value_from_ds_attr(ds, varname, attr=None, value=None, warn=True): return value -def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True): +def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True, return_da=False): """Internal function to get value from dataset data variables. Parameters @@ -551,6 +551,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True): the same as varname. If not passed as string, it is treated as data warn : bool, optional log warning if value not found + return_da : bool, optional + if True a dataarray can be returned, if False a dataarray is always + converted to a numpy array before being returned. The default is False. Returns ------- @@ -597,4 +600,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True): f"to function or check whether 'ds.{datavar}' was set correctly." ) logger.warning(msg) + + if not return_da: + if isinstance(value, xr.DataArray): + value = value.values + return value diff --git a/tests/test_003_mfpackages.py b/tests/test_003_mfpackages.py index 89654be0..14115a9c 100644 --- a/tests/test_003_mfpackages.py +++ b/tests/test_003_mfpackages.py @@ -102,20 +102,22 @@ def get_value_from_ds_datavar(): ds["test_var"] = ("layer", "y", "x"), np.arange(np.product(shape)).reshape(shape) # get value from ds - v0 = nlmod.util._get_value_from_ds_datavar(ds, "test_var", "test_var") + v0 = nlmod.util._get_value_from_ds_datavar( + ds, "test_var", "test_var", return_da=True + ) xr.testing.assert_equal(ds["test_var"], v0) # get value from ds, variable and stored name are different v1 = nlmod.util._get_value_from_ds_datavar(ds, "test", "test_var") - xr.testing.assert_equal(ds["test_var"], v1) + xr.testing.assert_equal(ds["test_var"].values, v1) # do not get value from ds, value is Data Array, should log info msg - v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0) + v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True) xr.testing.assert_equal(ds["test_var"], v2) # do not get value from ds, value is Data Array, no msg v0.name = "test2" - v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0) + v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True) assert (v0 == v3).all() # return None, value is str but not in dataset, should log warning