From 81f08ba0821c936a41ec2c16115e37fd227deffe Mon Sep 17 00:00:00 2001 From: Daniel Roschka Date: Fri, 14 Jun 2024 13:47:26 +0200 Subject: [PATCH] Change dictionary format to use strings again This changes the format of the dictionary returned by `DictionaryFactory().get_dictionary()` from `Dict[ByteString, ByteString]` to `Mapping[str, str] to accommodate alternative dictionary factory implementations better and to ease the dictionary handling again. This keeps the storage of pickled dictionaries with byte strings though, as they're smaller than when using strings. --- .../dictionaries/dictionary_factory.py | 34 ++++++++++++++----- .../dictionaries/trie_directory_factory.py | 16 ++++----- simplemma/strategies/dictionary_lookup.py | 13 ++----- .../strategies/greedy_dictionary_lookup.py | 4 +-- simplemma/utils.py | 8 ++--- .../test_trie_dictionary_factory.py | 22 ++++++------ tests/test_dictionary_pickler.py | 4 +-- tests/test_lemmatizer.py | 6 ++-- training/dictionary_pickler.py | 10 +++--- 9 files changed, 61 insertions(+), 56 deletions(-) diff --git a/simplemma/strategies/dictionaries/dictionary_factory.py b/simplemma/strategies/dictionaries/dictionary_factory.py index e163b97..50adad0 100644 --- a/simplemma/strategies/dictionaries/dictionary_factory.py +++ b/simplemma/strategies/dictionaries/dictionary_factory.py @@ -14,7 +14,7 @@ from functools import lru_cache from os import listdir, path from pathlib import Path -from typing import ByteString, Dict, Protocol +from typing import ByteString, Dict, Mapping, Protocol DATA_FOLDER = str(Path(__file__).parent / "data") SUPPORTED_LANGUAGES = [ @@ -62,7 +62,7 @@ class DictionaryFactory(Protocol): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -70,7 +70,7 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. @@ -78,6 +78,25 @@ def get_dictionary( raise NotImplementedError +class MappingStrToByteString(Mapping[str, str]): + """Wrapper around ByString dict to make them behave like str dict.""" + + __slots__ = ["_dict"] + + def __init__(self, dictionary: Dict[bytes, bytes]): + self._dict = dictionary + + def __getitem__(self, item: str): + return self._dict[item.encode()].decode() + + def __iter__(self): + for key in self._dict: + yield key.decode() + + def __len__(self): + return len(self._dict) + + class DefaultDictionaryFactory(DictionaryFactory): """ Default Dictionary Factory. @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory): It provides functionality for loading and caching dictionaries from disk that are included in Simplemma. """ - __slots__ = ["_data", "_load_dictionary_from_disk"] + __slots__ = ["_load_dictionary_from_disk"] def __init__(self, cache_max_size: int = 8): """ @@ -96,7 +115,6 @@ def __init__(self, cache_max_size: int = 8): cache_max_size (int): The maximum size of the cache for loaded dictionaries. Defaults to `8`. """ - self._data: Dict[str, Dict[ByteString, ByteString]] = {} self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)( _load_dictionary_from_disk ) @@ -104,7 +122,7 @@ def __init__(self, cache_max_size: int = 8): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -112,11 +130,11 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. """ if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") - return self._load_dictionary_from_disk(lang) + return MappingStrToByteString(self._load_dictionary_from_disk(lang)) diff --git a/simplemma/strategies/dictionaries/trie_directory_factory.py b/simplemma/strategies/dictionaries/trie_directory_factory.py index 03ac4ab..6d2c2ec 100644 --- a/simplemma/strategies/dictionaries/trie_directory_factory.py +++ b/simplemma/strategies/dictionaries/trie_directory_factory.py @@ -2,7 +2,7 @@ from collections.abc import MutableMapping from functools import lru_cache from pathlib import Path -from typing import ByteString, Dict, List, Optional, cast +from typing import List, Mapping, Optional from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found] from platformdirs import user_cache_dir @@ -24,7 +24,7 @@ def __init__(self, trie: BytesTrie): self._trie = trie def __getitem__(self, item): - return self._trie[item.decode()][0] + return self._trie[item][0].decode() def __setitem__(self, key, value): raise NotImplementedError @@ -34,7 +34,7 @@ def __delitem__(self, key): def __iter__(self): for key in self._trie.iterkeys(): - yield key.encode() + yield key def __len__(self): return len(self._trie) @@ -85,8 +85,8 @@ def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie: unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang) return BytesTrie( zip( - [key.decode() for key in unpickled_dict], # type: ignore[union-attr] - unpickled_dict.values(), + unpickled_dict.keys(), + [value.encode() for value in unpickled_dict.values()], ), cache_size=HUGE_CACHE, ) @@ -102,7 +102,7 @@ def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None: trie.save(self._cache_dir / f"{lang}.dic") - def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: + def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]: """Get the dictionary for the given language.""" if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") @@ -114,10 +114,10 @@ def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: if self._use_disk_cache: self._write_trie_to_disk(lang, trie) - return cast(dict, TrieWrapDict(trie)) + return TrieWrapDict(trie) def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: return self._get_dictionary(lang) diff --git a/simplemma/strategies/dictionary_lookup.py b/simplemma/strategies/dictionary_lookup.py index a98d365..9262477 100644 --- a/simplemma/strategies/dictionary_lookup.py +++ b/simplemma/strategies/dictionary_lookup.py @@ -3,7 +3,7 @@ It provides lemmatization using dictionary lookup. """ -from typing import ByteString, Dict, Optional +from typing import Optional from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory from .lemmatization_strategy import LemmatizationStrategy @@ -26,13 +26,6 @@ def __init__( """ self._dictionary_factory = dictionary_factory - def _get( - self, token: str, dictionary: Dict[ByteString, ByteString] - ) -> Optional[str]: - "Convenience function to handle bytestring to string conversion." - result = dictionary.get(token.encode("utf-8")) - return result.decode("utf-8") if result else None # type: ignore[union-attr] - def get_lemma(self, token: str, lang: str) -> Optional[str]: """ Get Lemma using Dictionary Lookup @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]: """ # Search the language data, reverse case to extend coverage. dictionary = self._dictionary_factory.get_dictionary(lang) - result = self._get(token, dictionary) + result = dictionary.get(token) if result: return result # Try upper or lowercase. token = token.lower() if token[0].isupper() else token.capitalize() - return self._get(token, dictionary) + return dictionary.get(token) diff --git a/simplemma/strategies/greedy_dictionary_lookup.py b/simplemma/strategies/greedy_dictionary_lookup.py index ea372de..0915402 100644 --- a/simplemma/strategies/greedy_dictionary_lookup.py +++ b/simplemma/strategies/greedy_dictionary_lookup.py @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str: return token dictionary = self._dictionary_factory.get_dictionary(lang) - candidate = token.encode("utf-8") + candidate = token for _ in range(self._steps): if candidate not in dictionary: break @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str: candidate = new_candidate - return candidate.decode("utf-8") + return candidate diff --git a/simplemma/utils.py b/simplemma/utils.py index 57d47cb..1d81fa0 100644 --- a/simplemma/utils.py +++ b/simplemma/utils.py @@ -6,7 +6,7 @@ - [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple. """ -from typing import ByteString, Tuple, Union +from typing import Tuple, Union def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: return lang # type: ignore[return-value] -def levenshtein_dist( - first: Union[ByteString, str], second: Union[ByteString, str] -) -> int: +def levenshtein_dist(str1: str, str2: str) -> int: """ Calculate the Levenshtein distance between two strings. @@ -49,8 +47,6 @@ def levenshtein_dist( int: The Levenshtein distance between the two strings. """ - str1 = first.encode("utf-8") if isinstance(first, str) else first - str2 = second.encode("utf-8") if isinstance(second, str) else second # inspired by this noticeably faster code: # https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b if str1 == str2: diff --git a/tests/strategies/dictionaries/test_trie_dictionary_factory.py b/tests/strategies/dictionaries/test_trie_dictionary_factory.py index 8cbd65a..cbbfd86 100644 --- a/tests/strategies/dictionaries/test_trie_dictionary_factory.py +++ b/tests/strategies/dictionaries/test_trie_dictionary_factory.py @@ -188,11 +188,11 @@ def test_dictionary_working_as_a_dict() -> None: assert isinstance(dictionary, TrieWrapDict) - assert (b"balconies" in dictionary) is True - assert (b"balconies123" in dictionary) is False + assert ("balconies" in dictionary) is True + assert ("balconies123" in dictionary) is False with pytest.raises(KeyError): - dictionary[b"balconies123"] - assert dictionary.get(b"balconies") == b"balcony" + dictionary["balconies123"] + assert dictionary.get("balconies") == "balcony" def test_trie_wrap_dict(): @@ -201,21 +201,21 @@ def test_trie_wrap_dict(): ) wrapped_trie = TrieWrapDict(trie) - assert (b"balconies" in wrapped_trie) is True - assert (b"balconies123" in wrapped_trie) is False - assert wrapped_trie[b"balconies"] == b"balcony" + assert ("balconies" in wrapped_trie) is True + assert ("balconies123" in wrapped_trie) is False + assert wrapped_trie["balconies"] == "balcony" with pytest.raises(KeyError): wrapped_trie[b"balconies123"] - assert wrapped_trie.get(b"balconies") == b"balcony" - assert wrapped_trie.get(b"balconies123") is None + assert wrapped_trie.get("balconies") == "balcony" + assert wrapped_trie.get("balconies123") is None assert isinstance(wrapped_trie.keys(), KeysView) assert isinstance(wrapped_trie.items(), ItemsView) assert len(wrapped_trie) == 3 with pytest.raises(NotImplementedError): - wrapped_trie["houses"] = b"teapot" + wrapped_trie["houses"] = "teapot" with pytest.raises(NotImplementedError): del wrapped_trie["balconies"] - assert [key for key in wrapped_trie] == [b"balconies", b"houses", b"ponies"] + assert [key for key in wrapped_trie] == ["balconies", "houses", "ponies"] diff --git a/tests/test_dictionary_pickler.py b/tests/test_dictionary_pickler.py index 2fc806f..37136f2 100644 --- a/tests/test_dictionary_pickler.py +++ b/tests/test_dictionary_pickler.py @@ -26,9 +26,9 @@ def test_logic() -> None: # different order mydict = dictionary_pickler._read_dict(testfile, "es", silent=True) assert len(mydict) == 5 - assert mydict[b"closeones"] == b"closeone" + assert mydict["closeones"] == "closeone" item = sorted(mydict.keys(), reverse=True)[0] - assert item == b"valid-word" + assert item == "valid-word" # file I/O assert dictionary_pickler._determine_path("lists", "de").endswith("de.txt") diff --git a/tests/test_lemmatizer.py b/tests/test_lemmatizer.py index e911cf1..17a8e93 100644 --- a/tests/test_lemmatizer.py +++ b/tests/test_lemmatizer.py @@ -1,6 +1,6 @@ """Tests for `simplemma` package.""" -from typing import ByteString, Dict +from typing import Mapping import pytest @@ -17,8 +17,8 @@ class CustomDictionaryFactory(DictionaryFactory): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: - return {b"testing": b"the test works!!"} + ) -> Mapping[str, str]: + return {"testing": "the test works!!"} assert ( Lemmatizer( diff --git a/training/dictionary_pickler.py b/training/dictionary_pickler.py index 15345d1..69f4692 100644 --- a/training/dictionary_pickler.py +++ b/training/dictionary_pickler.py @@ -10,7 +10,7 @@ import re from operator import itemgetter from pathlib import Path -from typing import ByteString, Dict, List, Optional +from typing import Dict, List, Optional import simplemma from simplemma.strategies.defaultrules import DEFAULT_RULES @@ -49,9 +49,7 @@ def _determine_path(listpath: str, langcode: str) -> str: return str(Path(__file__).parent / filename) -def _read_dict( - filepath: str, langcode: str, silent: bool -) -> Dict[ByteString, ByteString]: +def _read_dict(filepath: str, langcode: str, silent: bool) -> Dict[str, str]: mydict: Dict[str, str] = {} myadditions: List[str] = [] i: int = 0 @@ -122,12 +120,12 @@ def _read_dict( mydict[word] = word LOGGER.debug("%s %s", langcode, i) # sort and convert to bytestrings - return {k.encode("utf-8"): v.encode("utf-8") for k, v in sorted(mydict.items())} + return dict(sorted(mydict.items())) def _load_dict( langcode: str, listpath: str = "lists", silent: bool = True -) -> Dict[ByteString, ByteString]: +) -> Dict[str, str]: filepath = _determine_path(listpath, langcode) return _read_dict(filepath, langcode, silent)