Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new annotator solving the NLI problem #311

Open
wants to merge 35 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7122cbe
Added model that solves NLI between bot replicas
Kolpnick Jan 19, 2023
57ce7d6
Merge branch 'deeppavlov:dev' into dev
Kolpnick Jan 27, 2023
7e4cc4b
Added tests to ConveRTBasedNLI
Kolpnick Jan 27, 2023
5a765f5
Added ConveRTBasedNLI annotator
Kolpnick Jan 27, 2023
17faff4
Changed nli annotator output format
Kolpnick Jan 31, 2023
14d737a
Changed data copying and added model downloading
Kolpnick Feb 7, 2023
64df82a
Changed protobuf requirements
Kolpnick Feb 7, 2023
eed477b
Made lengths in batch the same
Kolpnick Feb 7, 2023
e73f96c
Changed logs outputs
Kolpnick Feb 7, 2023
a3d91a2
Added NLI contradiction selection into convers_evaluation_selector
Kolpnick Feb 8, 2023
77becf2
Fixed history batch preprocessing and updated test file
Kolpnick Feb 11, 2023
47743ac
Fixed code readability
Kolpnick Mar 28, 2023
3e7aa63
Added component cards
Kolpnick Mar 28, 2023
4ba1726
Fixed assert in a test
Kolpnick Mar 28, 2023
9a7216a
Changed debug message output
Kolpnick Mar 28, 2023
9cef446
Increased model memory limit
Kolpnick Mar 28, 2023
7c64fdd
Updated models paths
Kolpnick Mar 28, 2023
75c80c5
Updated model port
Kolpnick Mar 28, 2023
b3adfad
Merge branch 'dev' into convert_based_nli
Kolpnick Mar 28, 2023
9f41374
Deleted model file from git
Kolpnick Mar 28, 2023
0a614b5
Changed test
Kolpnick Mar 29, 2023
d57f7e3
Fixed debug message output
Kolpnick Mar 29, 2023
a34ecdc
Initialization of variables in Dockerfile changed
Kolpnick Apr 6, 2023
5ce22f4
Fixed issues with custom models training
Kolpnick Apr 6, 2023
03df165
Updated model path
Kolpnick Apr 6, 2023
4b7341b
Added README
Kolpnick Apr 6, 2023
f3ddf53
Merge branch 'dev' into convert_based_nli
dilyararimovna May 29, 2023
07376bc
fix: port
dilyararimovna May 29, 2023
f2f9e99
Updated Dockerfile
Kolpnick May 30, 2023
1754759
Separated toxic and contradiction check
Kolpnick May 30, 2023
a05b8af
Fixed black and flake8 codestyles errors
Kolpnick May 31, 2023
c2487a5
Added component card
Kolpnick Jun 5, 2023
c2f0894
Separated logger out for toxic and contradiction
Kolpnick Jun 5, 2023
755c98e
Refactored dockerfile
Kolpnick Jun 5, 2023
e6bbda5
Updated environment.yml
Kolpnick Jul 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions annotators/ConveRTBasedNLI/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
FROM python:3.9.16-slim

ARG CONVERT_URL=http://files.deeppavlov.ai/tmp/convert_model.tar.gz
ARG NLI_URL=http://files.deeppavlov.ai/tmp/nli_model.tar.gz
ARG TRAINED_MODEL_PATH
ARG SERVICE_PORT

ENV TRAINED_MODEL_PATH ${TRAINED_MODEL_PATH}
ENV SERVICE_PORT ${SERVICE_PORT}

