diff --git a/src/pyrovelocity/tests/analysis/test_analyze.py b/src/pyrovelocity/tests/analysis/test_analyze.py index 1a75e30d9..c28fb28a1 100644 --- a/src/pyrovelocity/tests/analysis/test_analyze.py +++ b/src/pyrovelocity/tests/analysis/test_analyze.py @@ -5,6 +5,7 @@ from pyrovelocity.analysis.analyze import ( pareto_frontier_genes, + top_mae_genes, ) @@ -69,3 +70,48 @@ def test_pareto_frontier_genes_fewer_genes(sample_volcano_data): sample_volcano_data, num_genes=10, max_iters=1000 ) assert len(result) == 5 + + +def test_top_mae_genes_basic(sample_volcano_data): + result = top_mae_genes(sample_volcano_data, num_genes=3) + assert isinstance(result, List) + assert len(result) == 3 + assert all(gene in sample_volcano_data.index for gene in result) + + +def test_top_mae_genes_order(sample_volcano_data): + result = top_mae_genes(sample_volcano_data, num_genes=5) + assert result == ["Gene5", "Gene4", "Gene3", "Gene2", "Gene1"] + + +def test_top_mae_genes_ribosomal_filtering(sample_volcano_data_with_ribosomal): + result = top_mae_genes(sample_volcano_data_with_ribosomal, num_genes=4) + assert "Rpl1" not in result + assert "Rps2" not in result + assert len(result) == 4 + + +def test_top_mae_genes_fewer_genes(sample_volcano_data): + result = top_mae_genes(sample_volcano_data, num_genes=10) + assert len(result) == 5 + + +def test_top_mae_genes_time_correlation_sorting(sample_volcano_data): + result = top_mae_genes(sample_volcano_data, num_genes=3) + time_correlations = [ + sample_volcano_data.loc[gene, "time_correlation"] for gene in result + ] + assert time_correlations == sorted(time_correlations, reverse=True) + + +@pytest.mark.parametrize("func", [pareto_frontier_genes, top_mae_genes]) +def test_gene_selection_functions(func, sample_volcano_data): + result = func(sample_volcano_data, num_genes=3) + assert isinstance(result, List) + assert len(result) == 3 + assert all(gene in sample_volcano_data.index for gene in result) + assert result == sorted( + result, + key=lambda x: sample_volcano_data.loc[x, "time_correlation"], + reverse=True, + )