Skip to content

Commit

Permalink
refactor: Renamed value matcher classes to use suffix 'ValueMatcher'
Browse files Browse the repository at this point in the history
This makes value matchers more consistent with other operations
such as schema matching (SchemaMatcher's) and top-k column
matching (TopkColumnMatcher).
  • Loading branch information
aecio committed Jul 17, 2024
1 parent 52fd5fd commit 1c3d9f7
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 35 deletions.
40 changes: 19 additions & 21 deletions bdikit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
)
from bdikit.mapping_algorithms.value_mapping.algorithms import (
ValueMatch,
BaseAlgorithm,
TFIDFAlgorithm,
LLMAlgorithm,
EditAlgorithm,
EmbeddingAlgorithm,
AutoFuzzyJoinAlgorithm,
FastTextAlgorithm,
BaseValueMatcher,
TFIDFValueMatcher,
GPTValueMatcher,
EditDistanceValueMatcher,
EmbeddingValueMatcher,
AutoFuzzyJoinValueMatcher,
FastTextValueMatcher,
)
from bdikit.mapping_algorithms.value_mapping.value_mappers import (
ValueMapper,
Expand Down Expand Up @@ -170,23 +170,21 @@ def top_matches(
return pd.concat(dfs, ignore_index=True)


class ValueMatchingMethod(Enum):
TFIDF = ("tfidf", TFIDFAlgorithm)
EDIT = ("edit_distance", EditAlgorithm)
EMBEDDINGS = ("embedding", EmbeddingAlgorithm)
AUTOFJ = ("auto_fuzzy_join", AutoFuzzyJoinAlgorithm)
FASTTEXT = ("fasttext", FastTextAlgorithm)
GPT = ("gpt", LLMAlgorithm)
class ValueMatchers(Enum):
TFIDF = ("tfidf", TFIDFValueMatcher)
EDIT = ("edit_distance", EditDistanceValueMatcher)
EMBEDDINGS = ("embedding", EmbeddingValueMatcher)
AUTOFJ = ("auto_fuzzy_join", AutoFuzzyJoinValueMatcher)
FASTTEXT = ("fasttext", FastTextValueMatcher)
GPT = ("gpt", GPTValueMatcher)

def __init__(self, method_name: str, method_class: Type[BaseAlgorithm]):
def __init__(self, method_name: str, method_class: Type[BaseValueMatcher]):
self.method_name = method_name
self.method_class = method_class

@staticmethod
def get_instance(method_name: str) -> BaseAlgorithm:
methods = {
method.method_name: method.method_class for method in ValueMatchingMethod
}
def get_instance(method_name: str) -> BaseValueMatcher:
methods = {method.method_name: method.method_class for method in ValueMatchers}
try:
return methods[method_name]()
except KeyError:
Expand Down Expand Up @@ -326,7 +324,7 @@ def match_values(
"The target must be a DataFrame or a standard vocabulary name."
)

value_matcher = ValueMatchingMethod.get_instance(method)
value_matcher = ValueMatchers.get_instance(method)
matches = _match_values(source, target_domain, column_mapping_dict, value_matcher)
return matches

Expand All @@ -335,7 +333,7 @@ def _match_values(
dataset: pd.DataFrame,
target_domain: Dict[str, Optional[List[str]]],
column_mapping: Dict[str, str],
value_matcher: BaseAlgorithm,
value_matcher: BaseValueMatcher,
) -> List[ValueMatchingResult]:

mapping_results: List[ValueMatchingResult] = []
Expand Down
16 changes: 8 additions & 8 deletions bdikit/mapping_algorithms/value_mapping/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ValueMatch(NamedTuple):
similarity: float


class BaseAlgorithm:
class BaseValueMatcher:
"""
Base class for value matching algorithms, i.e., algorithms that match
values from a source (current) domain to values from a target domain.
Expand All @@ -33,7 +33,7 @@ def match(
raise NotImplementedError("Subclasses must implement this method")


class PolyFuzzAlgorithm(BaseAlgorithm):
class PolyFuzzValueMatcher(BaseValueMatcher):
"""
Base class for value matching algorithms based on the PolyFuzz library.
"""
Expand Down Expand Up @@ -63,7 +63,7 @@ def match(
return matches


class TFIDFAlgorithm(PolyFuzzAlgorithm):
class TFIDFValueMatcher(PolyFuzzValueMatcher):
"""
Value matching algorithm based on the TF-IDF similarity between values.
"""
Expand All @@ -72,7 +72,7 @@ def __init__(self):
super().__init__(PolyFuzz(method=TFIDF(n_gram_range=(1, 3), min_similarity=0)))


class EditAlgorithm(PolyFuzzAlgorithm):
class EditDistanceValueMatcher(PolyFuzzValueMatcher):
"""
Value matching algorithm based on the edit distance between values.
"""
Expand All @@ -89,7 +89,7 @@ def __init__(self, scorer: Callable[[str, str], float] = fuzz.ratio):
)


class EmbeddingAlgorithm(PolyFuzzAlgorithm):
class EmbeddingValueMatcher(PolyFuzzValueMatcher):
"""
Value matching algorithm based on the cosine similarity of value embeddings.
"""
Expand All @@ -100,7 +100,7 @@ def __init__(self, model_path: str = "bert-base-multilingual-cased"):
super().__init__(PolyFuzz(method))


class FastTextAlgorithm(PolyFuzzAlgorithm):
class FastTextValueMatcher(PolyFuzzValueMatcher):
"""
Value matching algorithm based on the cosine similarity of FastText embeddings.
"""
Expand All @@ -111,7 +111,7 @@ def __init__(self, model_name: str = "en-crawl"):
super().__init__(PolyFuzz(method))


class LLMAlgorithm(BaseAlgorithm):
class GPTValueMatcher(BaseValueMatcher):
def __init__(self):
self.client = OpenAI()

Expand Down Expand Up @@ -158,7 +158,7 @@ def match(
return matches


class AutoFuzzyJoinAlgorithm(BaseAlgorithm):
class AutoFuzzyJoinValueMatcher(BaseValueMatcher):

def __init__(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ API

.. automodule:: bdikit.api
:members:
:exclude-members: SchemaMatchers, ValueMatchingMethod, ValueMatchingResult
:exclude-members: SchemaMatchers, ValueMatchers, ValueMatchingResult
10 changes: 5 additions & 5 deletions tests/test_value_matching_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import unittest
import pandas as pd
from bdikit.mapping_algorithms.value_mapping.algorithms import (
TFIDFAlgorithm,
EditAlgorithm,
TFIDFValueMatcher,
EditDistanceValueMatcher,
)


class ValueMatchingAlgorithmsTest(unittest.TestCase):
class ValueMatchingTest(unittest.TestCase):

def test_tfidf_value_matching(self):
# given
current_values = ["Red Apple", "Banana", "Oorange", "Strawberry"]
target_values = ["apple", "banana", "orange", "kiwi"]

tfidf_matcher = TFIDFAlgorithm()
tfidf_matcher = TFIDFValueMatcher()

# when
matches = tfidf_matcher.match(current_values, target_values)
Expand All @@ -35,7 +35,7 @@ def test_edit_distance_value_matching(self):
current_values = ["Red Apple", "Banana", "Oorange", "Strawberry"]
target_values = ["apple", "bananana", "orange", "kiwi"]

edit_distance_matcher = EditAlgorithm()
edit_distance_matcher = EditDistanceValueMatcher()

# when
matches = edit_distance_matcher.match(
Expand Down

0 comments on commit 1c3d9f7

Please sign in to comment.