Skip to content

Commit

Permalink
Improve the code logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bilalebi committed Jul 17, 2024
1 parent 2b7d980 commit 97c1532
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 59 deletions.
24 changes: 24 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
"""

import logging
from typing import List

from graphql import GraphQLResolveInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,3 +72,24 @@ def get_ensembl_metadata_api_version():
version = line.strip().split("@")[-1]
break
return version


def check_requested_fields(info: GraphQLResolveInfo, fields: List[str]) -> List[bool]:
"""
Check if specific fields are requested in the GraphQL query.
Args:
info (ResolveInfo): The GraphQL resolve information containing query details.
fields (List[str]): A list of field names to check for in the query.
Returns:
List[bool]: A list of booleans indicating whether each field is present in the query.
Usage example:
fields_to_check = ["assembly", "dataset"]
is_assembly_present, is_dataset_present = check_requested_fields(info, fields_to_check)
"""
requested_fields = [
field.name.value for field in info.field_nodes[0].selection_set.selections
]
return [field in requested_fields for field in fields]
161 changes: 102 additions & 59 deletions graphql_service/resolver/gene_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from graphql import GraphQLResolveInfo, GraphQLError
from pymongo.database import Database, Collection

from common import utils
from graphql_service.resolver.data_loaders import BatchLoaders

from graphql_service.resolver.exceptions import (
Expand Down Expand Up @@ -687,68 +688,30 @@ async def resolve_region(_, info: GraphQLResolveInfo, by_name: Dict[str, str]) -
return result


def fetch_genome_and_combine(
info: GraphQLResolveInfo, grpc_model: GRPC_MODEL, by_keyword: Dict, key: str
@QUERY_TYPE.field("genomes")
def resolve_genomes(
_, info: GraphQLResolveInfo, by_keyword: Dict[str, str] = None
) -> List:
"""
Fetches genomes by a specific keyword and combines genome data with assembly data if requested.
Resolve the genomes based on provided keyword arguments.
Under the hood, this resolver might execute and combine 3 different queries based on the requested data:
- The default `get_genome_by_specific_keyword()` gRPC call (Metadata DB)
- If `assembly` is requested, `fetch_assembly_data()` is triggered fetching data from Mongo DB
- If `dataset` is requested, `fetch_dataset_data()` is triggered which triggers `get_datasets_list_by_uuid()`
gRPC call to fetch dataset info (Metadata DB)
Args:
info (GraphQLResolveInfo): The GraphQL resolver information containing the field nodes and other query details.
grpc_model: The gRPC model to fetch genome data.
by_keyword (dict): Dictionary containing the keyword to search genomes by.
key (str): The specific key to use for fetching genomes.
info (GraphQLResolveInfo): GraphQL resolve information containing query details.
by_keyword (Dict[str, str]): Dictionary containing keyword arguments for fetching genomes.
Returns:
List: A list of combined genome and assembly data objects. If assembly data is not requested, only genome data is included.
List: A list of genomes matching the provided keyword.
Raises:
GenomeNotFoundError: If no genomes are found for the given keyword.
MissingArgumentException: If 'by_keyword' argument is not provided.
GraphQLError: If not exactly one field in 'by_keyword' is provided.
GenomeNotFoundError: If no genomes are found matching the provided keyword.
"""
# Fetch genomes data from metadata using gRPC
result = grpc_model.get_genome_by_specific_keyword(
**{key: by_keyword.get(key)},
release_version=by_keyword.get("release_version"),
)
genomes = list(result)
if not genomes:
raise GenomeNotFoundError(by_keyword)

requested_fields = [
field.name.value for field in info.field_nodes[0].selection_set.selections
]
# Check if the assembly and/or dataset fields are requested in the query
is_assembly_present = "assembly" in requested_fields
is_dataset_present = "dataset" in requested_fields

combined_results = []
for genome in genomes:
set_db_conn_for_uuid(info, genome.genome_uuid)
connection_db = get_db_conn(info)
# logging.debug("Collections in the database:", connection_db.list_collection_names())
assembly_collection = connection_db["assembly"]
# logging.debug("assembly_collection.name:", assembly_collection.name)

assembly_data = None
dataset_data = None
if is_assembly_present:
assembly_data = fetch_assembly_data(
assembly_collection, genome.assembly.name
)
if is_dataset_present:
dataset_data = fetch_dataset_data(grpc_model, genome.genome_uuid)
combined_results.append(
create_genome_response(genome, assembly_data, dataset_data)
)
return combined_results


@QUERY_TYPE.field("genomes")
def resolve_genomes(
_, info: GraphQLResolveInfo, by_keyword: Optional[Dict[str, str]] = None
) -> List:

# ask them to provide at least one argument
if not by_keyword:
raise MissingArgumentException("You must provide 'by_keyword' argument.")

Expand All @@ -770,8 +733,48 @@ def resolve_genomes(
"scientific_parlance_name",
"species_taxonomy_id",
]:
# if one of the keys is provided
if by_keyword.get(key):
return fetch_genome_and_combine(info, grpc_model, by_keyword, key)
# Fetch genomes data from metadata using gRPC
result = grpc_model.get_genome_by_specific_keyword(
**{key: by_keyword.get(key)},
release_version=by_keyword.get("release_version"),
)
genomes = list(result)

if not genomes:
raise GenomeNotFoundError(by_keyword)

# Check if the assembly and dataset fields are requested in the query
fields_to_check = ["assembly", "dataset"]
is_assembly_present, is_dataset_present = utils.check_requested_fields(
info, fields_to_check
)

combined_results = []
for genome in genomes:
set_db_conn_for_uuid(info, genome.genome_uuid)
connection_db = get_db_conn(info)
# logging.debug("Collections in the database:", connection_db.list_collection_names())
assembly_collection = connection_db["assembly"]
# logging.debug("assembly_collection.name:", assembly_collection.name)

assembly_data = (
fetch_assembly_data(assembly_collection, genome.assembly.name)
if is_assembly_present
else None
)
dataset_data = (
fetch_dataset_data(grpc_model, genome.genome_uuid)
if is_dataset_present
else None
)

combined_results.append(
create_genome_response(genome, dataset_data, assembly_data)
)

return combined_results

return []

Expand All @@ -787,20 +790,36 @@ def resolve_genome(_, info: GraphQLResolveInfo, by_genome_uuid: Dict[str, str])
if not genome.genome_uuid:
raise GenomeNotFoundError(by_genome_uuid)

# fetch dataset info
dataset_data = fetch_dataset_data(grpc_model, genome.genome_uuid)
genomes = create_genome_response(
genome=genome, assembly_data=None, dataset_data=dataset_data
# Check if the dataset fields is requested in the query
fields_to_check = ["dataset"]
is_dataset_present = utils.check_requested_fields(info, fields_to_check)

dataset_data = (
fetch_dataset_data(grpc_model, genome.genome_uuid)
if is_dataset_present
else None
)

genomes = create_genome_response(genome=genome, dataset_data=dataset_data)
return genomes


def create_genome_response(
genome: Genome,
assembly_data: Optional[Dict[str, Any]] = None,
dataset_data: Optional[List] = None,
assembly_data: Optional[Dict[str, Any]] = None,
) -> Dict:
"""
Create a response dictionary for a genome with optional assembly and dataset data.
Args:
genome (Genome): The genome object containing genome-related information.
assembly_data (Optional[Dict[str, Any]]): Optional dictionary containing assembly data.
dataset_data (Optional[List]): Optional list of dataset objects containing dataset information.
Returns:
Dict: A dictionary containing the genome response data.
"""
datasets_response = []
if dataset_data:
for dataset in dataset_data:
Expand Down Expand Up @@ -836,6 +855,20 @@ def create_genome_response(


def fetch_assembly_data(assembly_collection: Collection, assembly_id: str) -> Mapping:
"""
Fetch assembly data from a collection using the assembly ID.
Args:
assembly_collection (Collection): The collection to search for the assembly data.
assembly_id (str): The ID of the assembly to fetch.
Returns:
Mapping: The assembly data if found.
Raises:
CollectionNotFoundError: If there is an issue accessing the collection.
AssemblyNotFoundError: If the assembly with the given ID is not found.
"""
query = {"assembly_id": assembly_id}
try:
assembly = assembly_collection.find_one(query)
Expand All @@ -851,6 +884,16 @@ def fetch_assembly_data(assembly_collection: Collection, assembly_id: str) -> Ma


def fetch_dataset_data(grpc_model: GRPC_MODEL, genome_uuid: str) -> List:
"""
Fetch dataset data using a gRPC model based on the genome UUID.
Args:
grpc_model (GRPC_MODEL): The gRPC model to use for fetching the dataset data.
genome_uuid (str): The UUID of the genome for which to fetch dataset data.
Returns:
List: A list of datasets associated with the given genome UUID.
"""
result = grpc_model.get_datasets_list_by_uuid(genome_uuid)
datasets = list(result.datasets)
return datasets
Expand Down

0 comments on commit 97c1532

Please sign in to comment.