RUN apt-get update && \
apt-get install -y --allow-unauthenticated wget && \
rm -rf /var/lib/apt/lists/*

COPY ${WORK_DIR}/requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt
COPY ${WORK_DIR} /src
WORKDIR /src

RUN mkdir /cache /data /data/nli_model/ /data/convert_model/
RUN wget -c -q $NLI_URL -P /tmp/ && \
tar -xf /tmp/nli_model.tar.gz -C /data/nli_model/ && \
wget -c -q $CONVERT_URL -P /tmp/ && \
tar -xf /tmp/convert_model.tar.gz -C /data/convert_model/ && \
rm -rf /tmp/
11 changes: 11 additions & 0 deletions annotators/ConveRTBasedNLI/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
This model is designed to solve the Natural Language Inference problem.

It consists of two parts:
* [ConveRT model](https://arxiv.org/abs/1911.03688) that vectorizes the data
* Custom model consisting from 4 linear layers

The model was trained on the **Stanford Natural Language Inference** (SNLI) corpus that contains human-written English sentence pairs with the labels entailment, contradiction, and neutral.

Pre-trained model available [here](http://files.deeppavlov.ai/tmp/nli_model.tar.gz).

If you want to train a model from scratch, just omit TRAINED_MODEL_PATH input argument or set it to _None_.
243 changes: 243 additions & 0 deletions annotators/ConveRTBasedNLI/convert_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import os
import logging
import numpy as np
import random

from encoder import Encoder
import tensorflow as tf
import tensorflow_datasets as tfds


seed = 1
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)

TRAINED_MODEL_PATH = os.environ.get("TRAINED_MODEL_PATH", None)


def data_generation(file_path):
premise = np.load(file_path)["arr_0"][0]
hypothesis = np.load(file_path)["arr_0"][1]
label = np.load(file_path)["arr_1"]
label = label.reshape((len(label), 1))
return premise, hypothesis, label


class DataGenerator(tf.compat.v2.keras.utils.Sequence):
def __init__(self, list_examples, shuffle=False):
self.list_examples = list_examples
self.shuffle = shuffle
self.indexes = None
self.on_epoch_end()

def __len__(self):
return len(self.list_examples)

def __getitem__(self, index):
pos = self.indexes[index]
premise, hypothesis, label = data_generation(self.list_examples[pos])

return [premise, hypothesis], label

def on_epoch_end(self):
self.indexes = np.arange(len(self.list_examples))
if self.shuffle:
np.random.shuffle(self.indexes)


class ConveRTAnnotator:
def __init__(self):
self.encoder = Encoder()
self.model = None

if TRAINED_MODEL_PATH:
self.model_path = TRAINED_MODEL_PATH + "/model.h5"
else:
self.batch_size = 1024
self.__prepare_data()
self.__create_model()
self.__train_model()

def __prepare_data(self):
logger.info("The download of SNLI dataset has begun.")
snli_dataset = tfds.text.Snli()
snli_dataset.download_and_prepare(download_dir="/cache")

datasets = snli_dataset.as_dataset()
train_dataset, test_dataset, val_dataset = (
datasets["train"],
datasets["test"],
datasets["validation"],
)
val_dataset = val_dataset.batch(self.batch_size).prefetch(
tf.data.experimental.AUTOTUNE
)
test_dataset = test_dataset.batch(self.batch_size).prefetch(
tf.data.experimental.AUTOTUNE
)
train_dataset = train_dataset.batch(self.batch_size).prefetch(
tf.data.experimental.AUTOTUNE
)

logger.info("Dataset downloaded.")

common_path = "/cache/data"
val_path = common_path + "/validation/"
test_path = common_path + "/test/"
train_path = common_path + "/train/"
if not os.path.exists(val_path):
os.makedirs(val_path)
if not os.path.exists(test_path):
os.makedirs(test_path)
if not os.path.exists(train_path):
os.makedirs(train_path)

logger.info("Started making validation dataset.")
self.__vectorize_data(val_path + "val_", val_dataset)
logger.info("Started making test dataset.")
self.__vectorize_data(test_path + "test_", test_dataset)
logger.info("Started making train dataset.")
self.__vectorize_data(train_path + "train_", train_dataset)

train_examples = os.listdir(train_path)
train_examples = [train_path + f_name for f_name in train_examples]
test_examples = os.listdir(test_path)
test_examples = [test_path + f_name for f_name in test_examples]
val_examples = os.listdir(val_path)
val_examples = [val_path + f_name for f_name in val_examples]

self.train_generator = DataGenerator(train_examples)
self.test_generator = DataGenerator(test_examples)
self.val_generator = DataGenerator(val_examples)

logger.info("All datasets have been created.")

def __vectorize_data(self, data_path, dataset):
counter = 0
for example in tfds.as_numpy(dataset):
counter += 1
premise, hypothesis, label = (
example["premise"],
example["hypothesis"],
example["label"],
)

useless_pos = np.where(label == -1)[0]
premise = np.delete(premise, useless_pos)
hypothesis = np.delete(hypothesis, useless_pos)
label = np.delete(label, useless_pos)

premise_encoded = self.encoder.encode_sentences(premise)
hypothesis_encoded = self.encoder.encode_sentences(hypothesis)
np.savez(
data_path + str(counter), [premise_encoded, hypothesis_encoded], label
)

if counter % 10 == 0:
logger.info(f"Prepared {counter} files.")
logger.info("Prepared all files.")

def __create_model(self):
inp_p = tf.keras.layers.Input(shape=self.batch_size)
inp_h = tf.keras.layers.Input(shape=self.batch_size)
combined = tf.keras.layers.concatenate([inp_p, inp_h])
linear_1 = tf.keras.layers.Dense(1024, activation="relu")(combined)
dropout_1 = tf.keras.layers.Dropout(0.45)(linear_1)
linear_2 = tf.keras.layers.Dense(512, activation="relu")(dropout_1)
linear_3 = tf.keras.layers.Dense(256, activation="relu")(linear_2)
output = tf.keras.layers.Dense(3, activation="softmax")(linear_3)

self.model = tf.keras.models.Model(inputs=[inp_p, inp_h], outputs=output)
self.model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer="adam",
metrics=["accuracy"],
)

def __train_model(self):
log_dir = "/cache/logs/"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
csv_logger = tf.keras.callbacks.CSVLogger(log_dir + "log.csv")

ch_path = "/cache/checkpoints"
if not os.path.exists(ch_path):
os.makedirs(ch_path)
ch_path += "/cp-{epoch:04d}.ckpt"
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=ch_path, save_weights_only=True
)

early_stopping = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10
)

_ = self.model.fit(
x=self.train_generator,
validation_data=self.val_generator,
use_multiprocessing=True,
workers=6,
epochs=100,
callbacks=[model_checkpoint, csv_logger, early_stopping],
)

self.model_path = "/cache/model.h5"
self.model.save(self.model_path)
os.environ["TRAINED_MODEL_PATH"] = self.model_path
logger.info("Model is trained.")

def candidate_selection(self, candidates, bot_uttr_history, threshold=0.8):
self.model = tf.keras.models.load_model(self.model_path)
labels = {0: "entailment", 1: "neutral", 2: "contradiction"}
base_dict = {
"decision": labels[1],
labels[0]: 0.0,
labels[1]: 1.0,
labels[2]: 0.0,
}

rez_list = list(base_dict.copy() for _ in range(len(candidates)))
unique_history = {u for b in bot_uttr_history for u in b}

if unique_history and candidates:
vectorized_candidates = self.__response_encoding(candidates)
vectorized_history = self.__response_encoding(list(unique_history))

vectorized_history = dict(zip(unique_history, vectorized_history))
history_arr = [
vectorized_history.get(u) for b in bot_uttr_history for u in b
]
candidates_arr = []
for i in range(len(candidates)):
candidates_arr.extend(
[vectorized_candidates[i]] * len(bot_uttr_history[i])
)

pred_rez = self.model.predict([history_arr, candidates_arr])
pred_rez_idx = 0
for i in range(len(candidates)):
for _ in range(len(bot_uttr_history[i])):
row_probab = pred_rez[pred_rez_idx]
if row_probab[2] < threshold:
row_probab[2] = -row_probab[2]
label = int(np.argmax(row_probab, axis=-1))
if rez_list[i]["decision"] != labels[2]:
rez_list[i] = {
"decision": labels[label],
labels[0]: row_probab[0].astype(float),
labels[1]: row_probab[1].astype(float),
labels[2]: np.abs(row_probab[2]).astype(float),
}
pred_rez_idx += 1
logger.info(rez_list)
return rez_list

def __response_encoding(self, responses):
return self.encoder.encode_sentences(responses)
51 changes: 51 additions & 0 deletions annotators/ConveRTBasedNLI/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np

import tensorflow as tf
import tensorflow_text
import tensorflow_hub as tfhub


tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def normalize_vectors(vectors):
vectors = np.vstack(vectors)
norm = np.linalg.norm(vectors, ord=2, axis=-1, keepdims=True)
return vectors / norm


class Encoder:
def __init__(self):
self.sess = tf.compat.v1.Session()
self.text_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=[None])

self.module = tfhub.Module("/data/convert_model")
self.context_encoding_tensor = self.module(
self.text_placeholder, signature="encode_context"
)
self.encoding_tensor = self.module(self.text_placeholder)
self.response_encoding_tensor = self.module(
self.text_placeholder, signature="encode_response"
)

self.sess.run(tf.compat.v1.tables_initializer())
self.sess.run(tf.compat.v1.global_variables_initializer())

def encode_sentences(self, sentences):
vectors = self.sess.run(
self.encoding_tensor, feed_dict={self.text_placeholder: sentences}
)
return normalize_vectors(vectors)

def encode_contexts(self, sentences):
vectors = self.sess.run(
self.context_encoding_tensor, feed_dict={self.text_placeholder: sentences}
)
return normalize_vectors(vectors)

def encode_responses(self, sentences):
vectors = self.sess.run(
self.response_encoding_tensor, feed_dict={self.text_placeholder: sentences}
)
return normalize_vectors(vectors)
13 changes: 13 additions & 0 deletions annotators/ConveRTBasedNLI/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
tensorflow==2.8.0
tensorflow_hub==0.12.0
tensorflow_text==2.8.2
tensorflow-datasets==4.8.1
flask==1.1.1
itsdangerous==2.0.1
numpy==1.21.6
gunicorn==19.9.0
requests==2.22.0
sentry-sdk==0.12.3
jinja2<=3.0.3
Werkzeug<=2.0.3
protobuf==3.20.3
37 changes: 37 additions & 0 deletions annotators/ConveRTBasedNLI/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import time
from os import getenv

from convert_annotator import ConveRTAnnotator
import sentry_sdk
from flask import Flask, jsonify, request


sentry_sdk.init(getenv("SENTRY_DSN"))

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config["JSON_SORT_KEYS"] = False

annotator = ConveRTAnnotator()
logger.info("Annotator is loaded.")


@app.route("/batch_model", methods=["POST"])
def respond_batch():
start_time = time.time()
sentences = request.json.get("sentences", [])
last_bot_utterances = request.json.get("last_bot_utterances", [])
logger.debug(f"Sentences: {sentences}")
logger.debug(f"Last bot utterances: {last_bot_utterances}")
result = annotator.candidate_selection(sentences, last_bot_utterances)
total_time = time.time() - start_time
logger.info(f"convert-based-nli exec time: {round(total_time, 2)} sec")
return jsonify([{"batch": result}])


if __name__ == "__main__":
app.run(debug=False, host="0.0.0.0", port=8150)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SERVICE_PORT: 8150
TRAINED_MODEL_PATH: /data/nli_model
SERVICE_NAME: convert_based_nli
FLASK_APP: server
Loading