diff --git a/camphr/pipelines/pattern_search.py b/camphr/pipelines/pattern_search.py index 2371a24f..67c9882d 100644 --- a/camphr/pipelines/pattern_search.py +++ b/camphr/pipelines/pattern_search.py @@ -1,10 +1,21 @@ """Defines pattern search pipeline based on ahocorasik.""" -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, cast +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) import ahocorasick import spacy import textspan from spacy.tokens import Doc +from spacy.tokens.span import Span from spacy.util import filter_spans from typing_extensions import Literal @@ -106,7 +117,7 @@ def get_char_spans(self, text: str) -> Iterator[Tuple[int, int, str]]: i = j - len(word) + 1 yield i, j + 1, word - def _to_text(self, doc: Doc) -> str: + def _to_text(self, doc: Union[Doc, Span]) -> str: if self.lemma: text = _to_lemma_text(doc) else: @@ -133,7 +144,7 @@ def __call__(self, doc: Doc) -> Doc: covering=not self.destructive, label=self.get_label(text), ) - if span: + if span and self._to_text(span) == text: spans.append(span) [s.text for s in spans] # TODO: resolve the evaluation bug and remove this line ents = filter_spans(doc.ents + tuple(spans)) @@ -141,7 +152,7 @@ def __call__(self, doc: Doc) -> Doc: return doc -def _to_lemma_text(doc: Doc) -> str: +def _to_lemma_text(doc: Union[Doc, Span]) -> str: ret = "" for token in doc: ret += token.lemma_ diff --git a/camphr/utils.py b/camphr/utils.py index dc221506..374dc9ea 100644 --- a/camphr/utils.py +++ b/camphr/utils.py @@ -56,8 +56,8 @@ def token_from_char_pos(doc: Doc, i: int) -> Token: def _get_covering_span(doc: Doc, i: int, j: int) -> Span: token_idxs = [t.idx for t in doc] - i = doc[bisect.bisect(token_idxs, i) - 1].i - j = doc[bisect.bisect_left(token_idxs, j)].i + i = bisect.bisect(token_idxs, i) - 1 + j = bisect.bisect_left(token_idxs, j) return doc[i:j] diff --git a/tests/pipelines/test_pattern_search.py b/tests/pipelines/test_pattern_search.py index 12a41151..e3400b46 100644 --- a/tests/pipelines/test_pattern_search.py +++ b/tests/pipelines/test_pattern_search.py @@ -23,7 +23,7 @@ def nlp(lang): _nlp = spacy.blank(lang) pipe = PatternSearcher.from_words( KEYWORDS, - destructive=True, + destructive=False, lower=True, lemma=True, normalizer=lambda x: re.sub(r"\W", "", x), @@ -36,7 +36,8 @@ def nlp(lang): ("今日はいい天気だ", ["今日", "は"], "ja_mecab"), ("Mice is a plural form of mouse", ["mouse"], "en"), ("foo-bar", ["foo-bar"], "en"), - ("たくさん走った", ["走"], "ja_mecab"), + ("たくさん走った", ["走っ"], "ja_mecab"), + ("走れ", ["走れ"], "ja_mecab"), ] @@ -49,11 +50,13 @@ def test_call(nlp, text, expected, target, lang): assert ents == expected -def test_serialization(nlp, tmpdir): +def test_serialization(nlp, tmpdir, lang): + text, expected, target = TESTCASES[0] + if lang != target: + pytest.skip(f"target lang is '{target}', but actual lang is {lang}") path = Path(tmpdir) nlp.to_disk(path) nlp = spacy.load(path) - text, expected, _ = TESTCASES[0] doc = nlp(text) ents = [span.text for span in doc.ents] - assert ents == expected + assert ents == expected, list(doc)