Skip to content

Commit

Permalink
Merge pull request #36 from xinghai-sun/mt_with_external_memory
Browse files Browse the repository at this point in the history
Add model configuration for machine translation with external memory.
  • Loading branch information
lcy-seso authored Sep 13, 2017
2 parents cd16af8 + 54afcbc commit 717ccf5
Show file tree
Hide file tree
Showing 10 changed files with 1,250 additions and 1 deletion.
474 changes: 473 additions & 1 deletion mt_with_external_memory/README.md

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions mt_with_external_memory/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Contains data utilities.
"""


def reader_append_wrapper(reader, append_tuple):
"""
Data reader wrapper for appending extra data to exisiting reader.
"""

def new_reader():
for ins in reader():
yield ins + append_tuple

return new_reader
191 changes: 191 additions & 0 deletions mt_with_external_memory/external_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
External neural memory class.
"""
import paddle.v2 as paddle


class ExternalMemory(object):
"""External neural memory class.
A simplified Neural Turing Machines (NTM) with only content-based
addressing (including content addressing and interpolation, but excluding
convolutional shift and sharpening). It serves as an external differential
memory bank, with differential write/read head controllers to store
and read information dynamically. Simple feedforward networks are
used as the write/read head controllers.
The ExternalMemory class could be utilized by many neural network structures
to easily expand their memory bandwidth and accomplish a long-term memory
handling. Besides, some existing mechanism can be realized directly with
the ExternalMemory class, e.g. the attention mechanism in Seq2Seq (i.e. an
unbounded external memory).
Besides, the ExternalMemory class must be used together with
paddle.layer.recurrent_group (within its step function). It can never be
used in a standalone manner.
For more details, please refer to
`Neural Turing Machines <https://arxiv.org/abs/1410.5401>`_.
:param name: Memory name.
:type name: basestring
:param mem_slot_size: Size of memory slot/vector.
:type mem_slot_size: int
:param boot_layer: Boot layer for initializing the external memory. The
sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false.
:type readonly: bool
:param enable_interpolation: If set true, the read/write addressing weights
will be interpolated with the weights in the
last step, with the affine coefficients being
a learnable gate function.
:type enable_interpolation: bool
"""

def __init__(self,
name,
mem_slot_size,
boot_layer,
readonly=False,
enable_interpolation=True):
self.name = name
self.mem_slot_size = mem_slot_size
self.readonly = readonly
self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
# set memory to constant when readonly=True
if self.readonly:
self.updated_external_memory = paddle.layer.mixed(
name=self.name,
input=[
paddle.layer.identity_projection(input=self.external_memory)
],
size=self.mem_slot_size)

def _content_addressing(self, key_vector):
"""Get write/read head's addressing weights via content-based addressing.
"""
# content-based addressing: a=tanh(W*M + U*key)
key_projection = paddle.layer.fc(
input=key_vector,
size=self.mem_slot_size,
act=paddle.activation.Linear(),
bias_attr=False)
key_proj_expanded = paddle.layer.expand(
input=key_projection, expand_as=self.external_memory)
memory_projection = paddle.layer.fc(
input=self.external_memory,
size=self.mem_slot_size,
act=paddle.activation.Linear(),
bias_attr=False)
merged_projection = paddle.layer.addto(
input=[key_proj_expanded, memory_projection],
act=paddle.activation.Tanh())
# softmax addressing weight: w=softmax(v^T a)
addressing_weight = paddle.layer.fc(
input=merged_projection,
size=1,
act=paddle.activation.SequenceSoftmax(),
bias_attr=False)
return addressing_weight

def _interpolation(self, head_name, key_vector, addressing_weight):
"""Interpolate between previous and current addressing weights.
"""
# prepare interpolation scalar gate: g=sigmoid(W*key)
gate = paddle.layer.fc(
input=key_vector,
size=1,
act=paddle.activation.Sigmoid(),
bias_attr=False)
# interpolation: w_t = g*w_t+(1-g)*w_{t-1}
last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name,
size=1,
boot_layer=self.zero_addressing_init)
interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight],
weight=paddle.layer.expand(input=gate, expand_as=addressing_weight))
return interpolated_weight

def _get_addressing_weight(self, head_name, key_vector):
"""Get final addressing weights for read/write heads, including content
addressing and interpolation.
"""
# current content-based addressing
addressing_weight = self._content_addressing(key_vector)
# interpolation with previous addresing weight
if self.enable_interpolation:
return self._interpolation(head_name, key_vector, addressing_weight)
else:
return addressing_weight

