From 144fc6e576cf59ab82f72313ef80d49a8d6cf089 Mon Sep 17 00:00:00 2001 From: Andrii Garkavyi Date: Fri, 5 Jun 2020 15:55:41 +0300 Subject: [PATCH] Added support to train only classifier layer (transfer learning scenario) --- stanza/models/ner/trainer.py | 10 +++++++++- stanza/models/ner_tagger.py | 16 +++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/stanza/models/ner/trainer.py b/stanza/models/ner/trainer.py index 6504483dd4..920acaa5c8 100644 --- a/stanza/models/ner/trainer.py +++ b/stanza/models/ner/trainer.py @@ -32,7 +32,8 @@ def unpack_batch(batch, use_cuda): class Trainer(BaseTrainer): """ A trainer for training models. """ - def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=None, use_cuda=False): + def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=None, use_cuda=False, + train_classifier_only=False): self.use_cuda = use_cuda if model_file is not None: # load everything from file @@ -43,6 +44,13 @@ def __init__(self, args=None, vocab=None, pretrain_emb_matrix=None, model_file=N self.args = args self.vocab = vocab self.model = NERTagger(args, vocab, emb_matrix=pretrain_emb_matrix) + + if train_classifier_only: + logger.info('Disabling gradient for non-classifier layers') + exclude = ['tag_clf', 'crit'] + for pname, p in self.model.named_parameters(): + if pname.split('.')[0] not in exclude: + p.requires_grad = False self.parameters = [p for p in self.model.parameters() if p.requires_grad] if self.use_cuda: self.model.cuda() diff --git a/stanza/models/ner_tagger.py b/stanza/models/ner_tagger.py index 3b64f9182f..5b4bccc72b 100644 --- a/stanza/models/ner_tagger.py +++ b/stanza/models/ner_tagger.py @@ -16,6 +16,7 @@ import random import json import torch +from stanza.models.ner.vocab import TagVocab from torch import nn, optim from stanza.models.ner.data import DataLoader @@ -38,7 +39,9 @@ def parse_args(): parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.') parser.add_argument('--mode', default='train', choices=['train', 'predict']) - parser.add_argument('--finetune', action='store_false', help='Load existing model during `train` mode from `save_dir` path') + parser.add_argument('--finetune', action='store_true', help='Load existing model during `train` mode from `save_dir` path') + parser.add_argument('--train_classifier_only', action='store_true', + help='In case of applying Transfer-learning approach and training only the classifier layer this will freeze gradient propagation for all other layers.') parser.add_argument('--lang', type=str, help='Language') parser.add_argument('--shorthand', type=str, help="Treebank shorthand") @@ -112,9 +115,11 @@ def train(args): else '{}/{}_nertagger.pt'.format(args['save_dir'], args['shorthand']) pretrain_vocab = None + vocab = None + trainer = None if args['finetune'] and os.path.exists(model_file): logger.warning('Finetune is ON. Using model from "{}"'.format(model_file)) - _, trainer, pretrain_vocab = load_model(args, model_file) + _, trainer, vocab = load_model(args, model_file) else: if args['finetune']: logger.warning('Finetune is set to true but model file is not found. Continuing with training from scratch.') @@ -139,7 +144,7 @@ def train(args): # load data logger.info("Loading data with batch size {}...".format(args['batch_size'])) train_doc = Document(json.load(open(args['train_file']))) - train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain_vocab, evaluation=False) + train_batch = DataLoader(train_doc, args['batch_size'], args, pretrain_vocab, vocab=vocab, evaluation=False) vocab = train_batch.vocab dev_doc = Document(json.load(open(args['eval_file']))) dev_batch = DataLoader(dev_doc, args['batch_size'], args, pretrain_vocab, vocab=vocab, evaluation=True) @@ -152,7 +157,8 @@ def train(args): logger.info("Training tagger...") if trainer is None: # init if model was not loaded previously from file - trainer = Trainer(args=args, vocab=vocab, pretrain_emb_matrix=pretrain.emb, use_cuda=args['cuda']) + trainer = Trainer(args=args, vocab=vocab, pretrain_emb_matrix=pretrain.emb, use_cuda=args['cuda'], + train_classifier_only=args['train_classifier_only']) logger.info(trainer.model) global_step = 0 @@ -253,7 +259,7 @@ def evaluate(args): def load_model(args, model_file): # load model use_cuda = args['cuda'] and not args['cpu'] - trainer = Trainer(model_file=model_file, use_cuda=use_cuda) + trainer = Trainer(model_file=model_file, use_cuda=use_cuda, train_classifier_only=args['train_classifier_only']) loaded_args, vocab = trainer.args, trainer.vocab # load config