Skip to content

Commit

Permalink
fix: better guard config access
Browse files Browse the repository at this point in the history
+ fix search get parameters conversion
+ more integration tests
  • Loading branch information
alexgarel committed Oct 15, 2024
1 parent 475b87c commit 9828aa7
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 83 deletions.
80 changes: 53 additions & 27 deletions app/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,20 @@ class QuerySearchParameters(BaseModel):
CommonParametersQuery.index_id,
] = None

@field_validator("langs")
@field_validator("langs", mode="before")
@classmethod
def parse_langs_str(cls, langs: str | list[str]) -> list[str]:
"""
Parse for get params 'langs'
"""
if isinstance(langs, str):
langs = langs.split(",")
value_str = _prepare_str_list(langs)
if value_str:
langs = value_str.split(",")
else:
# we already know because of code logic that langs is the right type
# but we need to cast for mypy type checking
langs = cast(list[str], langs)

return langs

@model_validator(mode="after")
Expand All @@ -256,8 +262,7 @@ def validate_index_id(self):
because we want to be able to substitute the default None value,
by the default index
"""
config.check_config_is_defined()
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
check_index_id_is_defined(self.index_id, global_config)
self.index_id, _ = global_config.get_index_config(self.index_id)
return self
Expand All @@ -284,7 +289,7 @@ def check_max_results(self):
@cached_property
def index_config(self):
"""Get the index config once and for all"""
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
_, index_config = global_config.get_index_config(self.index_id)
return index_config

Expand Down Expand Up @@ -396,7 +401,6 @@ def check_charts_are_valid(self):
"""Check that the graph names are valid."""
if self.charts is None:
return self

errors = check_all_values_are_fields_agg(
self.index_id,
[
Expand Down Expand Up @@ -435,6 +439,14 @@ def sign_sort_by(self) -> Tuple[str_utils.BoolOperator, str | None]:
)


def _prepare_str_list(item: Any) -> str | None:
if isinstance(item, str):
return item
elif isinstance(item, list) and all(isinstance(x, str) for x in item):
return ",".join(item)
return None


class SearchParameters(
QuerySearchParameters, ResultSearchParameters, AggregateSearchParameters
):
Expand All @@ -454,30 +466,52 @@ class SearchParameters(
),
] = None

@field_validator("debug_info", mode="before")
@classmethod
def debug_info_list_from_str(
cls, debug_info: str | list[str] | list[DebugInfo] | None
) -> list[DebugInfo] | None:
"""We can pass a comma separated list of DebugInfo values as a string"""
# as we are a before validator, we get a list
str_infos = _prepare_str_list(debug_info)
if str_infos:
values = [getattr(DebugInfo, part, None) for part in str_infos.split(",")]
debug_info = [v for v in values if v is not None]
if debug_info is not None:
# we already know because of code logic that debug_info is the right type
# but we need to cast for mypy type checking
debug_info = cast(list[DebugInfo], debug_info)
return debug_info


class GetSearchParameters(SearchParameters):
"""GET parameters for search"""

@field_validator("charts")
@field_validator("charts", mode="before")
@classmethod
def parse_charts_str(
cls, charts: str | list[ChartType] | None
cls, charts: str | list[str] | list[ChartType] | None
) -> list[ChartType] | None:
"""
Parse for get params are 'field' or 'xfield:yfield'
separated by ',' for Distribution and Scatter charts.
Directly the dictionnaries in POST request
"""
if isinstance(charts, str):
charts_list = charts.split(",")
str_charts = _prepare_str_list(charts)
if str_charts:
charts = []
charts_list = str_charts.split(",")
for c in charts_list:
if ":" in c:
[x, y] = c.split(":")
charts.append(ScatterChartType(x=x, y=y))
else:
charts.append(DistributionChartType(field=c))
if charts is not None:
# we already know because of code logic that charts is the right type
# but we need to cast for mypy type checking
charts = cast(list[ChartType], charts)
return charts

@model_validator(mode="after")
Expand All @@ -487,36 +521,28 @@ def validate_q_or_sort_by(self):
raise ValueError("`sort_by` must be provided when `q` is missing")
return self

@field_validator("debug_info")
@classmethod
def debug_info_list_from_str(
cls, debug_info: str | list[DebugInfo] | None
) -> list[DebugInfo] | None:
"""We can pass a comma separated list of DebugInfo values as a string"""
if isinstance(debug_info, str):
values = [getattr(DebugInfo, part, None) for part in debug_info.split(",")]
debug_info = [v for v in values if v is not None]
return debug_info

fields: Annotated[
list[str] | None,
Query(
description="List of fields to include in the response. All other fields will be ignored."
),
] = None

@field_validator("facets", "fields", mode="before")
@classmethod
def parse_value_str(cls, value: str | list[str] | None) -> list[str] | None:
"""
Parse for get params 'langs'
"""
if isinstance(value, str):
value = value.split(",")
value_str = _prepare_str_list(value)
if value_str:
value = value_str.split(",")
if value is not None:
# we already know because of code logic that value is the right type
# but we need to cast for mypy type checking
value = cast(list[str], value)
return value

field_validator("fields")(parse_value_str)
field_validator("facets")(parse_value_str)

@model_validator(mode="after")
def no_sort_by_scripts_on_get(self):
if self.uses_sort_script:
Expand Down
8 changes: 3 additions & 5 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SearchResponse,
SuccessSearchResponse,
)
from app.config import check_config_is_defined, settings
from app.config import settings
from app.postprocessing import process_taxonomy_completion_response
from app.query import build_completion_query
from app.utils import connection, get_logger, init_sentry
Expand Down Expand Up @@ -81,8 +81,7 @@ def get_document(
index_id: Annotated[str | None, CommonParametersQuery.index_id] = None,
):
"""Fetch a document from Elasticsearch with specific ID."""
check_config_is_defined()
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
check_index_id_is_defined_or_400(index_id, global_config)
index_id, index_config = global_config.get_index_config(index_id)

Expand Down Expand Up @@ -161,8 +160,7 @@ def taxonomy_autocomplete(
index_id: Annotated[str | None, CommonParametersQuery.index_id] = None,
):
"""API endpoint for autocompletion using taxonomies"""
check_config_is_defined()
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
check_index_id_is_defined_or_400(index_id, global_config)
index_id, index_config = global_config.get_index_config(index_id)
taxonomy_names_list = taxonomy_names.split(",")
Expand Down
18 changes: 6 additions & 12 deletions app/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@
def _get_index_config(
config_path: Optional[Path], index_id: Optional[str]
) -> tuple[str, "app.config.IndexConfig"]:
from typing import cast

from app import config
from app.config import check_config_is_defined, set_global_config
from app.config import set_global_config

if config_path:
set_global_config(config_path)

check_config_is_defined()
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
index_id, index_config = global_config.get_index_config(index_id)
if index_config is None:
raise typer.BadParameter(
Expand Down Expand Up @@ -242,11 +240,9 @@ def cleanup_indexes(
index_configs = [index_config]
else:
_get_index_config(config_path, None) # just to set global config variable
from app.config import CONFIG
from app.config import get_config

if CONFIG is None:
raise ValueError("No configuration found")
index_configs = list(CONFIG.indices.values())
index_configs = list(get_config().indices.values())
start_time = time.perf_counter()
removed = 0
for index_config in index_configs:
Expand All @@ -273,11 +269,10 @@ def run_update_daemon(
It is optional but enables having an always up-to-date index,
for applications where data changes.
"""
from typing import cast

