diff --git a/neural_seq_qa/.gitignore b/neural_seq_qa/.gitignore new file mode 100644 index 0000000000..3cb1a39c52 --- /dev/null +++ b/neural_seq_qa/.gitignore @@ -0,0 +1,33 @@ +*.DS_Store +build/ +build_doc/ +*.user + +*.swp +.vscode +.idea +.project +.cproject +.pydevproject +.settings/ +Makefile +.test_env/ +third_party/ + +*~ +bazel-* +third_party/ + +# clion workspace. +cmake-build-* + +data/data +data/embedding +data/evaluation +data/LICENSE +data/Readme.md +tmp +eval.*.txt +models* +*.log +run.sh diff --git a/neural_seq_qa/README.md b/neural_seq_qa/README.md new file mode 100644 index 0000000000..52c91a118d --- /dev/null +++ b/neural_seq_qa/README.md @@ -0,0 +1,81 @@ +# Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering + +This model implements the work in the following paper: + +Peng Li, Wei Li, Zhengyan He, Xuguang Wang, Ying Cao, Jie Zhou, and Wei Xu. Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering. [arXiv:1607.06275](https://arxiv.org/abs/1607.06275). + +If you use the dataset/code in your research, please cite the above paper: + +```text +@article{li:2016:arxiv, + author = {Li, Peng and Li, Wei and He, Zhengyan and Wang, Xuguang and Cao, Ying and Zhou, Jie and Xu, Wei}, + title = {Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering}, + journal = {arXiv:1607.06275v2}, + year = {2016}, + url = {https://arxiv.org/abs/1607.06275v2}, +} +``` + + +# Installation + +1. Install PaddlePaddle v0.10.5 by the following commond. Note that v0.10.0 is not supported. + ```bash + # either one is OK + # CPU + pip install paddlepaddle + # GPU + pip install paddlepaddle-gpu + ``` +2. Download the [WebQA](http://idl.baidu.com/WebQA.html) dataset by running + ```bash + cd data && ./download.sh && cd .. + ``` + +#Hyperparameters + +All the hyperparameters are defined in `config.py`. The default values are aligned with the paper. + +# Training + +Training can be launched using the following command: + +```bash +PYTHONPATH=data/evaluation:$PYTHONPATH python train.py 2>&1 | tee train.log +``` +# Validation and Test + +WebQA provides two versions of validation and test sets. Automatic validation and test can be lauched by + +```bash +PYTHONPATH=data/evaluation:$PYTHONPATH python val_and_test.py models [ann|ir] +``` + +where + +* `models`: the directory where model files are stored. You can use `models` if `config.py` is not changed. +* `ann`: using the validation and test sets with annotated evidence. +* `ir`: using the validation and test sets with retrieved evidence. + +Note that validation and test can run simultaneously with training. `val_and_test.py` will handle the synchronization related problems. + +Intermediate results are stored in the directory `tmp`. You can delete them safely after validation and test. + +The results should be comparable with those shown in Table 3 in the paper. + +# Inferring using a Trained Model + +Infer using a trained model by running: +```bash +PYTHONPATH=data/evaluation:$PYTHONPATH python infer.py \ + MODEL_FILE \ + INPUT_DATA \ + OUTPUT_FILE \ + 2>&1 | tee infer.log +``` + +where + +* `MODEL_FILE`: a trained model produced by `train.py`. +* `INPUT_DATA`: input data in the same format as the validation/test sets of the WebQA dataset. +* `OUTPUT_FILE`: results in the format specified in the WebQA dataset for the evaluation scripts. diff --git a/neural_seq_qa/config.py b/neural_seq_qa/config.py new file mode 100644 index 0000000000..b6c457a499 --- /dev/null +++ b/neural_seq_qa/config.py @@ -0,0 +1,112 @@ +import math + +__all__ = ["TrainingConfig", "InferConfig"] + + +class CommonConfig(object): + def __init__(self): + # network size: + # dimension of the question LSTM + self.q_lstm_dim = 64 + # dimension of the attention layer + self.latent_chain_dim = 64 + # dimension of the evidence LSTMs + self.e_lstm_dim = 64 + # dimension of the qe.comm and ee.comm feature embeddings + self.com_vec_dim = 2 + self.drop_rate = 0.05 + + # CRF: + # valid values are BIO and BIO2 + self.label_schema = "BIO2" + + # word embedding: + # vocabulary file path + self.word_dict_path = "data/embedding/wordvecs.vcb" + # word embedding file path + self.wordvecs_path = "data/embedding/wordvecs.txt" + self.word_vec_dim = 64 + + # saving model & logs: + # dir for saving models + self.model_save_dir = "models" + + # print training info every log_period batches + self.log_period = 100 + # show parameter status every show_parameter_status_period batches + self.show_parameter_status_period = 100 + + @property + def label_num(self): + if self.label_schema == "BIO": + return 3 + elif self.label_schema == "BIO2": + return 4 + else: + raise ValueError("wrong value for label_schema") + + @property + def default_init_std(self): + return 1 / math.sqrt(self.e_lstm_dim * 4) + + @property + def default_l2_rate(self): + return 8e-4 * self.batch_size / 6 + + @property + def dict_dim(self): + return len(self.vocab) + + +class TrainingConfig(CommonConfig): + def __init__(self): + super(TrainingConfig, self).__init__() + + # data: + # training data path + self.train_data_path = "data/data/training.json.gz" + + # number of batches used in each pass + self.batches_per_pass = 1000 + # number of passes to train + self.num_passes = 25 + # batch size + self.batch_size = 120 + + # the ratio of negative samples used in training + self.negative_sample_ratio = 0.2 + # the ratio of negative samples that contain golden answer string + self.hit_ans_negative_sample_ratio = 0.25 + + # keep only first B in golden labels + self.keep_first_b = False + + # use GPU to train the model + self.use_gpu = False + # number of threads + self.trainer_count = 1 + + # random seeds: + # data reader random seed, 0 for random seed + self.seed = 0 + # paddle random seed, 0 for random seed + self.paddle_seed = 0 + + # optimizer: + self.learning_rate = 1e-3 + # rmsprop + self.rho = 0.95 + self.epsilon = 1e-4 + # model average + self.average_window = 0.5 + self.max_average_window = 10000 + + +class InferConfig(CommonConfig): + def __init__(self): + super(InferConfig, self).__init__() + + self.use_gpu = False + self.trainer_count = 1 + self.batch_size = 120 + self.wordvecs = None diff --git a/neural_seq_qa/data/download.sh b/neural_seq_qa/data/download.sh new file mode 100755 index 0000000000..1cae249385 --- /dev/null +++ b/neural_seq_qa/data/download.sh @@ -0,0 +1,18 @@ +#!/bin/bash +if [[ -d data ]] && [[ -d embedding ]] && [[ -d evaluation ]]; then + echo "data exist" + exit 0 +else + wget -c http://paddlepaddle.bj.bcebos.com/dataset/webqa/WebQA.v1.0.zip +fi + +if [[ `md5sum -c md5sum.txt` =~ 'OK' ]] ; then + unzip WebQA.v1.0.zip + mv WebQA.v1.0/* . + rmdir WebQA.v1.0 + rm WebQA.v1.0.zip +else + echo "download data error!" >> /dev/stderr + exit 1 +fi + diff --git a/neural_seq_qa/data/md5sum.txt b/neural_seq_qa/data/md5sum.txt new file mode 100644 index 0000000000..14b74c3ad0 --- /dev/null +++ b/neural_seq_qa/data/md5sum.txt @@ -0,0 +1 @@ +b129df2a4eb547d8b398721dd7ed6cc6 WebQA.v1.0.zip diff --git a/neural_seq_qa/index.html b/neural_seq_qa/index.html new file mode 100644 index 0000000000..a18ebe864b --- /dev/null +++ b/neural_seq_qa/index.html @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/neural_seq_qa/infer.py b/neural_seq_qa/infer.py new file mode 100644 index 0000000000..14bda05a54 --- /dev/null +++ b/neural_seq_qa/infer.py @@ -0,0 +1,82 @@ +import os +import sys +import argparse + +import paddle.v2 as paddle + +import reader +import utils +import network +import config + +from utils import logger + + +class Infer(object): + def __init__(self, conf): + self.conf = conf + + self.settings = reader.Settings( + vocab=conf.vocab, is_training=False, label_schema=conf.label_schema) + + # init paddle + # TODO(lipeng17) v2 API does not support parallel_nn yet. Therefore, we + # can only use CPU currently + paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count) + + # define network + self.tags_layer = network.inference_net(conf) + + def infer(self, model_path, data_path, output): + test_reader = paddle.batch( + paddle.reader.buffered( + reader.create_reader(data_path, self.settings), + size=self.conf.batch_size * 1000), + batch_size=self.conf.batch_size) + + # load the trained models + parameters = paddle.parameters.Parameters.from_tar( + utils.open_file(model_path, "r")) + inferer = paddle.inference.Inference( + output_layer=self.tags_layer, parameters=parameters) + + def count_evi_ids(test_batch): + num = 0 + for sample in test_batch: + num += len(sample[reader.E_IDS]) + return num + + for test_batch in test_reader(): + tags = inferer.infer( + input=test_batch, field=["id"], feeding=network.feeding) + evi_ids_num = count_evi_ids(test_batch) + assert len(tags) == evi_ids_num + print >> output, ";\n".join(str(tag) for tag in tags) + ";" + + +def parse_cmd(): + parser = argparse.ArgumentParser() + parser.add_argument("model_path") + parser.add_argument("data_path") + parser.add_argument("output", help="'-' for stdout") + return parser.parse_args() + + +def main(args): + conf = config.InferConfig() + conf.vocab = utils.load_dict(conf.word_dict_path) + logger.info("length of word dictionary is : %d." % len(conf.vocab)) + + if args.output == "-": + output = sys.stdout + else: + output = utils.open_file(args.output, "w") + + infer = Infer(conf) + infer.infer(args.model_path, args.data_path, output) + + output.close() + + +if __name__ == "__main__": + main(parse_cmd()) diff --git a/neural_seq_qa/network.py b/neural_seq_qa/network.py new file mode 100644 index 0000000000..0fb19022b7 --- /dev/null +++ b/neural_seq_qa/network.py @@ -0,0 +1,314 @@ +import math +import paddle.v2 as paddle + +import reader + +__all__ = ["training_net", "inference_net", "feeding"] + +feeding = { + reader.Q_IDS_STR: reader.Q_IDS, + reader.E_IDS_STR: reader.E_IDS, + reader.QE_COMM_STR: reader.QE_COMM, + reader.EE_COMM_STR: reader.EE_COMM, + reader.LABELS_STR: reader.LABELS +} + + +def get_embedding(input, word_vec_dim, wordvecs): + """ + Define word embedding + + :param input: layer input + :type input: LayerOutput + :param word_vec_dim: dimension of the word embeddings + :type word_vec_dim: int + :param wordvecs: word embedding matrix + :type wordvecs: numpy array + :return: embedding + :rtype: LayerOutput + """ + return paddle.layer.embedding( + input=input, + size=word_vec_dim, + param_attr=paddle.attr.ParamAttr( + name="wordvecs", is_static=True, initializer=lambda _: wordvecs)) + + +def encoding_question(question, q_lstm_dim, latent_chain_dim, word_vec_dim, + drop_rate, wordvecs, default_init_std, default_l2_rate): + """ + Define network for encoding question + + :param question: question token ids + :type question: LayerOutput + :param q_lstm_dim: dimension of the question LSTM + :type q_lstm_dim: int + :param latent_chain_dim: dimension of the attention layer + :type latent_chain_dim: int + :param word_vec_dim: dimension of the word embeddings + :type word_vec_dim: int + :param drop_rate: dropout rate + :type drop_rate: float + :param wordvecs: word embedding matrix + :type wordvecs: numpy array + :param default_init_std: default initial standard deviation + :type default_init_std: float + :param default_l2_rate: default l2 rate + :type default_l2_rate: float + :return: question encoding + :rtype: LayerOutput + """ + # word embedding + emb = get_embedding(question, word_vec_dim, wordvecs) + + # question LSTM + wx = paddle.layer.fc( + act=paddle.activation.Linear(), + size=q_lstm_dim * 4, + input=emb, + param_attr=paddle.attr.ParamAttr( + name="_q_hidden1.w0", + initial_std=default_init_std, + l2_rate=default_l2_rate), + bias_attr=paddle.attr.ParamAttr( + name="_q_hidden1.wbias", initial_std=0, l2_rate=default_l2_rate)) + q_rnn = paddle.layer.lstmemory( + input=wx, + bias_attr=paddle.attr.ParamAttr( + name="_q_rnn1.wbias", initial_std=0, l2_rate=default_l2_rate), + param_attr=paddle.attr.ParamAttr( + name="_q_rnn1.w0", + initial_std=default_init_std, + l2_rate=default_l2_rate)) + q_rnn = paddle.layer.dropout(q_rnn, drop_rate) + + # self attention + fc = paddle.layer.fc( + act=paddle.activation.Tanh(), + size=latent_chain_dim, + input=q_rnn, + param_attr=paddle.attr.ParamAttr( + name="_attention_layer1.w0", + initial_std=default_init_std, + l2_rate=default_l2_rate), + bias_attr=False) + weight = paddle.layer.fc( + size=1, + act=paddle.activation.SequenceSoftmax(), + input=fc, + param_attr=paddle.attr.ParamAttr( + name="_attention_weight.w0", + initial_std=default_init_std, + l2_rate=default_l2_rate), + bias_attr=False) + + scaled_q_rnn = paddle.layer.scaling(input=q_rnn, weight=weight) + + q_encoding = paddle.layer.pooling( + input=scaled_q_rnn, pooling_type=paddle.pooling.Sum()) + return q_encoding + + +def encoding_evidence(evidence, qe_comm, ee_comm, q_encoding, e_lstm_dim, + word_vec_dim, com_vec_dim, drop_rate, wordvecs, + default_init_std, default_l2_rate): + """ + Define network for encoding evidence + + :param qe_comm: qe.ecomm features + :type qe_comm: LayerOutput + :param ee_comm: ee.ecomm features + :type ee_comm: LayerOutput + :param q_encoding: question encoding, a fixed-length vector + :type q_encoding: LayerOutput + :param e_lstm_dim: dimension of the evidence LSTMs + :type e_lstm_dim: int + :param word_vec_dim: dimension of the word embeddings + :type word_vec_dim: int + :param com_vec_dim: dimension of the qe.comm and ee.comm feature embeddings + :type com_vec_dim: int + :param drop_rate: dropout rate + :type drop_rate: float + :param wordvecs: word embedding matrix + :type wordvecs: numpy array + :param default_init_std: default initial standard deviation + :type default_init_std: float + :param default_l2_rate: default l2 rate + :type default_l2_rate: float + :return: evidence encoding + :rtype: LayerOutput + """ + + def lstm(idx, reverse, inputs): + """LSTM wrapper""" + bias_attr = paddle.attr.ParamAttr( + name="_e_hidden%d.wbias" % idx, + initial_std=0, + l2_rate=default_l2_rate) + with paddle.layer.mixed(size=e_lstm_dim * 4, bias_attr=bias_attr) as wx: + for i, input in enumerate(inputs): + param_attr = paddle.attr.ParamAttr( + name="_e_hidden%d.w%d" % (idx, i), + initial_std=default_init_std, + l2_rate=default_l2_rate) + wx += paddle.layer.full_matrix_projection( + input=input, param_attr=param_attr) + + e_rnn = paddle.layer.lstmemory( + input=wx, + reverse=reverse, + bias_attr=paddle.attr.ParamAttr( + name="_e_rnn%d.wbias" % idx, + initial_std=0, + l2_rate=default_l2_rate), + param_attr=paddle.attr.ParamAttr( + name="_e_rnn%d.w0" % idx, + initial_std=default_init_std, + l2_rate=default_l2_rate)) + e_rnn = paddle.layer.dropout(e_rnn, drop_rate) + return e_rnn + + # share word embeddings with question + emb = get_embedding(evidence, word_vec_dim, wordvecs) + + # copy q_encoding len(evidence) times + q_encoding_expand = paddle.layer.expand( + input=q_encoding, expand_as=evidence) + + # feature embeddings + comm_initial_std = 1 / math.sqrt(64.0) + qe_comm_emb = paddle.layer.embedding( + input=qe_comm, + size=com_vec_dim, + param_attr=paddle.attr.ParamAttr( + name="_cw_embedding.w0", + initial_std=comm_initial_std, + l2_rate=default_l2_rate)) + + ee_comm_emb = paddle.layer.embedding( + input=ee_comm, + size=com_vec_dim, + param_attr=paddle.attr.ParamAttr( + name="_eecom_embedding.w0", + initial_std=comm_initial_std, + l2_rate=default_l2_rate)) + + # evidence LSTMs + first_layer_extra_inputs = [q_encoding_expand, qe_comm_emb, ee_comm_emb] + e_rnn1 = lstm(1, False, [emb] + first_layer_extra_inputs) + e_rnn2 = lstm(2, True, [e_rnn1]) + e_rnn3 = lstm(3, False, [e_rnn2, e_rnn1]) # with cross layer links + + return e_rnn3 + + +def define_data(dict_dim, label_num): + """ + Define data layers + + :param dict_dim: number of words in the vocabulary + :type dict_dim: int + :param label_num: label numbers, BIO:3, BIO2:4 + :type label_num: int + :return: data layers + :rtype: tuple of LayerOutput + """ + question = paddle.layer.data( + name=reader.Q_IDS_STR, + type=paddle.data_type.integer_value_sequence(dict_dim)) + + evidence = paddle.layer.data( + name=reader.E_IDS_STR, + type=paddle.data_type.integer_value_sequence(dict_dim)) + + qe_comm = paddle.layer.data( + name=reader.QE_COMM_STR, + type=paddle.data_type.integer_value_sequence(2)) + + ee_comm = paddle.layer.data( + name=reader.EE_COMM_STR, + type=paddle.data_type.integer_value_sequence(2)) + + label = paddle.layer.data( + name=reader.LABELS_STR, + type=paddle.data_type.integer_value_sequence(label_num), + layer_attr=paddle.attr.ExtraAttr(device=-1)) + + return question, evidence, qe_comm, ee_comm, label + + +def define_common_network(conf): + """ + Define common network + + :param conf: network conf + :return: CRF features, golden labels + :rtype: tuple + """ + # define data layers + question, evidence, qe_comm, ee_comm, label = \ + define_data(conf.dict_dim, conf.label_num) + + # encode question + q_encoding = encoding_question(question, conf.q_lstm_dim, + conf.latent_chain_dim, conf.word_vec_dim, + conf.drop_rate, conf.wordvecs, + conf.default_init_std, conf.default_l2_rate) + + # encode evidence + e_encoding = encoding_evidence( + evidence, qe_comm, ee_comm, q_encoding, conf.e_lstm_dim, + conf.word_vec_dim, conf.com_vec_dim, conf.drop_rate, conf.wordvecs, + conf.default_init_std, conf.default_l2_rate) + + # pre-compute CRF features + crf_feats = paddle.layer.fc( + act=paddle.activation.Linear(), + input=e_encoding, + size=conf.label_num, + param_attr=paddle.attr.ParamAttr( + name="_output.w0", + initial_std=conf.default_init_std, + l2_rate=conf.default_l2_rate), + bias_attr=False) + return crf_feats, label + + +def training_net(conf): + """ + Define training network + + :param conf: network conf + :return: CRF cost + :rtype: LayerOutput + """ + e_encoding, label = define_common_network(conf) + crf = paddle.layer.crf( + input=e_encoding, + label=label, + size=conf.label_num, + param_attr=paddle.attr.ParamAttr( + name="_crf.w0", + initial_std=conf.default_init_std, + l2_rate=conf.default_l2_rate), + layer_attr=paddle.attr.ExtraAttr(device=-1)) + + return crf + + +def inference_net(conf): + """ + Define training network + + :param conf: network conf + :return: CRF viberbi decoding result + :rtype: LayerOutput + """ + e_encoding, label = define_common_network(conf) + ret = paddle.layer.crf_decoding( + input=e_encoding, + size=conf.label_num, + param_attr=paddle.attr.ParamAttr(name="_crf.w0"), + layer_attr=paddle.attr.ExtraAttr(device=-1)) + + return ret diff --git a/neural_seq_qa/reader.py b/neural_seq_qa/reader.py new file mode 100644 index 0000000000..e55e77b601 --- /dev/null +++ b/neural_seq_qa/reader.py @@ -0,0 +1,409 @@ +import sys +import random +from itertools import izip +import json +import traceback + +from datapoint import DataPoint, Evidence, EecommFeatures +import utils +from utils import logger + +__all__ = [ + "Q_IDS", "E_IDS", "LABELS", "QE_COMM", "EE_COMM", "Q_IDS_STR", "E_IDS_STR", + "LABELS_STR", "QE_COMM_STR", "EE_COMM_STR", "Settings", "create_reader" +] + +# slot names +Q_IDS_STR = "q_ids" +E_IDS_STR = "e_ids" +LABELS_STR = "labels" +QE_COMM_STR = "qe.comm" +EE_COMM_STR = "ee.comm" + +Q_IDS = 0 +E_IDS = 1 +LABELS = 2 +QE_COMM = 3 +EE_COMM = 4 + +NO_ANSWER = "no_answer" + + +class Settings(object): + """ + class for storing settings + """ + + def __init__(self, + vocab, + is_training, + label_schema="BIO2", + negative_sample_ratio=0.2, + hit_ans_negative_sample_ratio=0.25, + keep_first_b=False, + seed=31425926): + """ + Init function + + :param vocab: word dict + :type vocab: dict + :param is_training: True for training + :type is_training: bool + :param label_schema: label schema, valid values are BIO and BIO2, + the default value is BIO2 + :type label_schema: str + :param negative_sample_ratio: the ratio of negative samples used in + training, the default value is 0.2 + :type negative_sample_ratio: float + :param hit_ans_negative_sample_ratio: the ratio of negative samples + that contain golden answer string, the default value is 0.25 + :type hit_ans_negative_sample_ratio: float + :param keep_first_b: only keep the first B in golden tag sequence, + the default value is False + :type keep_first_b: bool + :param seed: random seed, the default value is 31425926 + :type seed: int + """ + self.negative_sample_ratio = negative_sample_ratio + self.hit_ans_negative_sample_ratio = hit_ans_negative_sample_ratio + self.keep_first_b = keep_first_b + self.is_training = is_training + self.vocab = vocab + + # set up label schema + if label_schema == "BIO": + B, I, O1, O2 = 0, 1, 2, 2 + elif label_schema == "BIO2": + B, I, O1, O2 = 0, 1, 2, 3 + else: + raise ValueError("label_schema should be BIO/BIO2") + self.B, self.I, self.O1, self.O2 = B, I, O1, O2 + self.label_map = { + "B": B, + "I": I, + "O1": O1, + "O2": O2, + "b": B, + "i": I, + "o1": O1, + "o2": O2 + } + self.label_num = len(set((B, I, O1, O2))) + + # id for OOV + self.oov_id = 0 + + # set up random seed + random.seed(seed) + + # booking message + logger.info("negative_sample_ratio: %f", negative_sample_ratio) + logger.info("hit_ans_negative_sample_ratio: %f", + hit_ans_negative_sample_ratio) + logger.info("keep_first_b: %s", keep_first_b) + logger.info("data reader random seed: %d", seed) + + +class SampleStream(object): + def __init__(self, filename, settings): + self.filename = filename + self.settings = settings + + def __iter__(self): + return self.load_and_filter_samples(self.filename) + + def load_and_filter_samples(self, filename): + def remove_extra_b(labels): + if labels.count(self.settings.B) <= 1: return + + i = 0 + # find the first B + while i < len(labels) and labels[i] == self.settings.O1: + i += 1 + i += 1 # skip B + # skip the following Is + while i < len(labels) and labels[i] == self.settings.I: + i += 1 + # change all the other tags to O2 + while i < len(labels): + labels[i] = self.settings.O2 + i += 1 + + def filter_and_preprocess_evidences(evidences): + for i, evi in enumerate(evidences): + # convert golden labels to labels ids + if Evidence.GOLDEN_LABELS in evi: + labels = [self.settings.label_map[l] \ + for l in evi[Evidence.GOLDEN_LABELS]] + else: + labels = [self.settings.O1] * len(evi[Evidence.E_TOKENS]) + + # determine the current evidence is negative or not + answer_list = evi[Evidence.GOLDEN_ANSWERS] + is_negative = len(answer_list) == 1 \ + and "".join(answer_list[0]).lower() == NO_ANSWER + + # drop positive evidences that do not contain golden answer + # matches in training + is_all_o1 = labels.count(self.settings.O1) == len(labels) + if self.settings.is_training and is_all_o1 and not is_negative: + evidences[i] = None # dropped + continue + + if self.settings.keep_first_b: + remove_extra_b(labels) + evi[Evidence.GOLDEN_LABELS] = labels + + def get_eecom_feats_list(cur_sample_is_negative, eecom_feats_list, + evidences): + if not self.settings.is_training: + return [item[EecommFeatures.EECOMM_FEATURES] \ + for item in eecom_feats_list] + + positive_eecom_feats_list = [] + negative_eecom_feats_list = [] + + for eecom_feats_, other_evi in izip(eecom_feats_list, evidences): + if not other_evi: continue + + eecom_feats = eecom_feats_[EecommFeatures.EECOMM_FEATURES] + if not eecom_feats: continue + + other_evi_type = eecom_feats_[EecommFeatures.OTHER_E_TYPE] + if cur_sample_is_negative and \ + other_evi_type != Evidence.POSITIVE: + continue + + if other_evi_type == Evidence.POSITIVE: + positive_eecom_feats_list.append(eecom_feats) + else: + negative_eecom_feats_list.append(eecom_feats) + + eecom_feats_list = positive_eecom_feats_list + if negative_eecom_feats_list: + eecom_feats_list += [negative_eecom_feats_list] + + return eecom_feats_list + + def process_tokens(data, tok_key): + ids = [self.settings.vocab.get(token, self.settings.oov_id) \ + for token in data[tok_key]] + return ids + + def process_evi(q_ids, evi, evidences): + e_ids = process_tokens(evi, Evidence.E_TOKENS) + + labels = evi[Evidence.GOLDEN_LABELS] + qe_comm = evi[Evidence.QECOMM_FEATURES] + sample_type = evi[Evidence.TYPE] + + ret = [None] * 5 + ret[Q_IDS] = q_ids + ret[E_IDS] = e_ids + ret[LABELS] = labels + ret[QE_COMM] = qe_comm + + eecom_feats_list = get_eecom_feats_list( + sample_type != Evidence.POSITIVE, + evi[Evidence.EECOMM_FEATURES_LIST], evidences) + if not eecom_feats_list: + return None + else: + ret[EE_COMM] = eecom_feats_list + return ret + + with utils.DotBar(utils.open_file(filename)) as f_: + for q_idx, line in enumerate(f_): + # parse json line + try: + data = json.loads(line) + except Exception: + logger.fatal("ERROR LINE: %s", line.strip()) + traceback.print_exc() + continue + + # convert question tokens to ids + q_ids = process_tokens(data, DataPoint.Q_TOKENS) + + # process evidences + evidences = data[DataPoint.EVIDENCES] + filter_and_preprocess_evidences(evidences) + for evi in evidences: + if not evi: continue + sample = process_evi(q_ids, evi, evidences) + if sample: yield q_idx, sample, evi[Evidence.TYPE] + + +class DataReader(object): + def __iter__(self): + return self + + def _next(self): + raise NotImplemented() + + def next(self): + data_point = self._next() + return self.post_process_sample(data_point) + + def post_process_sample(self, sample): + ret = list(sample) + + # choose eecom features randomly + eecom_feats = random.choice(sample[EE_COMM]) + if not isinstance(eecom_feats[0], int): + # the other evidence is a negative evidence + eecom_feats = random.choice(eecom_feats) + ret[EE_COMM] = eecom_feats + + return ret + + +class TrainingDataReader(DataReader): + def __init__(self, sample_stream, negative_ratio, hit_ans_negative_ratio): + super(TrainingDataReader, self).__init__() + self.positive_data = [] + self.hit_ans_negative_data = [] + self.other_negative_data = [] + + self.negative_ratio = negative_ratio + self.hit_ans_negative_ratio = hit_ans_negative_ratio + + self.p_idx = 0 + self.hit_idx = 0 + self.other_idx = 0 + + self.load_samples(sample_stream) + + def add_data(self, positive, hit_negative, other_negative): + if not positive: return + self.positive_data.extend(positive) + for samples, target_list in \ + zip((hit_negative, other_negative), + (self.hit_ans_negative_data, self.other_negative_data)): + if not samples: continue + # `0" is an index, further refer to _next_negative_data() + target_list.append([samples, 0]) + + def load_samples(self, sample_stream): + logger.info("loading data...") + last_q_id, positive, hit_negative, other_negative = None, [], [], [] + for q_id, sample, type_ in sample_stream: + if not last_q_id and q_id != last_q_id: + self.add_data(positive, hit_negative, other_negative) + positive, hit_negative, other_negative = [], [], [] + + last_q_id = q_id + if type_ == Evidence.POSITIVE: + positive.append(sample) + elif type_ == Evidence.HIT_ANS_NEGATIVE: + hit_negative.append(sample) + elif type_ == Evidence.OTHER_NEGATIVE: + other_negative.append(sample) + else: + raise ValueError("wrong type: %s" % str(type_)) + self.add_data(positive, hit_negative, other_negative) + + # we are not sure whether the input data is shuffled or not + # so we shuffle them + random.shuffle(self.positive_data) + random.shuffle(self.hit_ans_negative_data) + random.shuffle(self.other_negative_data) + + # set thresholds + if len(self.positive_data) == 0: + logger.fatal("zero positive sample") + raise ValueError("zero positive sample") + + zero_hit = len(self.hit_ans_negative_data) == 0 + zero_other = len(self.other_negative_data) == 0 + + if zero_hit and zero_other: + logger.fatal("zero negative sample") + raise ValueError("zero negative sample") + + if zero_hit: + logger.warning("zero hit_ans_negative sample") + self.hit_ans_neg_threshold = 0 + else: + self.hit_ans_neg_threshold = \ + self.negative_ratio * self.hit_ans_negative_ratio + + self.other_neg_threshold = self.negative_ratio + if zero_other: + logger.warning("zero other_negative sample") + self.hit_ans_neg_threshold = self.negative_ratio + logger.info("loaded") + + def next_positive_data(self): + if self.p_idx >= len(self.positive_data): + random.shuffle(self.positive_data) + self.p_idx = 0 + + self.p_idx += 1 + return self.positive_data[self.p_idx - 1] + + def _next_negative_data(self, idx, negative_data): + if idx >= len(negative_data): + random.shuffle(negative_data) + idx = 0 + + # a negative evidence is sampled in two steps: + # step 1: sample a question uniformly + # step 2: sample a negative evidence corresponding to the question + # uniformly + # bundle -> (sample, idx) + bundle = negative_data[idx] + if bundle[1] >= len(bundle[0]): + random.shuffle(bundle[0]) + bundle[1] = 0 + bundle[1] += 1 + return idx + 1, bundle[0][bundle[1] - 1] + + def next_hit_ans_negative_data(self): + self.hit_idx, data = self._next_negative_data( + self.hit_idx, self.hit_ans_negative_data) + return data + + def next_other_negative_data(self): + self.other_idx, data = self._next_negative_data( + self.other_idx, self.other_negative_data) + return data + + def _next(self): + rand = random.random() + if rand <= self.hit_ans_neg_threshold: + return self.next_hit_ans_negative_data() + elif rand < self.other_neg_threshold: + return self.next_other_negative_data() + else: + return self.next_positive_data() + + +class TestDataReader(DataReader): + def __init__(self, sample_stream): + super(TestDataReader, self).__init__() + self.data_generator = iter(sample_stream) + + def _next(self): + q_idx, sample, type_ = self.data_generator.next() + return sample + + +def create_reader(filename, settings, samples_per_pass=sys.maxint): + if settings.is_training: + training_reader = TrainingDataReader( + SampleStream(filename, settings), settings.negative_sample_ratio, + settings.hit_ans_negative_sample_ratio) + + def wrapper(): + for i, data in izip(xrange(samples_per_pass), training_reader): + yield data + + return wrapper + else: + + def wrapper(): + sample_stream = SampleStream(filename, settings) + return TestDataReader(sample_stream) + + return wrapper diff --git a/neural_seq_qa/test/test_reader.py b/neural_seq_qa/test/test_reader.py new file mode 100644 index 0000000000..2c3725b503 --- /dev/null +++ b/neural_seq_qa/test/test_reader.py @@ -0,0 +1,110 @@ +import unittest +import os +import itertools +import math +import logging + +# set up python path +topdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +import sys +sys.path += [topdir, os.path.join(topdir, "data", "evaluation")] + +import reader +import utils + +formatter = logging.Formatter( + "[%(levelname)s %(asctime)s.%(msecs)d %(filename)s:%(lineno)d] %(message)s", + datefmt='%Y-%m-%d %I:%M:%S') +ch = logging.StreamHandler() +ch.setFormatter(formatter) +utils.logger.addHandler(ch) + + +class Vocab(object): + @property + def data(self): + word_dict_path = os.path.join(topdir, "data", "embedding", + "wordvecs.vcb") + return utils.load_dict(word_dict_path) + + +class NegativeSampleRatioTest(unittest.TestCase): + def check_ratio(self, negative_sample_ratio): + for keep_first_b in [True, False]: + settings = reader.Settings( + vocab=Vocab().data, + is_training=True, + label_schema="BIO2", + negative_sample_ratio=negative_sample_ratio, + hit_ans_negative_sample_ratio=0.25, + keep_first_b=keep_first_b) + + filename = os.path.join(topdir, "test", "trn_data.gz") + data_stream = reader.create_reader(filename, settings) + total, negative_num = 5000, 0 + for _, d in itertools.izip(xrange(total), data_stream()): + labels = d[reader.LABELS] + if labels.count(0) == 0: + negative_num += 1 + + ratio = negative_num / float(total) + self.assertLessEqual(math.fabs(ratio - negative_sample_ratio), 0.01) + + def runTest(self): + for ratio in [1., 0.25, 0.]: + self.check_ratio(ratio) + + +class KeepFirstBTest(unittest.TestCase): + def runTest(self): + for keep_first_b in [True, False]: + for label_schema in ["BIO", "BIO2"]: + settings = reader.Settings( + vocab=Vocab().data, + is_training=True, + label_schema=label_schema, + negative_sample_ratio=0.2, + hit_ans_negative_sample_ratio=0.25, + keep_first_b=keep_first_b) + + filename = os.path.join(topdir, "test", "trn_data.gz") + data_stream = reader.create_reader(filename, settings) + total, at_least_one, one = 1000, 0, 0 + for _, d in itertools.izip(xrange(total), data_stream()): + labels = d[reader.LABELS] + b_num = labels.count(0) + if b_num >= 1: + at_least_one += 1 + if b_num == 1: + one += 1 + + self.assertLess(at_least_one, total) + if keep_first_b: + self.assertEqual(one, at_least_one) + else: + self.assertLess(one, at_least_one) + + +class DictTest(unittest.TestCase): + def runTest(self): + settings = reader.Settings( + vocab=Vocab().data, + is_training=True, + label_schema="BIO2", + negative_sample_ratio=0.2, + hit_ans_negative_sample_ratio=0.25, + keep_first_b=True) + + filename = os.path.join(topdir, "test", "trn_data.gz") + data_stream = reader.create_reader(filename, settings) + q_uniq_ids, e_uniq_ids = set(), set() + for _, d in itertools.izip(xrange(1000), data_stream()): + q_uniq_ids.update(d[reader.Q_IDS]) + e_uniq_ids.update(d[reader.E_IDS]) + + self.assertGreater(len(q_uniq_ids), 50) + self.assertGreater(len(e_uniq_ids), 50) + + +if __name__ == '__main__': + unittest.main() diff --git a/neural_seq_qa/test/trn_data.gz b/neural_seq_qa/test/trn_data.gz new file mode 100644 index 0000000000..dcfc4b09c4 Binary files /dev/null and b/neural_seq_qa/test/trn_data.gz differ diff --git a/neural_seq_qa/train.py b/neural_seq_qa/train.py new file mode 100644 index 0000000000..fb7178575d --- /dev/null +++ b/neural_seq_qa/train.py @@ -0,0 +1,152 @@ +import sys +import os +import argparse +import numpy as np + +import paddle.v2 as paddle + +import reader +import utils +import network +import config +from utils import logger + + +def save_model(trainer, model_save_dir, parameters, pass_id): + f = os.path.join(model_save_dir, "params_pass_%05d.tar.gz" % pass_id) + logger.info("model saved to %s" % f) + with utils.open_file(f, "w") as f: + trainer.save_parameter_to_tar(f) + + +def show_parameter_init_info(parameters): + """ + Print the information of initialization mean and standard deviation of parameters + + :param parameters: the parameters created in a model + """ + logger.info("Parameter init info:") + for p in parameters: + p_val = parameters.get(p) + logger.info(("%-25s : initial_mean=%-7.4f initial_std=%-7.4f " + "actual_mean=%-7.4f actual_std=%-7.4f dims=%s") % + (p, parameters.__param_conf__[p].initial_mean, + parameters.__param_conf__[p].initial_std, p_val.mean(), + p_val.std(), parameters.__param_conf__[p].dims)) + logger.info("\n") + + +def show_parameter_status(parameters): + """ + Print some statistical information of parameters in a network + + :param parameters: the parameters created in a model + """ + for p in parameters: + abs_val = np.abs(parameters.get(p)) + abs_grad = np.abs(parameters.get_grad(p)) + + logger.info( + ("%-25s avg_abs_val=%-10.6f max_val=%-10.6f avg_abs_grad=%-10.6f " + "max_grad=%-10.6f min_val=%-10.6f min_grad=%-10.6f") % + (p, abs_val.mean(), abs_val.max(), abs_grad.mean(), abs_grad.max(), + abs_val.min(), abs_grad.min())) + + +def train(conf): + if not os.path.exists(conf.model_save_dir): + os.makedirs(conf.model_save_dir, mode=0755) + + settings = reader.Settings( + vocab=conf.vocab, + is_training=True, + label_schema=conf.label_schema, + negative_sample_ratio=conf.negative_sample_ratio, + hit_ans_negative_sample_ratio=conf.hit_ans_negative_sample_ratio, + keep_first_b=conf.keep_first_b, + seed=conf.seed) + samples_per_pass = conf.batch_size * conf.batches_per_pass + train_reader = paddle.batch( + paddle.reader.buffered( + reader.create_reader(conf.train_data_path, settings, + samples_per_pass), + size=samples_per_pass), + batch_size=conf.batch_size) + + # TODO(lipeng17) v2 API does not support parallel_nn yet. Therefore, we can + # only use CPU currently + paddle.init( + use_gpu=conf.use_gpu, + trainer_count=conf.trainer_count, + seed=conf.paddle_seed) + + # network config + cost = network.training_net(conf) + + # create parameters + # NOTE: parameter values are not initilized here, therefore, we need to + # print parameter initialization info in the beginning of the first batch + parameters = paddle.parameters.create(cost) + + # create optimizer + rmsprop_optimizer = paddle.optimizer.RMSProp( + learning_rate=conf.learning_rate, + rho=conf.rho, + epsilon=conf.epsilon, + model_average=paddle.optimizer.ModelAverage( + average_window=conf.average_window, + max_average_window=conf.max_average_window)) + + # create trainer + trainer = paddle.trainer.SGD( + cost=cost, parameters=parameters, update_equation=rmsprop_optimizer) + + # begin training network + def _event_handler(event): + """ + Define end batch and end pass event handler + """ + if isinstance(event, paddle.event.EndIteration): + sys.stderr.write(".") + batch_num = event.batch_id + 1 + total_batch = conf.batches_per_pass * event.pass_id + batch_num + if batch_num % conf.log_period == 0: + sys.stderr.write("\n") + logger.info("Total batch=%d Batch=%d CurrentCost=%f Eval: %s" \ + % (total_batch, batch_num, event.cost, event.metrics)) + + if batch_num % conf.show_parameter_status_period == 0: + show_parameter_status(parameters) + elif isinstance(event, paddle.event.EndPass): + save_model(trainer, conf.model_save_dir, parameters, event.pass_id) + elif isinstance(event, paddle.event.BeginIteration): + if event.batch_id == 0 and event.pass_id == 0: + show_parameter_init_info(parameters) + + ## for debugging purpose + #with utils.open_file("config", "w") as config: + # print >> config, paddle.layer.parse_network(cost) + + trainer.train( + reader=train_reader, + event_handler=_event_handler, + feeding=network.feeding, + num_passes=conf.num_passes) + + logger.info("Training has finished.") + + +def main(): + conf = config.TrainingConfig() + + logger.info("loading word embeddings...") + conf.vocab, conf.wordvecs = utils.load_wordvecs(conf.word_dict_path, + conf.wordvecs_path) + logger.info("loaded") + logger.info("length of word dictionary is : %d." % len(conf.vocab)) + + train(conf) + + +if __name__ == "__main__": + main() diff --git a/neural_seq_qa/utils.py b/neural_seq_qa/utils.py new file mode 100644 index 0000000000..28527dedef --- /dev/null +++ b/neural_seq_qa/utils.py @@ -0,0 +1,104 @@ +import argparse +import gzip +import logging +import sys +import numpy + +__all__ = [ + "open_file", "cumsum", "logger", "DotBar", "load_dict", "load_wordvecs" +] + +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + + +def open_file(filename, *args1, **args2): + """ + Open a file + + :param filename: name of the file + :type filename: str + :return: a file handler + """ + if filename.endswith(".gz"): + return gzip.open(filename, *args1, **args2) + else: + return open(filename, *args1, **args2) + + +def cumsum(array): + """ + Caculute the accumulated sum of array. For example, array=[1, 2, 3], the + result is [1, 1+2, 1+2+3] + + :param array: input array + :type array: python list or numpy array + :return: the accumulated sum of array + """ + if len(array) <= 1: + return list(array) + ret = list(array) + for i in xrange(1, len(ret)): + ret[i] += ret[i - 1] + return ret + + +class DotBar(object): + """ + A simple dot bar + """ + + def __init__(self, obj, step=200, dots_per_line=50, f=sys.stderr): + """ + :param obj: an iteratable obj + :type obj: a python itertor + :param step: print a dot every step iterations + :type step: int + :param dots_per_line: dots each line + :type dots_per_line: int + :param f: print dot to f, default value is sys.stderr + :type f: a file handler + """ + self.obj = obj + self.step = step + self.dots_per_line = dots_per_line + self.f = f + + def __enter__( + self, ): + self.obj.__enter__() + self.idx = 0 + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.f.write("\n") + if self.obj is sys.stdin or self.obj is sys.stdout: + return + self.obj.__exit__(exc_type, exc_value, traceback) + + def __iter__(self): + return self + + def next(self): + self.idx += 1 + if self.idx % self.step == 0: + self.f.write(".") + if self.idx % (self.step * self.dots_per_line) == 0: + self.f.write("\n") + + return self.obj.next() + + +def load_dict(word_dict_path): + with open_file(word_dict_path) as f: + # the first word must be OOV + vocab = {k.rstrip("\n").split()[0].decode("utf-8"):i \ + for i, k in enumerate(f)} + return vocab + + +def load_wordvecs(word_dict_path, wordvecs_path): + vocab = load_dict(word_dict_path) + wordvecs = numpy.loadtxt(wordvecs_path, delimiter=",", dtype="float32") + assert len(vocab) == wordvecs.shape[0] + return vocab, wordvecs diff --git a/neural_seq_qa/val_and_test.py b/neural_seq_qa/val_and_test.py new file mode 100644 index 0000000000..0285f8d68f --- /dev/null +++ b/neural_seq_qa/val_and_test.py @@ -0,0 +1,183 @@ +import os +import sys +import argparse +import time +import traceback +import subprocess +import re + +import utils +import infer +import config +from utils import logger + + +def load_existing_results(eval_result_file): + evals = {} + with utils.open_file(eval_result_file) as f: + for line in f: + line = line.strip() + if not line: continue + pos = line.find(" ") + pass_id, ret = int(line[len("Pass="):pos]), line[pos + 1:] + evals[pass_id] = ret + return evals + + +__PATTERN_CHUNK_F1 = re.compile("chunk_f1=(\d+(\.\d+)?)") + + +def find_best_pass(evals): + results = [] + for pass_id, eval_ret in evals.iteritems(): + chunk_f1 = float(__PATTERN_CHUNK_F1.search(eval_ret).group(1)) + results.append((pass_id, chunk_f1)) + + results.sort(key=lambda item: (-item[1], item[0])) + return results[0][0] + + +def eval_one_pass(infer_obj, conf, model_path, data_path, eval_script): + if not os.path.exists("tmp"): os.makedirs("tmp") + # model file is not ready + if not os.path.exists(model_path): return False + + output_path = os.path.join("tmp", "%s_%s.txt.gz" % ( + os.path.basename(model_path), os.path.basename(data_path))) + with utils.open_file(output_path, "w") as output: + try: + infer_obj.infer(model_path, data_path, output) + except Exception as ex: + traceback.print_exc() + return None + + cmd = [ + "python", eval_script, output_path, data_path, "--fuzzy", "--schema", + conf.label_schema + ] + logger.info("cmd: %s" % " ".join(cmd)) + eval_ret = subprocess.check_output(cmd) + if "chunk_f1" not in eval_ret: + raise ValueError("Unknown error in cmd \"%s\"" % " ".join(cmd)) + + return eval_ret + + +def run_eval(infer_obj, + conf, + model_dir, + input_path, + eval_script, + log_file, + start_pass_id, + end_pass_id, + force_rerun=False): + if not force_rerun and os.path.exists(log_file): + evals = load_existing_results(log_file) + else: + evals = {} + with utils.open_file(log_file, "w") as log: + for i in xrange(start_pass_id, end_pass_id + 1): + if i in evals: + eval_ret = evals[i] + else: + pass_id = "%05d" % i + model_path = os.path.join(model_dir, + "params_pass_%s.tar.gz" % pass_id) + logger.info("Waiting for model %s ..." % model_path) + while True: + eval_ret = eval_one_pass(infer_obj, conf, model_path, + input_path, eval_script) + if eval_ret: + evals[i] = eval_ret + break + + # wait for one minute and rerun + time.sleep(60) + print >> log, "Pass=%d %s" % (i, eval_ret.rstrip()) + log.flush() + return evals + + +def parse_cmd(): + parser = argparse.ArgumentParser() + parser.add_argument("model_dir") + parser.add_argument("data_type", choices=["ann", "ir"], default="ann") + parser.add_argument( + "--val_eval_output", help="validation set evaluation result file") + parser.add_argument( + "--tst_eval_output", help="test set evaluation result file") + parser.add_argument("--start_pass_id", type=int, default=0) + parser.add_argument( + "--end_pass_id", type=int, default=24, help="this pass is included") + parser.add_argument("--force_rerun", action="store_true") + return parser.parse_args() + + +__eval_scripts = { + "ann": "data/evaluation/evaluate-tagging-result.py", + "ir": "data/evaluation/evaluate-voting-result.py", +} + +__val_data = { + "ann": "./data/data/validation.ann.json.gz", + "ir": "./data/data/validation.ir.json.gz", +} + +__tst_data = { + "ann": "./data/data/test.ann.json.gz", + "ir": "./data/data/test.ir.json.gz", +} + + +def main(args): + conf = config.InferConfig() + conf.vocab = utils.load_dict(conf.word_dict_path) + logger.info("length of word dictionary is : %d." % len(conf.vocab)) + + if args.val_eval_output: + val_eval_output = args.val_eval_output + else: + val_eval_output = "eval.val.%s.txt" % args.data_type + + if args.tst_eval_output: + tst_eval_output = args.tst_eval_output + else: + tst_eval_output = "eval.tst.%s.txt" % args.data_type + + eval_script = __eval_scripts[args.data_type] + val_data_file = __val_data[args.data_type] + tst_data_file = __tst_data[args.data_type] + + infer_obj = infer.Infer(conf) + val_evals = run_eval( + infer_obj, + conf, + args.model_dir, + val_data_file, + eval_script, + val_eval_output, + args.start_pass_id, + args.end_pass_id, + force_rerun=args.force_rerun) + + best_pass_id = find_best_pass(val_evals) + + tst_evals = run_eval( + infer_obj, + conf, + args.model_dir, + tst_data_file, + eval_script, + tst_eval_output, + start_pass_id=best_pass_id, + end_pass_id=best_pass_id, + force_rerun=args.force_rerun) + + logger.info("Best Pass=%d" % best_pass_id) + logger.info("Validation: %s" % val_evals[best_pass_id]) + logger.info("Test : %s" % tst_evals[best_pass_id]) + + +if __name__ == "__main__": + main(parse_cmd())