-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
31 lines (25 loc) · 998 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from vncorenlp import VnCoreNLP
from nltk.tokenize import TweetTokenizer
from pandas import DataFrame
import re
import torch
import json
import os
import numpy as np
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def save_checkpoint(model, tokenizer, checkpoint_path, epoch='best'):
torch.save(model.state_dict(), os.path.join(
checkpoint_path, f'model_{epoch}.bin'))
model.config.to_json_file(os.path.join(checkpoint_path, 'config.json'))
tokenizer.save_vocabulary(checkpoint_path)
def convert_tokens_to_ids(text, tokenizer, max_len = 256):
inputs = tokenizer.encode_plus(text, padding = "max_length", max_length = max_len, truncation = True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return torch.tensor(input_ids, dtype = torch.long), torch.tensor(attention_mask, dtype = torch.long)