Skip to content

Commit

Permalink
feat(analyze): add function to extract top mae genes
Browse files Browse the repository at this point in the history
- sorted by descending shared time correlation

Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 6, 2024
1 parent 0d5a48a commit 508f1ba
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/pyrovelocity/analysis/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,49 @@ def mae_per_gene(pred_counts: ndarray, true_counts: ndarray) -> ndarray:
return -np.array(error / total)


@beartype
def top_mae_genes(
volcano_data: pd.DataFrame,
num_genes: int,
) -> List[str]:
"""
Identify top N genes based on Mean Absolute Error (MAE), excluding ribosomal genes,
and sort the final output by time correlation.
Args:
volcano_data (pd.DataFrame): DataFrame containing MAE and time correlation
num_genes (int): Number of genes to return.
Returns:
List[str]: List of gene indices from volcano_data, sorted by time correlation.
"""
filtered_data = volcano_data[
~volcano_data.index.str.contains(("^Rpl|^Rps"), case=False)
]

top_genes = filtered_data.nlargest(num_genes, "mean_mae")

top_genes_sorted = top_genes.sort_values(
by="time_correlation", ascending=False
)

gene_indices = top_genes_sorted.index.tolist()

logger.info(
f"\nSelected top {len(gene_indices)} genes based on MAE, sorted by time correlation:\n\n"
f" {gene_indices}\n\n"
)

if len(gene_indices) < num_genes:
logger.info(
f"Warning: Only {len(gene_indices)} genes were found,\n"
f"but {num_genes} were requested.\n"
"Returning all available genes.\n"
)

return gene_indices


@beartype
def pareto_frontier_genes(
volcano_data: pd.DataFrame,
Expand Down

0 comments on commit 508f1ba

Please sign in to comment.