from app import config
from app._import import run_update_daemon
from app.config import check_config_is_defined, set_global_config, settings
from app.config import set_global_config, settings
from app.utils import get_logger, init_sentry

# Create root logger
Expand All @@ -288,8 +283,7 @@ def run_update_daemon(
if config_path:
set_global_config(config_path)

check_config_is_defined()
global_config = cast(config.Config, config.CONFIG)
global_config = config.get_config()
run_update_daemon(global_config)


Expand Down
33 changes: 19 additions & 14 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,28 +908,33 @@ def from_yaml(cls, path: Path) -> "Config":
return cls(**data)


# CONFIG is a global variable that contains the search-a-licious configuration
# _CONFIG is a global variable that contains the search-a-licious configuration
# used. It is specified by the envvar CONFIG_PATH.
CONFIG: Config | None = None
# use get_config() to access it.
_CONFIG: Config | None = None


def get_config() -> Config:
"""Return the object containing global configuration
It raises if configuration was not yet set
"""
if _CONFIG is None:
raise RuntimeError(
"No configuration is configured, set envvar "
"CONFIG_PATH with the path of the yaml configuration file"
)
return _CONFIG


def set_global_config(config_path: Path):
global CONFIG
CONFIG = Config.from_yaml(config_path)
return CONFIG
global _CONFIG
_CONFIG = Config.from_yaml(config_path)
return _CONFIG


if settings.config_path:
if not settings.config_path.is_file():
raise RuntimeError(f"config file does not exist: {settings.config_path}")

set_global_config(settings.config_path)


def check_config_is_defined():
"""Raise a RuntimeError if the Config path is not set."""
if CONFIG is None:
raise RuntimeError(
"No configuration is configured, set envvar "
"CONFIG_PATH with the path of the yaml configuration file"
)
39 changes: 20 additions & 19 deletions app/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@

logger = logging.getLogger(__name__)

if config.CONFIG is None:
# We want to be able to import api.py (for tests for example) without
# failure, but we add a warning message as it's not expected in a
# production settings
logger.warning("Main configuration is not set, use CONFIG_PATH envvar")
ES_QUERY_BUILDERS = {}
RESULT_PROCESSORS = {}
else:
# we cache query builder and result processor here for faster processing
ES_QUERY_BUILDERS = {
index_id: build_elasticsearch_query_builder(index_config)
for index_id, index_config in config.CONFIG.indices.items()
}
RESULT_PROCESSORS = {
index_id: load_result_processor(index_config)
for index_id, index_config in config.CONFIG.indices.items()
}

# we cache query builder and result processor here for faster processing
_ES_QUERY_BUILDERS = {}
_RESULT_PROCESSORS = {}


def get_es_query_builder(index_id):
if index_id not in _ES_QUERY_BUILDERS:
index_config = config.get_config().indices[index_id]
_ES_QUERY_BUILDERS[index_id] = build_elasticsearch_query_builder(index_config)
return _ES_QUERY_BUILDERS[index_id]


def get_result_processor(index_id):
if index_id not in _RESULT_PROCESSORS:
index_config = config.get_config().indices[index_id]
_RESULT_PROCESSORS[index_id] = load_result_processor(index_config)
return _RESULT_PROCESSORS[index_id]


def add_debug_info(
Expand Down Expand Up @@ -64,7 +65,7 @@ def search(
) -> SearchResponse:
"""Run a search"""
result_processor = cast(
BaseResultProcessor, RESULT_PROCESSORS[params.valid_index_id]
BaseResultProcessor, get_result_processor(params.valid_index_id)
)
logger.debug(
"Received search query: q='%s', langs='%s', page=%d, "
Expand All @@ -82,7 +83,7 @@ def search(
params,
# ES query builder is generated from elasticsearch mapping and
# takes ~40ms to generate, build-it before hand to avoid this delay
es_query_builder=ES_QUERY_BUILDERS[params.valid_index_id],
es_query_builder=get_es_query_builder(params.valid_index_id),
)
(
logger.debug(
Expand Down
8 changes: 3 additions & 5 deletions app/validations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import cast

from .config import CONFIG, Config
from .config import Config, get_config


def check_index_id_is_defined(index_id: str | None, config: Config) -> None:
Expand Down Expand Up @@ -31,7 +29,7 @@ def check_all_values_are_fields_agg(
errors: list[str] = []
if values is None:
return errors
global_config = cast(Config, CONFIG)
global_config = get_config()
index_id, index_config = global_config.get_index_config(index_id)
if index_config is None:
raise ValueError(f"Cannot get index config for index_id {index_id}")
Expand All @@ -55,7 +53,7 @@ def check_fields_are_numeric(
if values is None:
return errors

global_config = cast(Config, CONFIG)
global_config = get_config()
index_id, index_config = global_config.get_index_config(index_id)
if index_config is None:
raise ValueError(f"Cannot get index config for index_id {index_id}")
Expand Down
2 changes: 1 addition & 1 deletion tests/int/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Product(factory.DictFactory):
nova_groups = "2"
last_modified_t = 1700525044
created_t = 1537090000
completeness = (0.5874999999999999,)
completeness = 0.5874999999999999
product_name_en = "Granulated sugar"
product_name_fr = "Sucre semoule"
lc = "en"
Expand Down
Loading

0 comments on commit 9828aa7

Please sign in to comment.