Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/2465 add text2text support for prepare for training spark nlp #2466

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,12 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[
doc = nlp.make_doc(text)

cats = dict.fromkeys(all_labels, 0)
for anno in record.annotation:
cats[anno] = 1

if isinstance(record.annotation, list):
for anno in record.annotation:
cats[anno] = 1
else:
cats[record.annotation] = 1

doc.cats = cats
db.add(doc)
Expand Down Expand Up @@ -821,8 +825,15 @@ def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas
def __all_labels__(self):
all_labels = set()
for record in self._records:
if record.annotation:
all_labels.update(record.annotation)
if record.annotation is None:
continue
elif isinstance(record.annotation, str):
all_labels.add(record.annotation)
elif isinstance(record.annotation, list):
all_labels.update((tuple(record.annotation)))
else:
# this is highly unlikely
raise TypeError("Record.annotation contains an unsupported type: {}".format(type(record.annotation)))

return list(all_labels)

Expand Down Expand Up @@ -1258,6 +1269,19 @@ def _prepare_for_training_with_transformers(

return ds

def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame":
spark_nlp_data = []
for record in records:
if record.annotation is None:
continue
if record.id is None:
record.id = str(uuid.uuid4())
text = record.text

spark_nlp_data.append([record.id, text, record.annotation])

return pd.DataFrame(spark_nlp_data, columns=["id", "text", "target"])


Dataset = Union[DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text]

Expand Down
101 changes: 85 additions & 16 deletions tests/client/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_prepare_for_training(self, request, records):
records = request.getfixturevalue(records)

ds = rg.DatasetForTextClassification(records)
train = ds.prepare_for_training()
train = ds.prepare_for_training(seed=42)

if not ds[0].multi_label:
column_names = ["text", "context", "label"]
Expand All @@ -357,12 +357,73 @@ def test_prepare_for_training(self, request, records):
else:
assert train.features["label"] == datasets.ClassLabel(names=["a"])

train_test = ds.prepare_for_training(train_size=0.5)
train_test = ds.prepare_for_training(train_size=0.5, seed=42)
assert len(train_test["train"]) == 1
assert len(train_test["test"]) == 1
for split in ["train", "test"]:
assert train_test[split].column_names == column_names

@pytest.mark.parametrize(
"records",
[
"singlelabel_textclassification_records",
"multilabel_textclassification_records",
],
)
def test_prepare_for_training_with_spacy(self, request, records):
records = request.getfixturevalue(records)

ds = rg.DatasetForTextClassification(records)
with pytest.raises(ValueError):
train = ds.prepare_for_training(framework="spacy", seed=42)
nlp = spacy.blank("en")
doc_bin = ds.prepare_for_training(framework="spacy", lang=nlp, seed=42)

assert isinstance(doc_bin, spacy.tokens.DocBin)
docs = list(doc_bin.get_docs(nlp.vocab))
assert len(docs) == 2

if records[0].multi_label:
assert set(list(docs[0].cats.keys())) == set(["a", "b"])
else:
assert isinstance(docs[0].cats, dict)

train, test = ds.prepare_for_training(train_size=0.5, framework="spacy", lang=nlp, seed=42)
docs_train = list(train.get_docs(nlp.vocab))
docs_test = list(train.get_docs(nlp.vocab))
assert len(list(docs_train)) == 1
assert len(list(docs_test)) == 1

@pytest.mark.parametrize(
"records",
[
"singlelabel_textclassification_records",
"multilabel_textclassification_records",
],
)
def test_prepare_for_training_with_spark_nlp(self, request, records):
records = request.getfixturevalue(records)

ds = rg.DatasetForTextClassification(records)
df = ds.prepare_for_training("spark-nlp", train_size=1, seed=42)

if ds[0].multi_label:
column_names = ["id", "text", "labels"]
else:
column_names = ["id", "text", "label"]

assert isinstance(df, pd.DataFrame)
assert list(df.columns) == column_names
assert len(df) == 2

df_train, df_test = ds.prepare_for_training("spark-nlp", train_size=0.5, seed=42)
assert len(df_train) == 1
assert len(df_test) == 1
assert isinstance(df_train, pd.DataFrame)
assert isinstance(df_test, pd.DataFrame)
assert list(df_train.columns) == column_names
assert list(df_test.columns) == column_names

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
reason="You need a HF Hub access token to test the push_to_hub feature",
Expand Down Expand Up @@ -574,13 +635,15 @@ def test_prepare_for_training_with_spacy(self):
r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]

with pytest.raises(ValueError):
train = rb_dataset.prepare_for_training(framework="spacy")
train = rb_dataset.prepare_for_training(framework="spacy", seed=42)

train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"))
train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), seed=42)
assert isinstance(train, spacy.tokens.DocBin)
assert len(train) == 100

train, test = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), train_size=0.8)
train, test = rb_dataset.prepare_for_training(
framework="spacy", lang=spacy.blank("en"), train_size=0.8, seed=42
)
assert isinstance(train, spacy.tokens.DocBin)
assert isinstance(test, spacy.tokens.DocBin)
assert len(train) == 80
Expand All @@ -601,11 +664,11 @@ def test_prepare_for_training_with_spark_nlp(self):
for r in rb_dataset:
r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]

train = rb_dataset.prepare_for_training(framework="spark-nlp")
train = rb_dataset.prepare_for_training(framework="spark-nlp", seed=42)
assert isinstance(train, pd.DataFrame)
assert len(train) == 100

train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8)
train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8, seed=42)
assert isinstance(train, pd.DataFrame)
assert isinstance(test, pd.DataFrame)
assert len(train) == 80
Expand Down Expand Up @@ -788,8 +851,10 @@ def test_to_from_pandas(self, text2text_records):
assert rec == expected

def test_prepare_for_training(self):
ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10)
train = ds.prepare_for_training(train_size=1)
ds = rg.DatasetForText2Text(
[rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10
)
train = ds.prepare_for_training(train_size=1, seed=42)

assert isinstance(train, datasets.Dataset)
assert train.column_names == ["text", "target"]
Expand All @@ -799,21 +864,25 @@ def test_prepare_for_training(self):
assert train.features["text"] == datasets.Value("string")
assert train.features["target"] == datasets.Value("string")

train_test = ds.prepare_for_training(train_size=0.5)
train_test = ds.prepare_for_training(train_size=0.5, seed=42)
assert len(train_test["train"]) == 5
assert len(train_test["test"]) == 5
for split in ["train", "test"]:
assert train_test[split].column_names == ["text", "target"]

def test_prepare_for_training_spacy(self):
ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10)
def test_prepare_for_training_with_spacy(self):
ds = rg.DatasetForText2Text(
[rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10
)
with pytest.raises(NotImplementedError):
ds.prepare_for_training("spacy", lang=spacy.blank("en"), train_size=1)

def test_prepare_for_training_spark_nlp(self):
ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10)
with pytest.raises(NotImplementedError):
ds.prepare_for_training("spark-nlp", train_size=1)
def test_prepare_for_training_with_spark_nlp(self):
ds = rg.DatasetForText2Text(
[rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10
)
df = ds.prepare_for_training("spark-nlp", train_size=1)
assert list(df.columns) == ["id", "text", "target"]

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
Expand Down