Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix index error #77

Merged
merged 3 commits into from
Aug 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions camphr/pipelines/pattern_search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
"""Defines pattern search pipeline based on ahocorasik."""
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, cast
from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)

import ahocorasick
import spacy
import textspan
from spacy.tokens import Doc
from spacy.tokens.span import Span
from spacy.util import filter_spans
from typing_extensions import Literal

Expand Down Expand Up @@ -106,7 +117,7 @@ def get_char_spans(self, text: str) -> Iterator[Tuple[int, int, str]]:
i = j - len(word) + 1
yield i, j + 1, word

def _to_text(self, doc: Doc) -> str:
def _to_text(self, doc: Union[Doc, Span]) -> str:
if self.lemma:
text = _to_lemma_text(doc)
else:
Expand All @@ -133,15 +144,15 @@ def __call__(self, doc: Doc) -> Doc:
covering=not self.destructive,
label=self.get_label(text),
)
if span:
if span and self._to_text(span) == text:
spans.append(span)
[s.text for s in spans] # TODO: resolve the evaluation bug and remove this line
ents = filter_spans(doc.ents + tuple(spans))
doc.ents = tuple(ents)
return doc


def _to_lemma_text(doc: Doc) -> str:
def _to_lemma_text(doc: Union[Doc, Span]) -> str:
ret = ""
for token in doc:
ret += token.lemma_
Expand Down
4 changes: 2 additions & 2 deletions camphr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def token_from_char_pos(doc: Doc, i: int) -> Token:

def _get_covering_span(doc: Doc, i: int, j: int) -> Span:
token_idxs = [t.idx for t in doc]
i = doc[bisect.bisect(token_idxs, i) - 1].i
j = doc[bisect.bisect_left(token_idxs, j)].i
i = bisect.bisect(token_idxs, i) - 1
j = bisect.bisect_left(token_idxs, j)
return doc[i:j]


Expand Down
13 changes: 8 additions & 5 deletions tests/pipelines/test_pattern_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def nlp(lang):
_nlp = spacy.blank(lang)
pipe = PatternSearcher.from_words(
KEYWORDS,
destructive=True,
destructive=False,
lower=True,
lemma=True,
normalizer=lambda x: re.sub(r"\W", "", x),
Expand All @@ -36,7 +36,8 @@ def nlp(lang):
("今日はいい天気だ", ["今日", "は"], "ja_mecab"),
("Mice is a plural form of mouse", ["mouse"], "en"),
("foo-bar", ["foo-bar"], "en"),
("たくさん走った", ["走"], "ja_mecab"),
("たくさん走った", ["走っ"], "ja_mecab"),
("走れ", ["走れ"], "ja_mecab"),
]


Expand All @@ -49,11 +50,13 @@ def test_call(nlp, text, expected, target, lang):
assert ents == expected


def test_serialization(nlp, tmpdir):
def test_serialization(nlp, tmpdir, lang):
text, expected, target = TESTCASES[0]
if lang != target:
pytest.skip(f"target lang is '{target}', but actual lang is {lang}")
path = Path(tmpdir)
nlp.to_disk(path)
nlp = spacy.load(path)
text, expected, _ = TESTCASES[0]
doc = nlp(text)
ents = [span.text for span in doc.ents]
assert ents == expected
assert ents == expected, list(doc)