Skip to content

Commit

Permalink
chore: change config schema
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Sep 26, 2023
1 parent 7cb6274 commit 854646d
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 147 deletions.
50 changes: 24 additions & 26 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

import yaml
from pydantic import BaseModel, Field, HttpUrl, model_validator
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
from pydantic_settings import BaseSettings


Expand Down Expand Up @@ -76,8 +76,8 @@ def is_numeric(self):


class FieldConfig(BaseModel):
# name of the field, must be unique across the config
name: str
# name of the field (internal field), it's added here for convenience
_name: str = ""
# type of the field, see `FieldType` for possible values
type: FieldType
# if required=True, the field is required in the input data
Expand All @@ -93,6 +93,10 @@ class FieldConfig(BaseModel):
# can the keyword field contain multiple value (keyword type only)
multi: bool = False

@property
def name(self):
return self._name

@model_validator(mode="after")
def multi_should_be_used_for_selected_type_only(self):
"""Validator that checks that `multi` flag is only True for fields
Expand Down Expand Up @@ -138,8 +142,9 @@ class IndexConfig(BaseModel):
class Config(BaseModel):
# configuration of the index
index: IndexConfig
# configuration of all fields in the index
fields: list[FieldConfig]
# configuration of all fields in the index, keys are field names and values
# contain the field configuration
fields: dict[str, FieldConfig]
split_separator: str = ","
# for `text_lang` FieldType, the separator between the name of the field
# and the language code, ex: product_name_it if lang_separator="_"
Expand All @@ -165,7 +170,7 @@ def taxonomy_name_should_be_defined(self):
"""Validator that checks that for if `taxonomy_type` is defined for a
field, it refers to a taxonomy defined in `taxonomy.sources`."""
defined_taxonomies = [source.name for source in self.taxonomy.sources]
for field in self.fields:
for field in self.fields.values():
if (
field.taxonomy_name is not None
and field.taxonomy_name not in defined_taxonomies
Expand All @@ -175,36 +180,21 @@ def taxonomy_name_should_be_defined(self):
)
return self

@model_validator(mode="after")
def field_name_should_be_unique(self):
"""Validator that checks that all fields have unique names."""
seen: set[str] = set()
for field in self.fields:
if field.name in seen:
raise ValueError(
f"each field name should be unique, duplicate found: '{field.name}'"
)
seen.add(field.name)
return self

@model_validator(mode="after")
def field_references_must_exist_and_be_valid(self):
"""Validator that checks that every field reference in IndexConfig
refers to an existing field and is valid."""

fields_by_name = {f.name: f for f in self.fields}

if self.index.id_field_name not in fields_by_name:
if self.index.id_field_name not in self.fields:
raise ValueError(
f"id_field_name={self.index.id_field_name} but field was not declared"
)

if self.index.last_modified_field_name not in fields_by_name:
if self.index.last_modified_field_name not in self.fields:
raise ValueError(
f"last_modified_field_name={self.index.last_modified_field_name} but field was not declared"
)

last_modified_field = fields_by_name[self.index.last_modified_field_name]
last_modified_field = self.fields[self.index.last_modified_field_name]

if last_modified_field.type != FieldType.date:
raise ValueError(
Expand All @@ -216,13 +206,20 @@ def field_references_must_exist_and_be_valid(self):
@model_validator(mode="after")
def if_split_should_be_multi(self):
"""Validator that checks that multi=True if split=True.."""
for field in self.fields:
for field in self.fields.values():
if field.split and not field.multi:
raise ValueError("multi should be True if split=True")
return self

@field_validator("fields")
@classmethod
def add_field_name_to_each_field(cls, fields):
for field_name, field_item in fields.items():
field_item._name = field_name
return fields

def get_input_fields(self) -> set[str]:
return {field.name for field in self.fields} | {
return set(self.fields) | {
field.input_field for field in self.fields if field.input_field is not None
}

Expand All @@ -231,6 +228,7 @@ def get_supported_langs(self) -> set[str]:

@classmethod
def from_yaml(cls, path: Path) -> "Config":
"""Create a Config from a yaml configuration file."""
with path.open("r") as f:
data = yaml.safe_load(f)
return cls(**data)
Expand Down
4 changes: 2 additions & 2 deletions app/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def from_dict(self, d: JSONType) -> JSONType | None:
if d is None:
return None

for field in self.config.fields:
for field in self.config.fields.values():
input_field = field.get_input_field()

if field.type == FieldType.text_lang:
Expand Down Expand Up @@ -234,7 +234,7 @@ def from_dict(self, d: JSONType) -> JSONType | None:
def generate_mapping_object(config: Config) -> Mapping:
mapping = Mapping()
supported_langs = config.get_supported_langs()
for field in config.fields:
for field in config.fields.values():
mapping.field(
field.name, generate_dsl_field(field, supported_langs=supported_langs)
)
Expand Down
2 changes: 1 addition & 1 deletion app/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def process(self, response: Response, projection: set[str] | None) -> JSONType:
result = hit.to_dict()
result["_score"] = hit.meta.score

for field in self.config.fields:
for field in self.config.fields.values():
if field.name not in result:
continue

Expand Down
4 changes: 2 additions & 2 deletions app/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def build_query_clause(query: str, langs: set[str], config: Config) -> Query:
supported_langs = config.get_supported_langs()
match_phrase_boost_queries = []

for field in config.fields:
for field in config.fields.values():
# We don't include all fields in the multi-match clause, only a subset
# of them
if field.include_multi_match:
Expand Down Expand Up @@ -139,7 +139,7 @@ def parse_sort_by_parameter(sort_by: str | None, config: Config) -> str | None:
if negative_operator := sort_by.startswith("-"):
sort_by = sort_by[1:]

for field in config.fields:
for field in config.fields.values():
if field.name == sort_by:
if field.type is FieldType.text_lang:
# use 'main' language subfield for sorting
Expand Down
Loading

0 comments on commit 854646d

Please sign in to comment.