Skip to content

Commit

Permalink
test(analyze): add tests for top_mae_genes function
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 6, 2024
1 parent 7a894ba commit 0d5a48a
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/pyrovelocity/tests/analysis/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pyrovelocity.analysis.analyze import (
pareto_frontier_genes,
top_mae_genes,
)


Expand Down Expand Up @@ -69,3 +70,48 @@ def test_pareto_frontier_genes_fewer_genes(sample_volcano_data):
sample_volcano_data, num_genes=10, max_iters=1000
)
assert len(result) == 5


def test_top_mae_genes_basic(sample_volcano_data):
result = top_mae_genes(sample_volcano_data, num_genes=3)
assert isinstance(result, List)
assert len(result) == 3
assert all(gene in sample_volcano_data.index for gene in result)


def test_top_mae_genes_order(sample_volcano_data):
result = top_mae_genes(sample_volcano_data, num_genes=5)
assert result == ["Gene5", "Gene4", "Gene3", "Gene2", "Gene1"]


def test_top_mae_genes_ribosomal_filtering(sample_volcano_data_with_ribosomal):
result = top_mae_genes(sample_volcano_data_with_ribosomal, num_genes=4)
assert "Rpl1" not in result
assert "Rps2" not in result
assert len(result) == 4


def test_top_mae_genes_fewer_genes(sample_volcano_data):
result = top_mae_genes(sample_volcano_data, num_genes=10)
assert len(result) == 5


def test_top_mae_genes_time_correlation_sorting(sample_volcano_data):
result = top_mae_genes(sample_volcano_data, num_genes=3)
time_correlations = [
sample_volcano_data.loc[gene, "time_correlation"] for gene in result
]
assert time_correlations == sorted(time_correlations, reverse=True)


@pytest.mark.parametrize("func", [pareto_frontier_genes, top_mae_genes])
def test_gene_selection_functions(func, sample_volcano_data):
result = func(sample_volcano_data, num_genes=3)
assert isinstance(result, List)
assert len(result) == 3
assert all(gene in sample_volcano_data.index for gene in result)
assert result == sorted(
result,
key=lambda x: sample_volcano_data.loc[x, "time_correlation"],
reverse=True,
)

0 comments on commit 0d5a48a

Please sign in to comment.