Skip to content

Commit

Permalink
Add a dictionary factory backed by MARISA-tries
Browse files Browse the repository at this point in the history
This adds an additional dictionary factory backed by MARISA-tries. This
dictionary factory on average offers 20x lower memory usage and 100x
faster initialization time, in exchange for reduced lemmatization and
language detection performance.

The first time loading a dictionary with the `TrieDictionaryFactory`
requires more memory and will take a few seconds, as the trie-backed
dictionary has to be generated on-the-fly from the pickled dict-based
dictionary first.
  • Loading branch information
Dunedan committed Jun 26, 2024
1 parent 5f4fa16 commit ce70e71
Show file tree
Hide file tree
Showing 7 changed files with 408 additions and 1 deletion.
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,61 @@ LANG_CACHE_SIZE = 5 # How many language dictionaries to keep in memory at once
For more information see the
[extended documentation](https://adbar.github.io/simplemma/).

### Reducing memory usage

For situations where low memory usage and fast initialization time are
more important than lemmatization and language detection performance,
Simplemma ships another `DictionaryFactory`, which uses a trie as
underlying data structure instead of a Python dict.

Using the `TrieDictionaryFactory` reduces memory usage on average by
20x and initialization time by 100x, but comes at the cost that
performance can be down 50% or even more compared to what Simplemma
otherwise achieves, depending on the specific usage.

To use the `TrieDictionaryFactory` you have to install Simplemma with
the `marisa-trie` extra dependency:

```
pip install simplemma[marisa-trie]
```

Then you have to create a custom strategy using the
`TrieDictionaryFactory` and use that for `Lemmatizer` and
`LanguageDetector` instances:

``` python
>>> from simplemma import LanguageDetector, Lemmatizer
>>> from simplemma.strategies import DefaultStrategy
>>> from simplemma.strategies.dictionaries import TrieDictionaryFactory

>>> lemmatization_strategy = DefaultStrategy(dictionary_factory=TrieDictionaryFactory())

>>> lemmatizer = Lemmatizer(lemmatization_strategy=lemmatization_strategy)
>>> lemmatizer.lemmatize('doughnuts', lang='en')
'doughnut'

>>> language_detector = LanguageDetector('la', lemmatization_strategy=lemmatization_strategy)
>>> language_detector.proportion_in_target_languages("opera post physica posita (τὰ μετὰ τὰ φυσικά)")
0.5
```

While memory usage and initialization time when using the
`TrieDictionaryFactory` are significantly lower compared to the
`DefaultDictionaryFactory`, that's only true if the trie dictionaries
are available on disk. That's not the case when using the
`TrieDictionaryFactory` for the first time, as Simplemma only ships
the dictionaries as Python dicts. The trie dictionaries have to be
generated once from the Python dicts. That happens on-the-fly when
using the `TrieDictionaryFactory` for the first time for a language and
will take a few seconds and use as much memory as loading the Python
dicts for the language requires. For further invocations the trie
dictionaries get cached on disk.

If the computer supposed to run Simplemma doesn't have enough memory to
generate the trie dictionaries, they can also be generated on another
computer with the same CPU architecture and copied over to the cache
directory.

## Supported languages

Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
black==24.4.2
flake8==7.0.0
marisa_trie==1.2.0
mypy==1.10.0
platformdirs==4.2.2
pytest==8.2.1
pytest-cov==5.0.0
types-requests==2.32.0.20240523
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_version(package):
],
description="A simple multilingual lemmatizer for Python.",
install_requires=requirements,
extras_require={"marisa-trie": ["marisa-trie", "platformdirs"]},
license="MIT license",
long_description=readme, # + '\n\n' + history,
long_description_content_type="text/markdown",
Expand Down
6 changes: 5 additions & 1 deletion simplemma/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from .affix_decomposition import AffixDecompositionStrategy
from .default import DefaultStrategy
from .dictionaries import DefaultDictionaryFactory, DictionaryFactory
from .dictionaries import (
DefaultDictionaryFactory,
DictionaryFactory,
TrieDictionaryFactory,
)
from .dictionary_lookup import DictionaryLookupStrategy
from .fallback.lemmatization_fallback_strategy import LemmatizationFallbackStrategy
from .fallback.raise_error import RaiseErrorFallbackStrategy
Expand Down
1 change: 1 addition & 0 deletions simplemma/strategies/dictionaries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Dictionary-based lemmatization strategy."""

from .dictionary_factory import DefaultDictionaryFactory, DictionaryFactory
from .trie_directory_factory import TrieDictionaryFactory
123 changes: 123 additions & 0 deletions simplemma/strategies/dictionaries/trie_directory_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
from collections.abc import MutableMapping
from functools import lru_cache
from pathlib import Path
from typing import ByteString, Dict, List, Optional, cast

from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found]
from platformdirs import user_cache_dir

from simplemma import __version__ as SIMPLEMMA_VERSION
from simplemma.strategies.dictionaries.dictionary_factory import (
DefaultDictionaryFactory,
DictionaryFactory,
SUPPORTED_LANGUAGES,
)

logger = logging.getLogger(__name__)


class TrieWrapDict(MutableMapping):
"""Wrapper around BytesTrie to make them behave like dicts."""

def __init__(self, trie: BytesTrie):
self._trie = trie

def __getitem__(self, item):
return self._trie[item.decode()][0]

def __setitem__(self, key, value):
raise NotImplementedError

def __delitem__(self, key):
raise NotImplementedError

def __iter__(self):
for key in self._trie.iterkeys():
yield key.encode()

def __len__(self):
return len(self._trie)


class TrieDictionaryFactory(DictionaryFactory):
"""Memory optimized DictionaryFactory backed by MARISA-tries.
This dictionary factory creates dictionaries, which are backed by a
MARISA-trie instead of a dict, to make them consume very little
memory compared to the DefaultDictionaryFactory. Trade-offs are that
lookup performance isn't as good as with dicts.
"""

__slots__: List[str] = []

def __init__(
self,
cache_max_size: int = 8,
use_disk_cache: bool = True,
disk_cache_dir: Optional[str] = None,
) -> None:
"""Initialize the TrieDictionaryFactory.
Args:
cache_max_size (int): The maximum number dictionaries to
keep in memory. Defaults to `8`.
use_disk_cache (bool): Whether to cache the tries on disk to
speed up loading time. Defaults to `True`.
disk_cache_dir (Optional[str]): Path where the generated
tries should be stored in. Defaults to a Simplemma-
specific subdirectory of the user's cache directory.
"""

if disk_cache_dir:
self._cache_dir = Path(disk_cache_dir)
else:
self._cache_dir = (
Path(user_cache_dir("simplemma")) / "marisa_trie" / SIMPLEMMA_VERSION
)
self._use_disk_cache = use_disk_cache
self._get_dictionary = lru_cache(maxsize=cache_max_size)(
self._get_dictionary_uncached
)

def _create_trie_from_pickled_dict(self, lang: str) -> BytesTrie:
"""Create a trie from a pickled dictionary."""
unpickled_dict = DefaultDictionaryFactory(cache_max_size=0).get_dictionary(lang)
return BytesTrie(
zip(
[key.decode() for key in unpickled_dict], # type: ignore[union-attr]
unpickled_dict.values(),
),
cache_size=HUGE_CACHE,
)

def _write_trie_to_disk(self, lang: str, trie: BytesTrie) -> None:
"""Persist the trie to disk for later usage.
The persisted trie can be loaded by subsequent runs to speed up
loading times.
"""
logger.debug("Caching trie on disk. This might take a second.")
self._cache_dir.mkdir(parents=True, exist_ok=True)

trie.save(self._cache_dir / f"{lang}.dic")

def _get_dictionary_uncached(self, lang: str) -> Dict[ByteString, ByteString]:
"""Get the dictionary for the given language."""
if lang not in SUPPORTED_LANGUAGES:
raise ValueError(f"Unsupported language: {lang}")

if self._use_disk_cache and (self._cache_dir / f"{lang}.dic").exists():
trie = BytesTrie().load(str(self._cache_dir / f"{lang}.dic"))
else:
trie = self._create_trie_from_pickled_dict(lang)
if self._use_disk_cache:
self._write_trie_to_disk(lang, trie)

return cast(dict, TrieWrapDict(trie))

def get_dictionary(
self,
lang: str,
) -> Dict[ByteString, ByteString]:
return self._get_dictionary(lang)
Loading

0 comments on commit ce70e71

Please sign in to comment.