Skip to content

Commit

Permalink
Add more tests for update() and concat()
Browse files Browse the repository at this point in the history
  • Loading branch information
gtca committed Aug 8, 2024
1 parent c1feb2a commit feb3d15
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def mdata():

@pytest.mark.usefixtures("filepath_h5mu", "filepath_zarr")
class TestMuData:
def test_merge(self, mdata, filepath_h5mu):
def test_merge(self, mdata):
mdata1, mdata2 = mdata[:N1, :].copy(), mdata[N1:, :].copy()
mdata_ = mudata.concat([mdata1, mdata2])
assert list(mdata_.mod.keys()) == ["mod1", "mod2"]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,33 @@ def test_update_after_obs_reordered(self, mdata):
[all(true_obsm_values[i] == test_obsm_values[i]) for i in range(len(true_obsm_values))]
)

@pytest.mark.parametrize("obs_mod", ["unique"])
@pytest.mark.parametrize("obs_across", ["intersecting"])
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
def test_update_intersecting_var_names_after_filtering(self, mdata):
orig_shape = mdata.shape
mdata.mod["mod1"].var_names = [str(i) for i in range(mdata["mod1"].n_vars)]
mdata.mod["mod2"].var_names = [str(i) for i in range(mdata["mod2"].n_vars)]
mdata.update()
mdata.mod["mod1"] = mdata["mod1"][:, :5].copy()
mdata["mod1"].var["true"] = True
mdata["mod2"].var["false"] = False
assert mdata["mod1"].n_vars == 5
mdata.update()
mdata.pull_var(prefix_unique=False)
assert mdata.n_obs == orig_shape[0]
assert mdata.n_vars == mdata["mod1"].n_vars + mdata["mod2"].n_vars
assert mdata.var["true"].sum() == 5
assert (~mdata.var["false"]).sum() == (~mdata["mod2"].var["false"]).sum()

@pytest.mark.parametrize("obs_mod", ["unique"])
@pytest.mark.parametrize("obs_across", ["intersecting"])
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
def test_update_to_new_names(self, mdata):
mdata["mod1"].var_names = [f"_mod1_var{i}" for i in range(1, mdata["mod1"].n_vars + 1)]
mdata["mod2"].var_names = [f"_mod2_var{i}" for i in range(1, mdata["mod2"].n_vars + 1)]
mdata.update()


# @pytest.mark.usefixtures("filepath_h5mu")
# class TestMuDataSameVars:
Expand Down

0 comments on commit feb3d15

Please sign in to comment.