Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into kv_map
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored Jun 22, 2017
2 parents 6e97bcd + bf4d774 commit 8cefba3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
29 changes: 28 additions & 1 deletion R-package/R/lstm.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ lstm.inference.symbol <- function(num.lstm.layer, input.size,
#' A number in [0,1) containing the dropout ratio from the last hidden layer to the output layer.
#' @param optimizer string, default="sgd"
#' The optimization method.
#' @param epoch.end.callback function, optional
#' The callback when iteration ends.
#' @param batch.end.callback function, optional
#' The callback when one mini-batch iteration ends.
#' @param ... other parameters passing to \code{mx.lstm}/.
#' @return model A trained lstm unrolled model.
#'
Expand All @@ -193,19 +197,29 @@ mx.lstm <- function(train.data, eval.data=NULL,
num.round=10, update.period=1,
initializer=mx.init.uniform(0.01),
dropout=0, optimizer='sgd',
epoch.end.callback=NULL, batch.end.callback=NULL,
model,
arg.params,
...) {
# check data and change data into iterator
train.data <- check.data(train.data, batch.size, TRUE)
eval.data <- check.data(eval.data, batch.size, FALSE)



# get unrolled lstm symbol
rnn.sym <- lstm.unroll(num.lstm.layer=num.lstm.layer,
if(missing(model)){
rnn.sym <- lstm.unroll(num.lstm.layer=num.lstm.layer,
num.hidden=num.hidden,
seq.len=seq.len,
input.size=input.size,
num.embed=num.embed,
num.label=num.label,
dropout=dropout)
} else {
rnn.sym=model$symbol
}

init.states.c <- lapply(1:num.lstm.layer, function(i) {
state.c <- paste0("l", i, ".init.c")
return (state.c)
Expand All @@ -229,13 +243,26 @@ mx.lstm <- function(train.data, eval.data=NULL,
init.states.name=init.states.name,
initializer=initializer,
dropout=dropout)
# restore states
if (!missing(arg.params)){
arg.names <- names(model$rnn.exec$ref.arg.arrays)
for (k in names(arg.params)) {
if ((k %in% arg.names) && is.param.name(k) ) {
rnn.input <- list()
rnn.input[[k]] <- arg.params[[k]]
mx.exec.update.arg.arrays(model$rnn.exec, rnn.input, match.name=TRUE)
}
}
}

# train lstm model
model <- train.rnn( model, train.data, eval.data,
num.round=num.round,
update.period=update.period,
ctx=ctx,
init.states.name=init.states.name,
epoch.end.callback=epoch.end.callback,
batch.end.callback=batch.end.callback,
...)
# change model into MXFeedForwardModel
model <- list(symbol=model$symbol, arg.params=model$rnn.exec$ref.arg.arrays, aux.params=model$rnn.exec$ref.aux.arrays)
Expand Down
26 changes: 25 additions & 1 deletion R-package/R/rnn_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,16 @@ get.label <- function(label, ctx) {
train.rnn <- function (model, train.data, eval.data,
num.round, update.period,
init.states.name,
optimizer='sgd', ctx=mx.ctx.default(), ...) {
optimizer='sgd', ctx=mx.ctx.default(),
epoch.end.callback,
batch.end.callback,
verbose=TRUE,
...) {
m <- model

model <- list(symbol=model$symbol, arg.params=model$rnn.exec$ref.arg.arrays,
aux.params=model$rnn.exec$ref.aux.arrays)

seq.len <- m$seq.len
batch.size <- m$batch.size
num.rnn.layer <- m$num.rnn.layer
Expand Down Expand Up @@ -173,6 +181,11 @@ train.rnn <- function (model, train.data, eval.data,
train.nll <- train.nll + calc.nll(as.array(seq.label.probs), batch.size)

nbatch <- nbatch + seq.len

if (!is.null(batch.end.callback)) {
batch.end.callback(iteration, nbatch, environment())
}

if ((epoch.counter %% log.period) == 0) {
message(paste0("Epoch [", epoch.counter,
"] Train: NLL=", train.nll / nbatch,
Expand Down Expand Up @@ -220,6 +233,17 @@ train.rnn <- function (model, train.data, eval.data,
"] Val: NLL=", val.nll / nbatch,
", Perp=", exp(val.nll / nbatch)))
}
# get the model out


epoch_continue <- TRUE
if (!is.null(epoch.end.callback)) {
epoch_continue <- epoch.end.callback(iteration, 0, environment(), verbose = verbose)
}

if (!epoch_continue) {
break
}
}

return (m)
Expand Down

0 comments on commit 8cefba3

Please sign in to comment.