🌟 This is the code repo for experiments performed in Language Models Need Inductive Biases to Count Inductively 🌟
In /scripts
, we maintain separate folders for different architecture types. Note, LSTM and RNN are subsumed in /scripts/s4
.
To support reproducibility for individual sets of experiments, mamba
and rwkv
have their own environments, while causal_transformer
and s4
use a shared env. Thus, we provide instructions for building three environments.
Here's how you setup the shared environment for causal_transformer
and s4
.
cd <path_to_this_repo> &&
python3 -m venv venv &&
source venv/bin/activate &&
pip install -r requirements.txt &&
cd scripts/s4 &&
pip install -r s4_requirements.txt
Please click these links for building mamba
and rwkv
environments.
For examples of the input-output formats, there are validation and OOD testing files for each task.
Our training data is generated in this notebook.
If this is the first time you use accelerate, and you haven't configured it, please do:
accelerate config
, and config accordingly.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Remember to specify output_dir, ckpt_dir, hf_cache_dir
in config.py
.
Training command:
cd scripts/causal_transformer && # or cd scripts/s4
python run.py --task <task_name> --cuda 0 --port 29500
Notes
-
<task_name>
can be choosen fromscripts/causal_transformer/config_taskspecific.py
, e.g.counting_samesymbol_mod10bos
. -
Model ckpts will be saved to
ckpt_dir
specified inconfig.py
Model outputs during validation will be saved tooutput_dir
. Specifically, each run will create its own folder underoutput_dir
named by the timestamp, which can be passed totester.py
through the argument "handle". -
If you're running multiple jobs on the same machine, use different ports. Otherwise, accelerator will complain about busy port.
python tester.py --handle <timestamp>
E.g., timestamp = 0522_103640
@article{chang2024language,
title={Language Models Need Inductive Biases to Count Inductively},
author={Chang, Yingshan and Bisk, Yonatan},
journal={arXiv preprint arXiv:2405.20131},
year={2024}
}
- Implementation of causal Transformer, as well as its positional embedding variants, is borrowed heavily from huggingface's implementation of gpt-2, t5 and llama.
- We give credit to the official S4 repo for implementation of s4.
- We give credit to the official rwkv repo for implementation of rwkv.
- We give credit to the official mamba repo for implementation of mamba, as well as the mamba-chat repo for setting up the mamba environment.