Skip to content

Code for "Accelerating Training with Neuron Interaction and Nowcasting Networks"

License

Notifications You must be signed in to change notification settings

SamsungSAILMontreal/nino

Repository files navigation

Accelerating Training with Neuron Interaction and Nowcasting Networks

Accelerating Training with Neuron Interaction and Nowcasting Networks
Boris Knyazev, Abhinav Moudgil, Guillaume Lajoie, Eugene Belilovsky, Simon Lacoste-Julien

arXiv

marktechpost

Intro

Neuron interaction and Nowcasting (NiNo) model

We introduce the NiNo model predicting future (nowcasting) parameters by learning neuron interaction in vision and language tasks. We feed c (c=5 by default) past parameter states as input to NiNo and nowcast K states leveraging neural graph structure and graph neural networks. For a new optimization task, NiNo is applied rarely over time: only once per 1k steps of Adam (or another base optimizer).

Using NiNo with Adam

Adam without and with nowcasting using our NiNo model on a language task that NiNo has not seen during its training.

Requirements

The experiments from our paper can be run using a single GPU with <= 80GB of memory.

  • python >= 3.8
  • pytorch >= 2.0
  • torch_geometric
  • transformers
  • datasets
  • other optional dependencies (networkx, pydot)

Updates

  • Initial code release with a pretrained NiNo model (see the checkpoints folder).
    • nino.pt - default NiNo model (assume the GPT2 tokenizer)
    • nino_no_posw.pt - NiNo without positional encoding for word embeddings (can be used for arbitrary models and tokenizers including Llama)
    • nino_h32.pt - NiNo with hidden size 32 instead of default 128
    • nino_mlp.pt - WNN+ model (does not use graphs)
    • nino_towers4.pt - NiNo with 4 towers in the message passing step for better efficiency
  • Neural graphs and evaluation script for convnet tasks.
  • Neural graphs and evaluation script for transformer tasks:
    • GPT2
    • BERT (experimental code)
    • Llama (experimental code, see a graph for a smaller variant of meta-llama/Meta-Llama-3.1-8B in the results folder)
    • Vision Transformer (experimental code)
  • Training dataset and training code for NiNo.

Pretrained NiNo models

We provide the checkpoint for our best performing NiNo model at checkpoints/nino.pt.

Usage

Example

Training loop with NiNo for some language model:

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from optim import NiNo

model = AutoModelForCausalLM.from_config(...)  # some model

# NiNo is implemented as a wrapper around the base optimizer
# any optimizer other than Adam should also be possible to use with NiNo
opt = NiNo(base_opt=torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2),
           ckpt='checkpoints/nino.pt',
           message_passing_device=None,  # can use 'cpu' when NiNo is applied to larger models 
           model=model,
           period=1000,
           max_steps=10000)
for step in range(10000):
    if opt.need_grads:  # True/False based on the step number and period
        opt.zero_grad()  # zero out gradients
        data, targets = ...  # get some batch of data
        # base optimizer step (majority of the time)
        outputs = model(data)  # forward pass
        loss = F.cross_entropy(outputs, targets)  # compute some loss
        loss.backward()  # only compute gradients for the base optimizer            
    opt.step()  # base_opt step or nowcast params every 1000 steps using NiNo    
    ...

Reproducing the results from our paper

Optimization in vision tasks

Evaluate on all vision tasks:

for task in FM-16 C10-16 FM-32 C10-32 C100-32; 
do for seed in $(seq 1000 1000 10000); 
do python train_vision.py --task $task --seed $seed | tee -a results.log; done; done

To evaluate without the NiNo model, run with --nino_ckpt none. You should get the results similar Table 1 and 2 in the paper.

Use --verbose 2 for graph visualization and more detailed output.

Optimization in language tasks

Single seed training on the Wiki/3-64 task:

python train_lm.py --dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 --num_train_epochs 4 --layers 3 --dim 64 --heads 4

For LM1B tasks, use --dataset_name lm1b --dataset_config_name plain_text.

Contributing

Pull requests and github issues are welcome. For major changes, please open an issue first to discuss what you would like to change.

LICENSE

MIT, see the LICENSE file.

Citation

@misc{knyazev2024accelerating,
  title={Accelerating Training with Neuron Interaction and Nowcasting Networks}, 
  author={Boris Knyazev and Abhinav Moudgil and Guillaume Lajoie and Eugene Belilovsky and Simon Lacoste-Julien},
  year={2024},
  eprint={2409.04434},
  archivePrefix={arXiv},
  url={https://arxiv.org/abs/2409.04434}, 
}

Releases

No releases published

Packages

No packages published

Languages