From 97c15329d375c36ec915aa0199a7316b2365680c Mon Sep 17 00:00:00 2001 From: Bilal Date: Wed, 17 Jul 2024 13:46:44 +0100 Subject: [PATCH] Improve the code logic --- common/utils.py | 24 ++++ graphql_service/resolver/gene_model.py | 161 ++++++++++++++++--------- 2 files changed, 126 insertions(+), 59 deletions(-) diff --git a/common/utils.py b/common/utils.py index b711e43..f3f93e8 100644 --- a/common/utils.py +++ b/common/utils.py @@ -13,6 +13,9 @@ """ import logging +from typing import List + +from graphql import GraphQLResolveInfo logger = logging.getLogger(__name__) @@ -69,3 +72,24 @@ def get_ensembl_metadata_api_version(): version = line.strip().split("@")[-1] break return version + + +def check_requested_fields(info: GraphQLResolveInfo, fields: List[str]) -> List[bool]: + """ + Check if specific fields are requested in the GraphQL query. + + Args: + info (ResolveInfo): The GraphQL resolve information containing query details. + fields (List[str]): A list of field names to check for in the query. + + Returns: + List[bool]: A list of booleans indicating whether each field is present in the query. + + Usage example: + fields_to_check = ["assembly", "dataset"] + is_assembly_present, is_dataset_present = check_requested_fields(info, fields_to_check) + """ + requested_fields = [ + field.name.value for field in info.field_nodes[0].selection_set.selections + ] + return [field in requested_fields for field in fields] diff --git a/graphql_service/resolver/gene_model.py b/graphql_service/resolver/gene_model.py index ad272ff..f4cb5a6 100644 --- a/graphql_service/resolver/gene_model.py +++ b/graphql_service/resolver/gene_model.py @@ -20,6 +20,7 @@ from graphql import GraphQLResolveInfo, GraphQLError from pymongo.database import Database, Collection +from common import utils from graphql_service.resolver.data_loaders import BatchLoaders from graphql_service.resolver.exceptions import ( @@ -687,68 +688,30 @@ async def resolve_region(_, info: GraphQLResolveInfo, by_name: Dict[str, str]) - return result -def fetch_genome_and_combine( - info: GraphQLResolveInfo, grpc_model: GRPC_MODEL, by_keyword: Dict, key: str +@QUERY_TYPE.field("genomes") +def resolve_genomes( + _, info: GraphQLResolveInfo, by_keyword: Dict[str, str] = None ) -> List: """ - Fetches genomes by a specific keyword and combines genome data with assembly data if requested. + Resolve the genomes based on provided keyword arguments. + Under the hood, this resolver might execute and combine 3 different queries based on the requested data: + - The default `get_genome_by_specific_keyword()` gRPC call (Metadata DB) + - If `assembly` is requested, `fetch_assembly_data()` is triggered fetching data from Mongo DB + - If `dataset` is requested, `fetch_dataset_data()` is triggered which triggers `get_datasets_list_by_uuid()` + gRPC call to fetch dataset info (Metadata DB) Args: - info (GraphQLResolveInfo): The GraphQL resolver information containing the field nodes and other query details. - grpc_model: The gRPC model to fetch genome data. - by_keyword (dict): Dictionary containing the keyword to search genomes by. - key (str): The specific key to use for fetching genomes. + info (GraphQLResolveInfo): GraphQL resolve information containing query details. + by_keyword (Dict[str, str]): Dictionary containing keyword arguments for fetching genomes. Returns: - List: A list of combined genome and assembly data objects. If assembly data is not requested, only genome data is included. + List: A list of genomes matching the provided keyword. Raises: - GenomeNotFoundError: If no genomes are found for the given keyword. + MissingArgumentException: If 'by_keyword' argument is not provided. + GraphQLError: If not exactly one field in 'by_keyword' is provided. + GenomeNotFoundError: If no genomes are found matching the provided keyword. """ - # Fetch genomes data from metadata using gRPC - result = grpc_model.get_genome_by_specific_keyword( - **{key: by_keyword.get(key)}, - release_version=by_keyword.get("release_version"), - ) - genomes = list(result) - if not genomes: - raise GenomeNotFoundError(by_keyword) - - requested_fields = [ - field.name.value for field in info.field_nodes[0].selection_set.selections - ] - # Check if the assembly and/or dataset fields are requested in the query - is_assembly_present = "assembly" in requested_fields - is_dataset_present = "dataset" in requested_fields - - combined_results = [] - for genome in genomes: - set_db_conn_for_uuid(info, genome.genome_uuid) - connection_db = get_db_conn(info) - # logging.debug("Collections in the database:", connection_db.list_collection_names()) - assembly_collection = connection_db["assembly"] - # logging.debug("assembly_collection.name:", assembly_collection.name) - - assembly_data = None - dataset_data = None - if is_assembly_present: - assembly_data = fetch_assembly_data( - assembly_collection, genome.assembly.name - ) - if is_dataset_present: - dataset_data = fetch_dataset_data(grpc_model, genome.genome_uuid) - combined_results.append( - create_genome_response(genome, assembly_data, dataset_data) - ) - return combined_results - - -@QUERY_TYPE.field("genomes") -def resolve_genomes( - _, info: GraphQLResolveInfo, by_keyword: Optional[Dict[str, str]] = None -) -> List: - - # ask them to provide at least one argument if not by_keyword: raise MissingArgumentException("You must provide 'by_keyword' argument.") @@ -770,8 +733,48 @@ def resolve_genomes( "scientific_parlance_name", "species_taxonomy_id", ]: + # if one of the keys is provided if by_keyword.get(key): - return fetch_genome_and_combine(info, grpc_model, by_keyword, key) + # Fetch genomes data from metadata using gRPC + result = grpc_model.get_genome_by_specific_keyword( + **{key: by_keyword.get(key)}, + release_version=by_keyword.get("release_version"), + ) + genomes = list(result) + + if not genomes: + raise GenomeNotFoundError(by_keyword) + + # Check if the assembly and dataset fields are requested in the query + fields_to_check = ["assembly", "dataset"] + is_assembly_present, is_dataset_present = utils.check_requested_fields( + info, fields_to_check + ) + + combined_results = [] + for genome in genomes: + set_db_conn_for_uuid(info, genome.genome_uuid) + connection_db = get_db_conn(info) + # logging.debug("Collections in the database:", connection_db.list_collection_names()) + assembly_collection = connection_db["assembly"] + # logging.debug("assembly_collection.name:", assembly_collection.name) + + assembly_data = ( + fetch_assembly_data(assembly_collection, genome.assembly.name) + if is_assembly_present + else None + ) + dataset_data = ( + fetch_dataset_data(grpc_model, genome.genome_uuid) + if is_dataset_present + else None + ) + + combined_results.append( + create_genome_response(genome, dataset_data, assembly_data) + ) + + return combined_results return [] @@ -787,20 +790,36 @@ def resolve_genome(_, info: GraphQLResolveInfo, by_genome_uuid: Dict[str, str]) if not genome.genome_uuid: raise GenomeNotFoundError(by_genome_uuid) - # fetch dataset info - dataset_data = fetch_dataset_data(grpc_model, genome.genome_uuid) - genomes = create_genome_response( - genome=genome, assembly_data=None, dataset_data=dataset_data + # Check if the dataset fields is requested in the query + fields_to_check = ["dataset"] + is_dataset_present = utils.check_requested_fields(info, fields_to_check) + + dataset_data = ( + fetch_dataset_data(grpc_model, genome.genome_uuid) + if is_dataset_present + else None ) + + genomes = create_genome_response(genome=genome, dataset_data=dataset_data) return genomes def create_genome_response( genome: Genome, - assembly_data: Optional[Dict[str, Any]] = None, dataset_data: Optional[List] = None, + assembly_data: Optional[Dict[str, Any]] = None, ) -> Dict: + """ + Create a response dictionary for a genome with optional assembly and dataset data. + + Args: + genome (Genome): The genome object containing genome-related information. + assembly_data (Optional[Dict[str, Any]]): Optional dictionary containing assembly data. + dataset_data (Optional[List]): Optional list of dataset objects containing dataset information. + Returns: + Dict: A dictionary containing the genome response data. + """ datasets_response = [] if dataset_data: for dataset in dataset_data: @@ -836,6 +855,20 @@ def create_genome_response( def fetch_assembly_data(assembly_collection: Collection, assembly_id: str) -> Mapping: + """ + Fetch assembly data from a collection using the assembly ID. + + Args: + assembly_collection (Collection): The collection to search for the assembly data. + assembly_id (str): The ID of the assembly to fetch. + + Returns: + Mapping: The assembly data if found. + + Raises: + CollectionNotFoundError: If there is an issue accessing the collection. + AssemblyNotFoundError: If the assembly with the given ID is not found. + """ query = {"assembly_id": assembly_id} try: assembly = assembly_collection.find_one(query) @@ -851,6 +884,16 @@ def fetch_assembly_data(assembly_collection: Collection, assembly_id: str) -> Ma def fetch_dataset_data(grpc_model: GRPC_MODEL, genome_uuid: str) -> List: + """ + Fetch dataset data using a gRPC model based on the genome UUID. + + Args: + grpc_model (GRPC_MODEL): The gRPC model to use for fetching the dataset data. + genome_uuid (str): The UUID of the genome for which to fetch dataset data. + + Returns: + List: A list of datasets associated with the given genome UUID. + """ result = grpc_model.get_datasets_list_by_uuid(genome_uuid) datasets = list(result.datasets) return datasets