Skip to content

Commit

Permalink
Added support to train only classifier layer (transfer learning scena…
Browse files Browse the repository at this point in the history
…rio)
  • Loading branch information
gawy committed Jun 15, 2020
1 parent cb2ba2d commit 144fc6e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
10 changes: 9 additions & 1 deletion stanza/models/ner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def unpack_batch(batch, use_cuda):

class Trainer(BaseTrainer):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=None, use_cuda=False):
def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=None, use_cuda=False,
train_classifier_only=False):
self.use_cuda = use_cuda
if model_file is not None:
# load everything from file
Expand All @@ -43,6 +44,13 @@ def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=N
self.args = args
self.vocab = vocab
self.model = NERTagger(args, vocab, emb_matrix=pretrain_emb_matrix)

if train_classifier_only:
logger.info('Disabling gradient for non-classifier layers')
exclude = ['tag_clf', 'crit']
for pname, p in self.model.named_parameters():
if pname.split('.')[0] not in exclude:
p.requires_grad = False
self.parameters = [p for p in self.model.parameters() if p.requires_grad]
if self.use_cuda:
self.model.cuda()
Expand Down
16 changes: 11 additions & 5 deletions stanza/models/ner_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import json
import torch
from stanza.models.ner.vocab import TagVocab
from torch import nn, optim

from stanza.models.ner.data import DataLoader
Expand All @@ -38,7 +39,9 @@ def parse_args():
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')

parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--finetune', action='store_false', help='Load existing model during `train` mode from `save_dir` path')
parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path')
parser.add_argument('--train_classifier_only', action='store_true',
help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.')
parser.add_argument('--lang', type=str, help='Language')
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")

Expand Down Expand Up @@ -112,9 +115,11 @@ def train(args):
else '{}/{}_nertagger.pt'.format(args['save_dir'], args['shorthand'])

pretrain_vocab = None
vocab = None
trainer = None
if args['finetune'] and os.path.exists(model_file):
logger.warning('Finetune is ON. Using model from "{}"'.format(model_file))
_, trainer, pretrain_vocab = load_model(args, model_file)
_, trainer, vocab = load_model(args, model_file)
else:
if args['finetune']:
logger.warning('Finetune is set to true but model file is not found. Continuing with training from scratch.')
Expand All @@ -139,7 +144,7 @@ def train(args):
# load data
logger.info("Loading data with batch size {}...".format(args['batch_size']))
train_doc = Document(json.load(open(args['train_file'])))
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain_vocab, evaluation=False)
train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain_vocab, vocab=vocab, evaluation=False)
vocab = train_batch.vocab
dev_doc = Document(json.load(open(args['eval_file'])))
dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain_vocab, vocab=vocab, evaluation=True)
Expand All @@ -152,7 +157,8 @@ def train(args):

logger.info("Training tagger...")
if trainer is None: # init if model was not loaded previously from file
trainer = Trainer(args=args, vocab=vocab, pretrain_emb_matrix=pretrain.emb, use_cuda=args['cuda'])
trainer = Trainer(args=args, vocab=vocab, pretrain_emb_matrix=pretrain.emb, use_cuda=args['cuda'],
train_classifier_only=args['train_classifier_only'])
logger.info(trainer.model)

global_step = 0
Expand Down Expand Up @@ -253,7 +259,7 @@ def evaluate(args):
def load_model(args, model_file):
# load model
use_cuda = args['cuda'] and not args['cpu']
trainer = Trainer(model_file=model_file, use_cuda=use_cuda)
trainer = Trainer(model_file=model_file, use_cuda=use_cuda, train_classifier_only=args['train_classifier_only'])
loaded_args, vocab = trainer.args, trainer.vocab

# load config
Expand Down

0 comments on commit 144fc6e

Please sign in to comment.