def write(self, write_key):
"""Write onto the external memory.
It cannot be called if "readonly" set True.
:param write_key: Key vector for write heads to generate writing
content and addressing signals.
:type write_key: LayerOutput
"""
# check readonly
if self.readonly:
raise ValueError("ExternalMemory with readonly=True cannot write.")
# get addressing weight for write head
write_weight = self._get_addressing_weight("write_head", write_key)
# prepare add_vector and erase_vector
erase_vector = paddle.layer.fc(
input=write_key,
size=self.mem_slot_size,
act=paddle.activation.Sigmoid(),
bias_attr=False)
add_vector = paddle.layer.fc(
input=write_key,
size=self.mem_slot_size,
act=paddle.activation.Sigmoid(),
bias_attr=False)
erase_vector_expand = paddle.layer.expand(
input=erase_vector, expand_as=self.external_memory)
add_vector_expand = paddle.layer.expand(
input=add_vector, expand_as=self.external_memory)
# prepare scaled add part and erase part
scaled_erase_vector_expand = paddle.layer.scaling(
weight=write_weight, input=erase_vector_expand)
erase_memory_part = paddle.layer.mixed(
input=paddle.layer.dotmul_operator(
a=self.external_memory,
b=scaled_erase_vector_expand,
scale=-1.0))
add_memory_part = paddle.layer.scaling(
weight=write_weight, input=add_vector_expand)
# update external memory
self.updated_external_memory = paddle.layer.addto(
input=[self.external_memory, add_memory_part, erase_memory_part],
name=self.name)

def read(self, read_key):
"""Read from the external memory.
:param write_key: Key vector for read head to generate addressing
signals.
:type write_key: LayerOutput
:return: Content (vector) read from external memory.
:rtype: LayerOutput
"""
# get addressing weight for write head
read_weight = self._get_addressing_weight("read_head", read_key)
# read content from external memory
scaled = paddle.layer.scaling(
weight=read_weight, input=self.updated_external_memory)
return paddle.layer.pooling(
input=scaled, pooling_type=paddle.pooling.Sum())
Binary file added mt_with_external_memory/image/lstm_c_state.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
155 changes: 155 additions & 0 deletions mt_with_external_memory/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Contains infering script for machine translation with external memory.
"""
import distutils.util
import argparse
import gzip

import paddle.v2 as paddle
from external_memory import ExternalMemory
from model import memory_enhanced_seq2seq
from data_utils import reader_append_wrapper

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dict_size",
default=30000,
type=int,
help="Vocabulary size. (default: %(default)s)")
parser.add_argument(
"--word_vec_dim",
default=512,
type=int,
help="Word embedding size. (default: %(default)s)")
parser.add_argument(
"--hidden_size",
default=1024,
type=int,
help="Hidden cell number in RNN. (default: %(default)s)")
parser.add_argument(
"--memory_slot_num",
default=8,
type=int,
help="External memory slot number. (default: %(default)s)")
parser.add_argument(
"--beam_size",
default=3,
type=int,
help="Beam search width. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=False,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=1,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--batch_size",
default=5,
type=int,
help="Batch size. (default: %(default)s)")
parser.add_argument(
"--infer_data_num",
default=3,
type=int,
help="Instance num to infer. (default: %(default)s)")
parser.add_argument(
"--model_filepath",
default="checkpoints/params.latest.tar.gz",
type=str,
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--memory_perturb_stddev",
default=0.1,
type=float,
help="Memory perturb stddev for memory initialization."
"(default: %(default)s)")
args = parser.parse_args()


def parse_beam_search_result(beam_result, dictionary):
"""
Beam search result parser.
"""
sentence_list = []
sentence = []
for word in beam_result[1]:
if word != -1:
sentence.append(word)
else:
sentence_list.append(
' '.join([dictionary.get(word) for word in sentence[1:]]))
sentence = []
beam_probs = beam_result[0]
beam_size = len(beam_probs[0])
beam_sentences = [
sentence_list[i:i + beam_size]
for i in range(0, len(sentence_list), beam_size)
]
return beam_probs, beam_sentences


def infer():
"""
For inferencing.
"""
# create network config
source_words = paddle.layer.data(
name="source_words",
type=paddle.data_type.integer_value_sequence(args.dict_size))
beam_gen = memory_enhanced_seq2seq(
encoder_input=source_words,
decoder_input=None,
decoder_target=None,
hidden_size=args.hidden_size,
word_vec_dim=args.word_vec_dim,
dict_size=args.dict_size,
is_generating=True,
beam_size=args.beam_size)

# load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_filepath))

# prepare infer data
infer_data = []
random.seed(0) # for keeping consitancy for multiple runs
bounded_memory_perturbation = [[
random.gauss(0, memory_perturb_stddev) for i in xrange(args.hidden_size)
] for j in xrange(args.memory_slot_num)]
test_append_reader = reader_append_wrapper(
reader=paddle.dataset.wmt14.test(dict_size),
append_tuple=(bounded_memory_perturbation, ))
for i, item in enumerate(test_append_reader()):
if i < args.infer_data_num:
infer_data.append((item[0], item[3], ))

# run inference
beam_result = paddle.infer(
output_layer=beam_gen,
parameters=parameters,
input=infer_data,
field=['prob', 'id'])

# parse beam result and print
source_dict, target_dict = paddle.dataset.wmt14.get_dict(dict_size)
beam_probs, beam_sentences = parse_beam_search_result(beam_result,
target_dict)
for i in xrange(args.infer_data_num):
print "\n***************************************************\n"
print "src:", ' '.join(
[source_dict.get(word) for word in infer_data[i][0]]), "\n"
for j in xrange(args.beam_size):
print "prob = %f : %s" % (beam_probs[i][j], beam_sentences[i][j])


def main():
paddle.init(use_gpu=False, trainer_count=1)
infer()


if __name__ == '__main__':
main()
Loading

0 comments on commit 717ccf5

Please sign in to comment.