Skip to content

Commit

Permalink
Change dictionary format to use strings again
Browse files Browse the repository at this point in the history
This changes the format of the dictionary returned by
`DictionaryFactory().get_dictionary()` from
`Dict[ByteString, ByteString]` to `Mapping[str, str] to accommodate
alternative dictionary factory implementations better and to ease the
dictionary handling again. This keeps the storage of pickled
dictionaries with byte strings though, as they're smaller than when
using strings.
  • Loading branch information
Dunedan committed Jun 26, 2024
1 parent ce70e71 commit 81f08ba
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 56 deletions.
34 changes: 26 additions & 8 deletions simplemma/strategies/dictionaries/dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from os import listdir, path
from pathlib import Path
from typing import ByteString, Dict, Protocol
from typing import ByteString, Dict, Mapping, Protocol

DATA_FOLDER = str(Path(__file__).parent / "data")
SUPPORTED_LANGUAGES = [
Expand Down Expand Up @@ -62,22 +62,41 @@ class DictionaryFactory(Protocol):
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.
Args:
lang (str): The language code.
Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.
Raises:
ValueError: If the specified language is not supported.
"""
raise NotImplementedError


class MappingStrToByteString(Mapping[str, str]):
"""Wrapper around ByString dict to make them behave like str dict."""

__slots__ = ["_dict"]

def __init__(self, dictionary: Dict[bytes, bytes]):
self._dict = dictionary

def __getitem__(self, item: str):
return self._dict[item.encode()].decode()

def __iter__(self):
for key in self._dict:
yield key.decode()

def __len__(self):
return len(self._dict)

Check warning on line 97 in simplemma/strategies/dictionaries/dictionary_factory.py

View check run for this annotation

Codecov / codecov/patch

simplemma/strategies/dictionaries/dictionary_factory.py#L97

Added line #L97 was not covered by tests


class DefaultDictionaryFactory(DictionaryFactory):
"""
Default Dictionary Factory.
Expand All @@ -86,7 +105,7 @@ class DefaultDictionaryFactory(DictionaryFactory):
It provides functionality for loading and caching dictionaries from disk that are included in Simplemma.
"""

__slots__ = ["_data", "_load_dictionary_from_disk"]
__slots__ = ["_load_dictionary_from_disk"]

def __init__(self, cache_max_size: int = 8):
"""
Expand All @@ -96,27 +115,26 @@ def __init__(self, cache_max_size: int = 8):
cache_max_size (int): The maximum size of the cache for loaded dictionaries.
Defaults to `8`.
"""
self._data: Dict[str, Dict[ByteString, ByteString]] = {}
self._load_dictionary_from_disk = lru_cache(maxsize=cache_max_size)(
_load_dictionary_from_disk
)

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
"""
Get the dictionary for a specific language.
Args:
lang (str): The language code.
Returns:
Dict[str, str]: The dictionary for the specified language.
Mapping[str, str]: The dictionary for the specified language.
Raises:
ValueError: If the specified language is not supported.
"""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")
return self._load_dictionary_from_disk(lang)
return MappingStrToByteString(self._load_dictionary_from_disk(lang))
16 changes: 8 additions & 8 deletions simplemma/strategies/dictionaries/trie_directory_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import MutableMapping
from functools import lru_cache
from pathlib import Path
from typing import ByteString, Dict, List, Optional, cast
from typing import List, Mapping, Optional

from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found]
from platformdirs import user_cache_dir
Expand All @@ -24,7 +24,7 @@ def __init__(self, trie: BytesTrie):
self._trie = trie

def __getitem__(self, item):
return self._trie[item.decode()][0]
return self._trie[item][0].decode()

def __setitem__(self, key, value):
raise NotImplementedError
Expand All @@ -34,7 +34,7 @@ def __delitem__(self, key):

def __iter__(self):
for key in self._trie.iterkeys():
yield key.encode()
yield key

def __len__(self):
return len(self._trie)
Expand Down Expand Up @@ -85,8 +85,8 @@ def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie:
unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang)
return BytesTrie(
zip(
[key.decode() for key in unpickled_dict], # type: ignore[union-attr]
unpickled_dict.values(),
unpickled_dict.keys(),
[value.encode() for value in unpickled_dict.values()],
),
cache_size=HUGE_CACHE,
)
Expand All @@ -102,7 +102,7 @@ def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None:

trie.save(self._cache_dir / f"{lang}.dic")

def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]:
def _get_dictionary_uncached(self, lang: str) -> Mapping[str, str]:
"""Get the dictionary for the given language."""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")
Expand All @@ -114,10 +114,10 @@ def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]:
if self._use_disk_cache:
self._write_trie_to_disk(lang, trie)

return cast(dict, TrieWrapDict(trie))
return TrieWrapDict(trie)

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
) -> Mapping[str, str]:
return self._get_dictionary(lang)
13 changes: 3 additions & 10 deletions simplemma/strategies/dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It provides lemmatization using dictionary lookup.
"""

from typing import ByteString, Dict, Optional
from typing import Optional

from .dictionaries.dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .lemmatization_strategy import LemmatizationStrategy
Expand All @@ -26,13 +26,6 @@ def __init__(
"""
self._dictionary_factory = dictionary_factory

def _get(
self, token: str, dictionary: Dict[ByteString, ByteString]
) -> Optional[str]:
"Convenience function to handle bytestring to string conversion."
result = dictionary.get(token.encode("utf-8"))
return result.decode("utf-8") if result else None # type: ignore[union-attr]

def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
Get Lemma using Dictionary Lookup
Expand All @@ -50,9 +43,9 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
# Search the language data, reverse case to extend coverage.
dictionary = self._dictionary_factory.get_dictionary(lang)
result = self._get(token, dictionary)
result = dictionary.get(token)
if result:
return result
# Try upper or lowercase.
token = token.lower() if token[0].isupper() else token.capitalize()
return self._get(token, dictionary)
return dictionary.get(token)
4 changes: 2 additions & 2 deletions simplemma/strategies/greedy_dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_lemma(self, token: str, lang: str) -> str:
return token

dictionary = self._dictionary_factory.get_dictionary(lang)
candidate = token.encode("utf-8")
candidate = token
for _ in range(self._steps):
if candidate not in dictionary:
break
Expand All @@ -73,4 +73,4 @@ def get_lemma(self, token: str, lang: str) -> str:

candidate = new_candidate

return candidate.decode("utf-8")
return candidate
8 changes: 2 additions & 6 deletions simplemma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- [validate_lang_input][simplemma.utils.validate_lang_input]: Validates the language input and ensures it is a valid tuple.
"""

from typing import ByteString, Tuple, Union
from typing import Tuple, Union


def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
Expand All @@ -31,9 +31,7 @@ def validate_lang_input(lang: Union[str, Tuple[str, ...]]) -> Tuple[str]:
return lang # type: ignore[return-value]


def levenshtein_dist(
first: Union[ByteString, str], second: Union[ByteString, str]
) -> int:
def levenshtein_dist(str1: str, str2: str) -> int:
"""
Calculate the Levenshtein distance between two strings.
Expand All @@ -49,8 +47,6 @@ def levenshtein_dist(
int: The Levenshtein distance between the two strings.
"""
str1 = first.encode("utf-8") if isinstance(first, str) else first
str2 = second.encode("utf-8") if isinstance(second, str) else second
# inspired by this noticeably faster code:
# https://gist.github.com/p-hash/9e0f9904ce7947c133308fbe48fe032b
if str1 == str2:
Expand Down
22 changes: 11 additions & 11 deletions tests/strategies/dictionaries/test_trie_dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def test_dictionary_working_as_a_dict() -> None:

assert isinstance(dictionary, TrieWrapDict)

assert (b"balconies" in dictionary) is True
assert (b"balconies123" in dictionary) is False
assert ("balconies" in dictionary) is True
assert ("balconies123" in dictionary) is False
with pytest.raises(KeyError):
dictionary[b"balconies123"]
assert dictionary.get(b"balconies") == b"balcony"
dictionary["balconies123"]

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.
assert dictionary.get("balconies") == "balcony"


def test_trie_wrap_dict():
Expand All @@ -201,21 +201,21 @@ def test_trie_wrap_dict():
)
wrapped_trie = TrieWrapDict(trie)

assert (b"balconies" in wrapped_trie) is True
assert (b"balconies123" in wrapped_trie) is False
assert wrapped_trie[b"balconies"] == b"balcony"
assert ("balconies" in wrapped_trie) is True
assert ("balconies123" in wrapped_trie) is False
assert wrapped_trie["balconies"] == "balcony"
with pytest.raises(KeyError):
wrapped_trie[b"balconies123"]

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.
assert wrapped_trie.get(b"balconies") == b"balcony"
assert wrapped_trie.get(b"balconies123") is None
assert wrapped_trie.get("balconies") == "balcony"
assert wrapped_trie.get("balconies123") is None

assert isinstance(wrapped_trie.keys(), KeysView)
assert isinstance(wrapped_trie.items(), ItemsView)
assert len(wrapped_trie) == 3

with pytest.raises(NotImplementedError):
wrapped_trie["houses"] = b"teapot"
wrapped_trie["houses"] = "teapot"
with pytest.raises(NotImplementedError):
del wrapped_trie["balconies"]

assert [key for key in wrapped_trie] == [b"balconies", b"houses", b"ponies"]
assert [key for key in wrapped_trie] == ["balconies", "houses", "ponies"]
4 changes: 2 additions & 2 deletions tests/test_dictionary_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_logic() -> None:
# different order
mydict = dictionary_pickler._read_dict(testfile, "es", silent=True)
assert len(mydict) == 5
assert mydict[b"closeones"] == b"closeone"
assert mydict["closeones"] == "closeone"
item = sorted(mydict.keys(), reverse=True)[0]
assert item == b"valid-word"
assert item == "valid-word"

# file I/O
assert dictionary_pickler._determine_path("lists", "de").endswith("de.txt")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lemmatizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for `simplemma` package."""

from typing import ByteString, Dict
from typing import Mapping

import pytest

Expand All @@ -17,8 +17,8 @@ class CustomDictionaryFactory(DictionaryFactory):
def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
return {b"testing": b"the test works!!"}
) -> Mapping[str, str]:
return {"testing": "the test works!!"}

assert (
Lemmatizer(
Expand Down
10 changes: 4 additions & 6 deletions training/dictionary_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from operator import itemgetter
from pathlib import Path
from typing import ByteString, Dict, List, Optional
from typing import Dict, List, Optional

import simplemma
from simplemma.strategies.defaultrules import DEFAULT_RULES
Expand Down Expand Up @@ -49,9 +49,7 @@ def _determine_path(listpath: str, langcode: str) -> str:
return str(Path(__file__).parent / filename)


def _read_dict(
filepath: str, langcode: str, silent: bool
) -> Dict[ByteString, ByteString]:
def _read_dict(filepath: str, langcode: str, silent: bool) -> Dict[str, str]:
mydict: Dict[str, str] = {}
myadditions: List[str] = []
i: int = 0
Expand Down Expand Up @@ -122,12 +120,12 @@ def _read_dict(
mydict[word] = word
LOGGER.debug("%s %s", langcode, i)
# sort and convert to bytestrings
return {k.encode("utf-8"): v.encode("utf-8") for k, v in sorted(mydict.items())}
return dict(sorted(mydict.items()))


def _load_dict(
langcode: str, listpath: str = "lists", silent: bool = True
) -> Dict[ByteString, ByteString]:
) -> Dict[str, str]:
filepath = _determine_path(listpath, langcode)
return _read_dict(filepath, langcode, silent)

Expand Down

0 comments on commit 81f08ba

Please sign in to comment.