-
Notifications
You must be signed in to change notification settings - Fork 2
/
seq2attn_example.sh
executable file
·72 lines (68 loc) · 2.12 KB
/
seq2attn_example.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# DATA
DATA_PATH="../machine-tasks/LookupTables/data/samples/sample1"
TRAIN_PATH="${DATA_PATH}/train.tsv"
VALIDATION_PATH="${DATA_PATH}/validation.tsv"
TEST1_PATH="${DATA_PATH}/heldout_compositions.tsv"
TEST2_PATH="${DATA_PATH}/heldout_inputs.tsv"
TEST3_PATH="${DATA_PATH}/heldout_tables.tsv"
TEST4_PATH="${DATA_PATH}/new_compositions.tsv"
TEST5_PATH="${DATA_PATH}/longer_compositions_incremental.tsv"
TEST6_PATH="${DATA_PATH}/longer_compositions_new.tsv"
TEST7_PATH="${DATA_PATH}/longer_compositions_seen.tsv"
MONITOR_DATA="${VALIDATION_PATH} ${TEST1_PATH} ${TEST2_PATH} ${TEST3_PATH} ${TEST4_PATH} ${TEST5_PATH} ${TEST6_PATH} ${TEST7_PATH} ${TEST8_PATH} ${TEST9_PATH}"
# TRAIN SETTINGS
TF=0.5
EPOCHS=100
BATCH_SIZE=1
EVAL_BATCH_SIZE=2000
METRICS="seq_acc"
SAVE_EVERY=100
PRINT_EVERY=100
CUDA=0
EXPT_DIR="seq2attn_lookup_checkpoints"
# MODEL PARAMETSR
DROPOUT=0.5
ATTENTION="pre-rnn"
ATTN_METHOD="mlp"
EMB_SIZE=256
HIDDEN_SIZE=256
RNN_CELL=gru
ATTN_VALS=embeddings
SAMPLE_TRAIN=gumbel_st
SAMPLE_INFER=argmax
INIT_TEMP=5
LEARN_TEMP=conditioned
FULL_ATTENTION_FOCUS=yes
python train_model.py \
--train $TRAIN_PATH \
--dev $VALIDATION_PATH \
--monitor $MONITOR_DATA \
--metrics $METRICS \
--output_dir $EXPT_DIR \
--epochs $EPOCHS \
--rnn_cell $RNN_CELL \
--embedding_size $EMB_SIZE \
--hidden_size $HIDDEN_SIZE \
--dropout_p_encoder $DROPOUT \
--dropout_p_decoder $DROPOUT \
--teacher_forcing_ratio $TF \
--attention $ATTENTION \
--attention_method $ATTN_METHOD \
--batch_size $BATCH_SIZE \
--eval_batch_size $EVAL_BATCH_SIZE \
--save_every $SAVE_EVERY \
--print_every $PRINT_EVERY \
--write-logs "${EXPT_DIR}_LOG" \
--cuda_device $CUDA \
--sample_train $SAMPLE_TRAIN \
--sample_infer $SAMPLE_INFER \
--initial_temperature $INIT_TEMP \
--learn_temperature $LEARN_TEMP \
--attn_vals $ATTN_VALS \
--full_attention_focus $FULL_ATTENTION_FOCUS
python evaluate.py \
--checkpoint_path $EXPT_DIR/$(ls -t $EXPT_DIR/ | head -2 | tail -1) \
--test_data $VALIDATION_PATH \
--batch_size $EVAL_BATCH_SIZE \
--attention pre-rnn \
--attention_method mlp \