Skip to content

Commit

Permalink
Cohort coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 21, 2020
1 parent 364c52e commit 606b870
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
32 changes: 26 additions & 6 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
from typing import Hashable

import dask.array as da
import numpy as np
from xarray import Dataset
from xarray import DataArray, Dataset

from sgkit.stats.utils import assert_array_shape
from sgkit.utils import merge_datasets

from .aggregation import count_cohort_alleles, count_variant_alleles
Expand Down Expand Up @@ -71,11 +73,29 @@ def divergence(
ac = ds[allele_counts]
an = ac.sum(axis=2)

n_pairs = np.prod(an, axis=1).compute()
n_same = np.prod(ac, axis=1).sum(axis=1).compute()
n_diff = n_pairs - n_same
div = n_diff / n_pairs
new_ds = Dataset({"stat_divergence": div.sum()})
n_variants = ds.dims["variants"]
n_alleles = ds.dims["alleles"]
n_cohorts = ds.dims["cohorts"]
result = np.full([n_cohorts, n_cohorts], np.nan)

# Iterate over cohort pairs
for i, j in itertools.combinations(range(n_cohorts), 2):
an_cohort_pair = an[:, [i, j]]
assert_array_shape(an_cohort_pair, n_variants, 2)
ac_cohort_pair = ac[:, [i, j], :]
assert_array_shape(ac_cohort_pair, n_variants, 2, n_alleles)

n_pairs = np.prod(an_cohort_pair, axis=1).compute()
n_same = np.prod(ac_cohort_pair, axis=1).sum(axis=1).compute()

n_diff = n_pairs - n_same
div = n_diff / n_pairs
div_sum = div.sum().compute() # TODO: avoid this compute

result[i, j] = div_sum

arr = DataArray(result, dims=["cohorts_a", "cohorts_b"])
new_ds = Dataset({"stat_divergence": arr})
return merge_datasets(ds, new_ds) if merge else new_ds


Expand Down
19 changes: 14 additions & 5 deletions sgkit/tests/test_popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def test_diversity(size):
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
sample_cohorts = np.full_like(ts.samples(), 0)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
div = diversity(ds)["stat_diversity"].compute()
ds = ds.assign_coords({"cohorts": ["co_0"]})
ds = diversity(ds)
div = ds["stat_diversity"].sel(cohorts="co_0").values
ts_div = ts.diversity(span_normalise=False)
np.testing.assert_allclose(div[0], ts_div)
np.testing.assert_allclose(div, ts_div)


@pytest.mark.parametrize("size", [2, 3, 10, 100])
Expand All @@ -55,7 +57,10 @@ def test_divergence(size):
(np.full_like(subset_1, 0), np.full_like(subset_2, 1))
)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
div = divergence(ds)["stat_divergence"].compute()
cohort_names = ["co_0", "co_1"]
ds = ds.assign_coords({"cohorts_a": cohort_names, "cohorts_b": cohort_names})
ds = divergence(ds)
div = ds["stat_divergence"].sel(cohorts_a="co_0", cohorts_b="co_1").values
ts_div = ts.divergence([subset_1, subset_2], span_normalise=False)
np.testing.assert_allclose(div, ts_div)

Expand All @@ -70,7 +75,10 @@ def test_Fst(size):
(np.full_like(subset_1, 0), np.full_like(subset_2, 1))
)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
fst = Fst(ds)["stat_Fst"].compute()
cohort_names = ["co_0", "co_1"]
ds = ds.assign_coords({"cohorts_a": cohort_names, "cohorts_b": cohort_names})
ds = Fst(ds)
fst = ds["stat_Fst"].sel(cohorts_a="co_0", cohorts_b="co_1").values
ts_fst = ts.Fst([subset_1, subset_2])
np.testing.assert_allclose(fst, ts_fst)

Expand All @@ -81,6 +89,7 @@ def test_Tajimas_D(size):
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
sample_cohorts = np.full_like(ts.samples(), 0)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
ds = Tajimas_D(ds)
d = ds["stat_Tajimas_D"].compute()
ts_d = ts.Tajimas_D()
d = Tajimas_D(ds)["stat_Tajimas_D"].compute()
np.testing.assert_allclose(d, ts_d)

0 comments on commit 606b870

Please sign in to comment.