diff --git a/simplemma/strategies/dictionaries/dictionary_factory.py b/simplemma/strategies/dictionaries/dictionary_factory.py index e163b97..50adad0 100644 --- a/simplemma/strategies/dictionaries/dictionary_factory.py +++ b/simplemma/strategies/dictionaries/dictionary_factory.py @@ -14,7 +14,7 @@ from functools import lru_cache from os import listdir, path from pathlib import Path -from typing import ByteString, Dict, Protocol +from typing import ByteString, Dict, Mapping, Protocol DATA_FOLDER = str(Path(__file__).parent / "data") SUPPORTED_LANGUAGES = [ @@ -62,7 +62,7 @@ class DictionaryFactory(Protocol): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -70,7 +70,7 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. @@ -78,6 +78,25 @@ def get_dictionary( raise NotImplementedError +class MappingStrToByteString(Mapping[str, str]): + """Wrapper around ByString dict to make them behave like str dict.""" + + __slots__ = ["_dict"] + + def __init__(self, dictionary: Dict[bytes, bytes]): + self._dict = dictionary + + def __getitem__(self, item: str): + return self._dict[item.encode()].decode() + + def __iter__(self): + for key in self._dict: + yield key.decode() + + def __len__(self): + return len(self._dict) + + class DefaultDictionaryFactory(DictionaryFactory): """ Default Dictionary Factory. @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory): It provides functionality for loading and caching dictionaries from disk that are included in Simplemma. """ - __slots__ = ["_data", "_load_dictionary_from_disk"] + __slots__ = ["_load_dictionary_from_disk"] def __init__(self, cache_max_size: int = 8): """ @@ -96,7 +115,6 @@ def __init__(self, cache_max_size: int = 8): cache_max_size (int): The maximum size of the cache for loaded dictionaries. Defaults to `8`. """ - self._data: Dict[str, Dict[ByteString, ByteString]] = {} self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)( _load_dictionary_from_disk ) @@ -104,7 +122,7 @@ def __init__(self, cache_max_size: int = 8): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: """ Get the dictionary for a specific language. @@ -112,11 +130,11 @@ def get_dictionary( lang (str): The language code. Returns: - Dict[str, str]: The dictionary for the specified language. + Mapping[str, str]: The dictionary for the specified language. Raises: ValueError: If the specified language is not supported. """ if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") - return self._load_dictionary_from_disk(lang) + return MappingStrToByteString(self._load_dictionary_from_disk(lang)) diff --git a/simplemma/strategies/dictionaries/trie_directory_factory.py b/simplemma/strategies/dictionaries/trie_directory_factory.py index 03ac4ab..6d2c2ec 100644 --- a/simplemma/strategies/dictionaries/trie_directory_factory.py +++ b/simplemma/strategies/dictionaries/trie_directory_factory.py @@ -2,7 +2,7 @@ from collections.abc import MutableMapping from functools import lru_cache from pathlib import Path -from typing import ByteString, Dict, List, Optional, cast +from typing import List, Mapping, Optional from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found] from platformdirs import user_cache_dir @@ -24,7 +24,7 @@ def __init__(self, trie: BytesTrie): self._trie = trie def __getitem__(self, item): - return self._trie[item.decode()][0] + return self._trie[item][0].decode() def __setitem__(self, key, value): raise NotImplementedError @@ -34,7 +34,7 @@ def __delitem__(self, key): def __iter__(self): for key in self._trie.iterkeys(): - yield key.encode() + yield key def __len__(self): return len(self._trie) @@ -85,8 +85,8 @@ def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie: unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang) return BytesTrie( zip( - [key.decode() for key in unpickled_dict], # type: ignore[union-attr] - unpickled_dict.values(), + unpickled_dict.keys(), + [value.encode() for value in unpickled_dict.values()], ), cache_size=HUGE_CACHE, ) @@ -102,7 +102,7 @@ def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None: trie.save(self._cache_dir / f"{lang}.dic") - def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: + def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]: """Get the dictionary for the given language.""" if lang not in SUPPORTED_LANGUAGES: raise ValueError(f"Unsupported language: {lang}") @@ -114,10 +114,10 @@ def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]: if self._use_disk_cache: self._write_trie_to_disk(lang, trie) - return cast(dict, TrieWrapDict(trie)) + return TrieWrapDict(trie) def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: + ) -> Mapping[str, str]: return self._get_dictionary(lang) diff --git a/simplemma/strategies/dictionary_lookup.py b/simplemma/strategies/dictionary_lookup.py index a98d365..9262477 100644 --- a/simplemma/strategies/dictionary_lookup.py +++ b/simplemma/strategies/dictionary_lookup.py @@ -3,7 +3,7 @@ It provides lemmatization using dictionary lookup. """ -from typing import ByteString, Dict, Optional +from typing import Optional from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory from .lemmatization_strategy import LemmatizationStrategy @@ -26,13 +26,6 @@ def __init__( """ self._dictionary_factory = dictionary_factory - def _get( - self, token: str, dictionary: Dict[ByteString, ByteString] - ) -> Optional[str]: - "Convenience function to handle bytestring to string conversion." - result = dictionary.get(token.encode("utf-8")) - return result.decode("utf-8") if result else None # type: ignore[union-attr] - def get_lemma(self, token: str, lang: str) -> Optional[str]: """ Get Lemma using Dictionary Lookup @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]: """ # Search the language data, reverse case to extend coverage. dictionary = self._dictionary_factory.get_dictionary(lang) - result = self._get(token, dictionary) + result = dictionary.get(token) if result: return result # Try upper or lowercase. token = token.lower() if token[0].isupper() else token.capitalize() - return self._get(token, dictionary) + return dictionary.get(token) diff --git a/simplemma/strategies/greedy_dictionary_lookup.py b/simplemma/strategies/greedy_dictionary_lookup.py index ea372de..0915402 100644 --- a/simplemma/strategies/greedy_dictionary_lookup.py +++ b/simplemma/strategies/greedy_dictionary_lookup.py @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str: return token dictionary = self._dictionary_factory.get_dictionary(lang) - candidate = token.encode("utf-8") + candidate = token for _ in range(self._steps): if candidate not in dictionary: break @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str: candidate = new_candidate - return candidate.decode("utf-8") + return candidate diff --git a/simplemma/utils.py b/simplemma/utils.py index 57d47cb..1d81fa0 100644 --- a/simplemma/utils.py +++ b/simplemma/utils.py @@ -6,7 +6,7 @@ - [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple. """ -from typing import ByteString, Tuple, Union +from typing import Tuple, Union def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]: return lang # type: ignore[return-value] -def levenshtein_dist( - first: Union[ByteString, str], second: Union[ByteString, str] -) -> int: +def levenshtein_dist(str1: str, str2: str) -> int: """ Calculate the Levenshtein distance between two strings. @@ -49,8 +47,6 @@ def levenshtein_dist( int: The Levenshtein distance between the two strings. """ - str1 = first.encode("utf-8") if isinstance(first, str) else first - str2 = second.encode("utf-8") if isinstance(second, str) else second # inspired by this noticeably faster code: # https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b if str1 == str2: diff --git a/tests/strategies/dictionaries/test_trie_dictionary_factory.py b/tests/strategies/dictionaries/test_trie_dictionary_factory.py index 8cbd65a..cbbfd86 100644 --- a/tests/strategies/dictionaries/test_trie_dictionary_factory.py +++ b/tests/strategies/dictionaries/test_trie_dictionary_factory.py @@ -188,11 +188,11 @@ def test_dictionary_working_as_a_dict() -> None: assert isinstance(dictionary, TrieWrapDict) - assert (b"balconies" in dictionary) is True - assert (b"balconies123" in dictionary) is False + assert ("balconies" in dictionary) is True + assert ("balconies123" in dictionary) is False with pytest.raises(KeyError): - dictionary[b"balconies123"] - assert dictionary.get(b"balconies") == b"balcony" + dictionary["balconies123"] + assert dictionary.get("balconies") == "balcony" def test_trie_wrap_dict(): @@ -201,21 +201,21 @@ def test_trie_wrap_dict(): ) wrapped_trie = TrieWrapDict(trie) - assert (b"balconies" in wrapped_trie) is True - assert (b"balconies123" in wrapped_trie) is False - assert wrapped_trie[b"balconies"] == b"balcony" + assert ("balconies" in wrapped_trie) is True + assert ("balconies123" in wrapped_trie) is False + assert wrapped_trie["balconies"] == "balcony" with pytest.raises(KeyError): wrapped_trie[b"balconies123"] - assert wrapped_trie.get(b"balconies") == b"balcony" - assert wrapped_trie.get(b"balconies123") is None + assert wrapped_trie.get("balconies") == "balcony" + assert wrapped_trie.get("balconies123") is None assert isinstance(wrapped_trie.keys(), KeysView) assert isinstance(wrapped_trie.items(), ItemsView) assert len(wrapped_trie) == 3 with pytest.raises(NotImplementedError): - wrapped_trie["houses"] = b"teapot" + wrapped_trie["houses"] = "teapot" with pytest.raises(NotImplementedError): del wrapped_trie["balconies"] - assert [key for key in wrapped_trie] == [b"balconies", b"houses", b"ponies"] + assert [key for key in wrapped_trie] == ["balconies", "houses", "ponies"] diff --git a/tests/test_dictionary_pickler.py b/tests/test_dictionary_pickler.py index 2fc806f..37136f2 100644 --- a/tests/test_dictionary_pickler.py +++ b/tests/test_dictionary_pickler.py @@ -26,9 +26,9 @@ def test_logic() -> None: # different order mydict = dictionary_pickler._read_dict(testfile, "es", silent=True) assert len(mydict) == 5 - assert mydict[b"closeones"] == b"closeone" + assert mydict["closeones"] == "closeone" item = sorted(mydict.keys(), reverse=True)[0] - assert item == b"valid-word" + assert item == "valid-word" # file I/O assert dictionary_pickler._determine_path("lists", "de").endswith("de.txt") diff --git a/tests/test_lemmatizer.py b/tests/test_lemmatizer.py index e911cf1..17a8e93 100644 --- a/tests/test_lemmatizer.py +++ b/tests/test_lemmatizer.py @@ -1,6 +1,6 @@ """Tests for `simplemma` package.""" -from typing import ByteString, Dict +from typing import Mapping import pytest @@ -17,8 +17,8 @@ class CustomDictionaryFactory(DictionaryFactory): def get_dictionary( self, lang: str, - ) -> Dict[ByteString, ByteString]: - return {b"testing": b"the test works!!"} + ) -> Mapping[str, str]: + return {"testing": "the test works!!"} assert ( Lemmatizer( diff --git a/training/dictionary_pickler.py b/training/dictionary_pickler.py index 15345d1..69f4692 100644 --- a/training/dictionary_pickler.py +++ b/training/dictionary_pickler.py @@ -10,7 +10,7 @@ import re from operator import itemgetter from pathlib import Path -from typing import ByteString, Dict, List, Optional +from typing import Dict, List, Optional import simplemma from simplemma.strategies.defaultrules import DEFAULT_RULES @@ -49,9 +49,7 @@ def _determine_path(listpath: str, langcode: str) -> str: return str(Path(__file__).parent / filename) -def _read_dict( - filepath: str, langcode: str, silent: bool -) -> Dict[ByteString, ByteString]: +def _read_dict(filepath: str, langcode: str, silent: bool) -> Dict[str, str]: mydict: Dict[str, str] = {} myadditions: List[str] = [] i: int = 0 @@ -122,12 +120,12 @@ def _read_dict( mydict[word] = word LOGGER.debug("%s %s", langcode, i) # sort and convert to bytestrings - return {k.encode("utf-8"): v.encode("utf-8") for k, v in sorted(mydict.items())} + return dict(sorted(mydict.items())) def _load_dict( langcode: str, listpath: str = "lists", silent: bool = True -) -> Dict[ByteString, ByteString]: +) -> Dict[str, str]: filepath = _determine_path(listpath, langcode) return _read_dict(filepath, langcode, silent)