Skip to content

Commit

Permalink
fix transformers eval batchsize failure (#50)
Browse files Browse the repository at this point in the history
* add albert test

* add test

* add data

* fix

* modified:   fail.json

* refactor

* modified:   fail.json
  • Loading branch information
tamuhey authored Apr 24, 2020
1 parent 6a642b7 commit fdbdbe7
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 13 deletions.
20 changes: 19 additions & 1 deletion camphr/lang/juman/tag_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,25 @@

from __future__ import unicode_literals

from spacy.symbols import ADJ, ADP, ADV, AUX, CCONJ, DET, INTJ, NOUN, NUM, PART, POS, PRON, PROPN, PUNCT, SPACE, SYM, VERB # type: ignore
from spacy.symbols import ( # type: ignore
ADJ,
ADP,
ADV,
AUX,
CCONJ,
DET,
INTJ,
NOUN,
NUM,
PART,
POS,
PRON,
PROPN,
PUNCT,
SPACE,
SYM,
VERB,
)

TAG_MAP = {
"名詞,普通名詞": {POS: NOUN},
Expand Down
10 changes: 7 additions & 3 deletions camphr/lang/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def evaluate( # type: ignore
docs, golds = zip(*batch)
docs, golds = self._format_docs_and_golds(docs, golds) # type: ignore
for _, pipe in self.pipeline:
self._eval_pipe(pipe, docs, golds)
self._eval_pipe(pipe, docs, golds, batch_size=batch_size)
loss += cast(float, get_loss_from_docs(docs).cpu().float().item())
for doc, gold in zip(docs, golds):
scorer.score(doc, gold)
Expand All @@ -103,14 +103,18 @@ def evaluate( # type: ignore
return scores

def _eval_pipe(
self, pipe: Pipe, docs: Sequence[Doc], golds: Sequence[GoldParse]
self,
pipe: Pipe,
docs: Sequence[Doc],
golds: Sequence[GoldParse],
batch_size: int,
) -> Sequence[Doc]:
if not hasattr(pipe, "pipe"):
docs = spacy.language._pipe(docs, pipe, {})
elif hasattr(pipe, "eval"):
pipe.eval(docs, golds) # type: ignore
else:
docs = list(pipe.pipe(docs)) # type: ignore
docs = list(pipe.pipe(docs, batch_size=batch_size)) # type: ignore
return docs

def resume_training(self, **kwargs) -> Optimizer: # type: ignore
Expand Down
10 changes: 3 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,10 @@ def lang(request):

@pytest.fixture(scope="session", params=TRF_TESTMODEL_PATH)
def trf_name_or_path(request):
if "bert-base-japanese" in request.param and not check_mecab():
name = request.param
if "bert-base-japanese" in name and not check_mecab():
pytest.skip("mecab is required")
return request.param


@pytest.fixture(scope="session", params=TRF_TESTMODEL_PATH)
def trf_testmodel_path(request) -> str:
return request.param
return name


@pytest.fixture(scope="session")
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/knp/test_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def test_dependency_parse(nlp, text, heads):
assert token.head.i == headi


@pytest.mark.parametrize("text,deps", [("太郎が本を読む", ["nsubj", "case", "obj", "case", "ROOT"])])
@pytest.mark.parametrize(
"text,deps", [("太郎が本を読む", ["nsubj", "case", "obj", "case", "ROOT"])]
)
def test_dependency_deps(nlp, text, deps):
doc = nlp(text)
for token, depi in itertools.zip_longest(doc, deps):
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/transformers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def train():
pipe.cfg["freeze"] = False


def test_freeze_model(trf_testmodel_path, trf_model_config: NLPConfig):
def test_freeze_model(trf_name_or_path, trf_model_config: NLPConfig):
config = omegaconf.OmegaConf.to_container(trf_model_config)
config["pipeline"][TRANSFORMERS_MODEL]["freeze"] = True
nlp = create_model(config)
Expand Down
1 change: 1 addition & 0 deletions tests/regression/test49/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
!fail.json
Empty file.
1 change: 1 addition & 0 deletions tests/regression/test49/fail.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}], ["a", {}]]
39 changes: 39 additions & 0 deletions tests/regression/test49/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
from pathlib import Path

import pytest
import torch

from camphr.lang.torch import TorchLanguage
from camphr.models import create_model
from camphr.ner_labels.labels_ene import ALL_LABELS


@pytest.fixture
def nlp():
name = "albert-base-v2"
config = f"""
lang:
name: en
optimizer:
class: torch.optim.SGD
params:
lr: 0.01
pipeline:
transformers_model:
trf_name_or_path: {name}
transformers_ner:
labels: {ALL_LABELS}
"""
return create_model(config)


@pytest.fixture
def data():
return json.loads((Path(__file__).parent / "fail.json").read_text())


def test(nlp: TorchLanguage, data):
if torch.cuda.is_available():
nlp.to(torch.device("cuda"))
nlp.evaluate(data, batch_size=256)
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ def check_serialization(nlp, text: str = "It is a serialization set. 今日は
DATA_DIR = (Path(__file__).parent / "data/").absolute()

TRF_TESTMODEL_PATH = [str(BERT_JA_DIR), str(XLNET_DIR), str(BERT_DIR)]
LARGE_MODELS = {"albert-base-v2"}

0 comments on commit fdbdbe7

Please sign in to comment.