Skip to content

Commit

Permalink
refactor: data.Phenotypes samples parameter to be of type set (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
aryarm authored Dec 29, 2022
1 parent 7c84a45 commit 383aac4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ Both the ``load()`` and ``read()`` methods support the ``samples`` parameter tha
.. code-block:: python
phenotypes = data.Phenotypes('tests/data/simple.pheno')
phenotypes.read(samples=["HG00097", "HG00099"])
phenotypes.read(samples={"HG00097", "HG00099"})
Iterating over a file
*********************
Expand All @@ -433,7 +433,7 @@ You'll have to call ``__iter()__`` manually if you want to specify any function
.. code-block:: python
phenotypes = data.Phenotypes('tests/data/simple.pheno')
for line in phenotypes.__iter__(samples=["HG00097", "HG00099"]):
for line in phenotypes.__iter__(samples={"HG00097", "HG00099"}):
print(line)
Writing a file
Expand Down
13 changes: 6 additions & 7 deletions haptools/data/phenotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, fname: Path | str, log: Logger = None):

@classmethod
def load(
cls: Phenotypes, fname: Path | str, samples: list[str] = None
cls: Phenotypes, fname: Path | str, samples: set[str] = None
) -> Phenotypes:
"""
Load phenotypes from a pheno file
Expand All @@ -53,7 +53,7 @@ def load(
----------
fname
See documentation for :py:attr:`~.Data.fname`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Data.Phenotypes.read`
Returns
Expand All @@ -66,13 +66,13 @@ def load(
phenotypes.standardize()
return phenotypes

def read(self, samples: list[str] = None):
def read(self, samples: set[str] = None):
"""
Read phenotypes from a pheno file into a numpy matrix stored in :py:attr:`~.Penotypes.data`
Parameters
----------
samples : list[str], optional
samples : set[str], optional
A subset of the samples from which to extract phenotypes
Defaults to loading phenotypes from all samples
Expand Down Expand Up @@ -128,13 +128,13 @@ def _iterate(
)
phens.close()

def __iter__(self, samples: list[str] = None) -> Iterable[namedtuple]:
def __iter__(self, samples: set[str] = None) -> Iterable[namedtuple]:
"""
Read phenotypes from a pheno line by line without storing anything
Parameters
----------
samples : list[str], optional
samples : set[str], optional
A subset of the samples from which to extract phenotypes
Defaults to loading phenotypes from all samples
Expand Down Expand Up @@ -162,7 +162,6 @@ def __iter__(self, samples: list[str] = None) -> Iterable[namedtuple]:
" and should be named '#IID' in the header line"
)
self.names = tuple(header[1:])
samples = set(samples) if samples else None
# call another function to force the lines above to be run immediately
# see https://stackoverflow.com/a/36726497
return self._iterate(phens, phen_text, samples)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_load_phenotypes_subset(self):

# can we load the data from the phenotype file?
phens = Phenotypes(DATADIR.joinpath("simple.pheno"))
phens.read(samples=samples)
phens.read(samples=set(samples))
np.testing.assert_allclose(phens.data, expected)
assert phens.samples == tuple(samples)

Expand Down Expand Up @@ -503,7 +503,7 @@ def _get_expected_covariates(self):
return expected

def _get_fake_covariates(self):
gts = Phenotypes(fname=None)
gts = Covariates(fname=None)
gts.data = self._get_expected_covariates()
gts.samples = ("HG00096", "HG00097", "HG00099", "HG00100", "HG00101")
gts.names = ("sex", "age")
Expand Down Expand Up @@ -551,7 +551,7 @@ def test_load_covariates_subset(self):

# can we load the data from the covariate file?
covars = Covariates(DATADIR.joinpath("simple.covar"))
covars.read(samples=samples)
covars.read(samples=set(samples))
np.testing.assert_allclose(covars.data, expected)
assert covars.samples == tuple(samples)

Expand Down
1 change: 1 addition & 0 deletions tests/test_simphenotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_one_hap_zero_noise_all_same(self):
assert phens.samples == expected.samples
assert phens.names[0] == expected.names[0]

@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_one_hap_zero_noise_all_same_nonzero_heritability(self):
gts = self._get_fake_gens()
hps = self._get_fake_haps()
Expand Down

0 comments on commit 383aac4

Please sign in to comment.