-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added new annotator solving the NLI problem #311
base: dev
Are you sure you want to change the base?
Changes from 20 commits
7122cbe
57ce7d6
7e4cc4b
5a765f5
17faff4
14d737a
64df82a
eed477b
e73f96c
a3d91a2
77becf2
47743ac
3e7aa63
4ba1726
9a7216a
9cef446
7c64fdd
75c80c5
b3adfad
9f41374
0a614b5
d57f7e3
a34ecdc
5ce22f4
03df165
4b7341b
f3ddf53
07376bc
f2f9e99
1754759
a05b8af
c2487a5
c2f0894
755c98e
e6bbda5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
FROM python:3.7.4 | ||
|
||
ARG DATA_URL=http://files.deeppavlov.ai/tmp/nocontext_tf_model.tar.gz | ||
ARG NEL_URL=http://files.deeppavlov.ai/tmp/model.h5 | ||
|
||
ENV CACHE_DIR /cache | ||
ENV TRAINED_MODEL_PATH /data/nli_model/model.h5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. плохо задать переменную в докерфайле, лучше не задавтаь вообще - зачем переменная-то? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Переменную с cache можно убрать, а вот переменная с путём к модели нужна - вдруг захочется в другое место её загружать) |
||
ENV CONVERT_MODEL_PATH /data/convert_model | ||
|
||
WORKDIR /src | ||
RUN mkdir /cache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. почему кэш, а не в дата? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. что в кэше лежит? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Кэш нужен при обучении: если обучать модель с 0, то в папку cache закачиваются датасеты и сохраняются checkpoints модели |
||
|
||
COPY requirements.txt . | ||
RUN pip install -r requirements.txt | ||
|
||
RUN mkdir -p /data/nli_model/ | ||
RUN curl -L $NEL_URL --output /data/nli_model/model.h5 | ||
|
||
RUN mkdir -p /data/convert_model/ | ||
RUN curl -L $DATA_URL --output /tmp/conv_model.tar.gz && tar -xf /tmp/conv_model.tar.gz -C /data/convert_model && rm /tmp/conv_model.tar.gz | ||
|
||
COPY . . |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
convert-based-nli: | ||
name: convert-based-nli | ||
display_name: ConveRT based NLI | ||
container_name: convert-based-nli | ||
component_type: null | ||
model_type: NN-based | ||
is_customizable: false | ||
author: DeepPavlov | ||
description: defines wheather 2 sentences are correlated as entailment, neutral or contradiction | ||
ram_usage: 1.5G | ||
gpu_usage: null | ||
port: 8150 | ||
endpoints: | ||
- group: candidate_annotators | ||
endpoint: batch_model | ||
date_created: '2023-03-28T14:22:32' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import os | ||
import logging | ||
import numpy as np | ||
import random | ||
|
||
from encoder import Encoder | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
|
||
seed = 1 | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
random.seed(seed) | ||
tf.random.set_seed(seed) | ||
np.random.seed(seed) | ||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
TRAINED_MODEL_PATH = os.environ.get("TRAINED_MODEL_PATH", None) | ||
CACHE_DIR = os.environ.get("CACHE_DIR", None) | ||
|
||
|
||
def data_generation(file_path): | ||
premise = np.load(file_path)['arr_0'][0] | ||
hypothesis = np.load(file_path)['arr_0'][1] | ||
label = np.load(file_path)['arr_1'] | ||
label = label.reshape((len(label), 1)) | ||
return premise, hypothesis, label | ||
|
||
|
||
class DataGenerator(tf.compat.v2.keras.utils.Sequence): | ||
def __init__(self, list_examples, shuffle=True): | ||
self.list_examples = list_examples | ||
self.shuffle = shuffle | ||
self.indexes = None | ||
self.on_epoch_end() | ||
|
||
def __len__(self): | ||
return len(self.list_examples) | ||
|
||
def __getitem__(self, index): | ||
pos = self.indexes[index] | ||
premise, hypothesis, label = data_generation(self.list_examples[pos]) | ||
|
||
return [premise, hypothesis], label | ||
|
||
def on_epoch_end(self): | ||
self.indexes = np.arange(len(self.list_examples)) | ||
if self.shuffle: | ||
np.random.shuffle(self.indexes) | ||
|
||
|
||
class ConveRTAnnotator: | ||
def __init__(self): | ||
self.encoder = Encoder() | ||
self.model = None | ||
|
||
if TRAINED_MODEL_PATH: | ||
self.model_path = TRAINED_MODEL_PATH | ||
else: | ||
self.__prepare_data() | ||
self.__create_model() | ||
self.__train_model() | ||
|
||
def __prepare_data(self): | ||
snli_dataset = tfds.text.Snli() | ||
snli_dataset.download_and_prepare(download_dir=CACHE_DIR) | ||
|
||
datasets = snli_dataset.as_dataset() | ||
train_dataset, test_dataset, val_dataset = datasets['train'], datasets['test'], datasets['validation'] | ||
val_dataset = val_dataset.batch(self.batch_size).prefetch(tf.data.experimental.AUTOTUNE) | ||
test_dataset = test_dataset.batch(self.batch_size).prefetch(tf.data.experimental.AUTOTUNE) | ||
train_dataset = train_dataset.batch(self.batch_size).prefetch(tf.data.experimental.AUTOTUNE) | ||
|
||
common_path = CACHE_DIR + '/data' | ||
val_path = common_path + '/validation' | ||
test_path = common_path + '/test' | ||
train_path = common_path + '/train' | ||
if not os.path.exists(val_path): | ||
os.makedirs(val_path) | ||
if not os.path.exists(test_path): | ||
os.makedirs(test_path) | ||
if not os.path.exists(train_path): | ||
os.makedirs(train_path) | ||
|
||
self.__vectorize_data(val_path + '/val_', val_dataset) | ||
self.__vectorize_data(test_path + '/test_', test_dataset) | ||
self.__vectorize_data(train_path + '/train_', train_dataset) | ||
|
||
train_examples = os.listdir(train_path) | ||
train_examples = [train_path + f_name for f_name in train_examples] | ||
test_examples = os.listdir(test_path) | ||
test_examples = [test_path + f_name for f_name in test_examples] | ||
val_examples = os.listdir(val_path) | ||
val_examples = [val_path + f_name for f_name in val_examples] | ||
|
||
self.train_generator = DataGenerator(train_examples) | ||
self.test_generator = DataGenerator(test_examples) | ||
self.val_generator = DataGenerator(val_examples) | ||
|
||
logger.info(f"All datasets are made.") | ||
|
||
def __vectorize_data(self, data_path, dataset): | ||
counter = 0 | ||
logger.info(f"Started making {data_path[-4:-1]} dataset.") | ||
for example in tfds.as_numpy(dataset): | ||
counter += 1 | ||
premise, hypothesis, label = example['premise'], example['hypothesis'], example['label'] | ||
|
||
useless_pos = np.where(label == -1)[0] | ||
premise = np.delete(premise, useless_pos) | ||
hypothesis = np.delete(hypothesis, useless_pos) | ||
label = np.delete(label, useless_pos) | ||
|
||
premise_encoded = self.encoder.encode_sentences(premise) | ||
hypothesis_encoded = self.encoder.encode_sentences(hypothesis) | ||
np.savez(data_path + str(counter), [premise_encoded, hypothesis_encoded], label) | ||
|
||
if counter % 10 == 0: | ||
logger.info(f"Prepared {counter} files.") | ||
logger.info(f"Prepared all files.") | ||
|
||
def __create_model(self): | ||
inp_p = tf.keras.layers.Input(shape=1024) | ||
inp_h = tf.keras.layers.Input(shape=1024) | ||
combined = tf.keras.layers.concatenate([inp_p, inp_h]) | ||
linear_1 = tf.keras.layers.Dense(1024, activation='relu')(combined) | ||
dropout_1 = tf.keras.layers.Dropout(0.45)(linear_1) | ||
linear_2 = tf.keras.layers.Dense(512, activation='relu')(dropout_1) | ||
linear_3 = tf.keras.layers.Dense(256, activation='relu')(linear_2) | ||
output = tf.keras.layers.Dense(3, activation='softmax')(linear_3) | ||
|
||
self.model = tf.keras.models.Model(inputs=[inp_p, inp_h], outputs=output) | ||
self.model.compile( | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(), | ||
optimizer='adam', | ||
metrics=['accuracy']) | ||
|
||
def __train_model(self): | ||
log_dir = CACHE_DIR + '/logs/' | ||
ch_path = CACHE_DIR + '/checkpoints/cp-{epoch:04d}.ckpt' | ||
csv_logger = tf.keras.callbacks.CSVLogger(log_dir + 'log.csv') | ||
model_checkpoint = tf.keras.callbacks.ModelCheckpoint( | ||
filepath=ch_path, | ||
save_weights_only=True) | ||
early_stopping = tf.keras.callbacks.EarlyStopping( | ||
monitor='val_loss', | ||
patience=10) | ||
|
||
_ = self.model.fit(x=self.train_generator, | ||
validation_data=self.val_generator, | ||
use_multiprocessing=True, | ||
workers=6, epochs=100, | ||
callbacks=[model_checkpoint, csv_logger, early_stopping]) | ||
|
||
self.model.save(CACHE_DIR + '/model.h5') | ||
self.model_path = CACHE_DIR + '/model.h5' | ||
|
||
def candidate_selection(self, candidates, bot_uttr_history, threshold=0.8): | ||
self.model = tf.keras.models.load_model(self.model_path) | ||
labels = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} | ||
base_dict = {'decision': labels[1], | ||
labels[0]: 0.0, | ||
labels[1]: 1.0, | ||
labels[2]: 0.0} | ||
|
||
rez_list = list(base_dict.copy() for _ in range(len(candidates))) | ||
unique_history = {u for b in bot_uttr_history for u in b} | ||
|
||
if unique_history and candidates: | ||
vectorized_candidates = self.__response_encoding(candidates) | ||
vectorized_history = self.__response_encoding(list(unique_history)) | ||
|
||
vectorized_history = dict(zip(unique_history, vectorized_history)) | ||
history_arr = [vectorized_history.get(u) for b in bot_uttr_history for u in b] | ||
candidates_arr = [] | ||
for i in range(len(candidates)): | ||
candidates_arr.extend([vectorized_candidates[i]] * len(bot_uttr_history[i])) | ||
|
||
pred_rez = self.model.predict([history_arr, candidates_arr]) | ||
pred_rez_idx = 0 | ||
for i in range(len(candidates)): | ||
for _ in range(len(bot_uttr_history[i])): | ||
row_probab = pred_rez[pred_rez_idx] | ||
if row_probab[2] < threshold: | ||
row_probab[2] = -row_probab[2] | ||
label = int(np.argmax(row_probab, axis=-1)) | ||
if rez_list[i]['decision'] != labels[2]: | ||
rez_list[i] = {'decision': labels[label], | ||
labels[0]: row_probab[0].astype(float), | ||
labels[1]: row_probab[1].astype(float), | ||
labels[2]: np.abs(row_probab[2]).astype(float)} | ||
pred_rez_idx += 1 | ||
logger.info(rez_list) | ||
return rez_list | ||
|
||
def __response_encoding(self, responses): | ||
return self.encoder.encode_sentences(responses) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
import numpy as np | ||
|
||
import tensorflow as tf | ||
import tensorflow_text | ||
import tensorflow_hub as tfhub | ||
|
||
|
||
tf.compat.v1.disable_eager_execution() | ||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | ||
CONVERT_MODEL_PATH = os.environ.get("CONVERT_MODEL_PATH", None) | ||
|
||
|
||
def normalize_vectors(vectors): | ||
vectors = np.vstack(vectors) | ||
norm = np.linalg.norm(vectors, ord=2, axis=-1, keepdims=True) | ||
return vectors/norm | ||
|
||
|
||
class Encoder: | ||
def __init__(self): | ||
self.sess = tf.compat.v1.Session() | ||
self.text_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=[None]) | ||
|
||
self.module = tfhub.Module(CONVERT_MODEL_PATH) | ||
self.context_encoding_tensor = self.module(self.text_placeholder, signature="encode_context") | ||
self.encoding_tensor = self.module(self.text_placeholder) | ||
self.response_encoding_tensor = self.module(self.text_placeholder, signature="encode_response") | ||
|
||
self.sess.run(tf.compat.v1.tables_initializer()) | ||
self.sess.run(tf.compat.v1.global_variables_initializer()) | ||
|
||
def encode_sentences(self, sentences): | ||
vectors = self.sess.run(self.encoding_tensor, feed_dict={self.text_placeholder: sentences}) | ||
return normalize_vectors(vectors) | ||
|
||
def encode_contexts(self, sentences): | ||
vectors = self.sess.run(self.context_encoding_tensor, feed_dict={self.text_placeholder: sentences}) | ||
return normalize_vectors(vectors) | ||
|
||
def encode_responses(self, sentences): | ||
vectors = self.sess.run(self.response_encoding_tensor, feed_dict={self.text_placeholder: sentences}) | ||
return normalize_vectors(vectors) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
convert_based_nli: | ||
- group: candidate_annotators | ||
connector: | ||
protocol: http | ||
timeout: 2.0 | ||
url: http://convert-based-nli:8150/batch_model | ||
dialog_formatter: state_formatters.dp_formatters:convert_nli_hypotheses_annotator_formatter | ||
response_formatter: state_formatters.dp_formatters:simple_formatter_service | ||
state_manager_method: add_hypothesis_annotation_batch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
tensorflow==2.8.0 | ||
tensorflow_hub==0.12.0 | ||
tensorflow_text==2.8.2 | ||
tensorflow-datasets==4.8.1 | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
numpy==1.21.6 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk==0.12.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 | ||
protobuf==3.20.3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import logging | ||
import time | ||
from os import getenv | ||
|
||
from convert_annotator import ConveRTAnnotator | ||
import sentry_sdk | ||
from flask import Flask, jsonify, request | ||
|
||
|
||
sentry_sdk.init(getenv("SENTRY_DSN")) | ||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
app = Flask(__name__) | ||
app.config['JSON_SORT_KEYS'] = False | ||
|
||
annotator = ConveRTAnnotator() | ||
logger.info("Annotator is loaded.") | ||
|
||
|
||
@app.route("/batch_model", methods=["POST"]) | ||
def respond_batch(): | ||
start_time = time.time() | ||
sentences = request.json.get("sentences", []) | ||
last_bot_utterances = request.json.get("last_bot_utterances", []) | ||
logger.debug(f"Sentences: {sentences}") | ||
logger.debug(f"Last bot utterances: {last_bot_utterances}") | ||
result = annotator.candidate_selection(sentences, last_bot_utterances) | ||
total_time = time.time() - start_time | ||
logger.info(f"convert-based-nli exec time: {round(total_time, 2)} sec") | ||
return jsonify([{"batch": result}]) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(debug=False, host="0.0.0.0", port=8150) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import requests | ||
|
||
|
||
def main(): | ||
url = "http://0.0.0.0:8150/batch_model" | ||
|
||
input_data = {"sentences": ["Do you like ice cream?", "It's going to be sunny today", | ||
"I love dogs", "Do you want to know some interesting fact?", | ||
"Wolves have small teeth"], | ||
"last_bot_utterances": [["I hate dogs", "The moon is a satellite of the earth"], | ||
[], | ||
["I hate dogs", "Wolves have big teeth", "The moon is a satellite of the earth"], | ||
["The moon is a satellite of the earth"], | ||
["Wolves have big teeth", "The moon is a satellite of the earth"]]} | ||
desired_labels = ['neutral', 'neutral', 'contradiction', 'neutral', 'contradiction'] | ||
|
||
result = requests.post(url, json=input_data).json() | ||
labels = [r['decision'] for r in result[0]['batch']] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wild guess: не закреплен random seed где-то, где стоило бы закрепить? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. не убирай числа, используй round() до двух знаков после запятой There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
assert labels == desired_labels | ||
print("Successfully predicted contradiction!") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
почему h5 не архив? может заархивировать? сколько весит файл?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
давай пути к моделям принимать в виде параметров без дефолтных значейний. значения задавтаь в докер компоуз
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Модель занимает немного места - всего 10,5 Мб весит (которая model.h5)