Quickly fit neural fields to an entire dataset.
Creators: Samuele Papa, Riccardo Valperga, David Knigge, Phillip Lippe.
Official repository of both the fit-a-nef
library, and the example of how to use it effectively.
π This code is the base for the benchmark and study in: How to Train Neural Field Representations: A Comprehensive Study and Benchmark .
βοΈ For the neural dataset collection and to use the neural fields as representations, see the neural-field-arena repository.
Using the ability of JAX to easily parallelize the operations on a GPU with vmap
, a sizeable set of neural fields can be fit to distinct samples at the same time.
The fit-a-nef
library is designed to easily allow the user to add their own training task, dataset, and model. It provides a uniform format to store and load large amounts of neural fields in a platform-agnostic way. Whether you use PyTorch, JAX or any other framework, the neural fields can be loaded and used in your project.
This repository also provides a simple interface that uses optuna to find the best parameters for any neural field while tracking all relevant metrics using wandb.
- Getting started with π
fit-a-nef
- The repository
- Citing
- Contributing
- Code of conduct
- License
- Acknowledgements and Contributions
For further information see the documentation.
To use the fit-a-nef
library, simply clone the repository and run:
pip install .
This will install the fit-a-nef
library and all its dependencies. To ensure that JAX and PyTorch are installed with the right CUDA/cuDNN version of your platform, we recommend installing them first (see instructions on the official Jax and Pytorch), and then run the command above.
The basic usage of the library is to fit neural fields to a collection of signals. The current signals supported are images and shapes (through occupancy).
The library provides a SignalTrainer
class, which supports the fitting of signals given coordinates. This library is agnostic to the type of signal being used. Additionally, it provides the basic infrastructure to store and load the neural fields in a platform-agnostic way.
Images and shapes can be fit using the SignalImageTrainer
and SignalShapeTrainer
classes, respectively. These classes are agnostic to the type of neural field being used. For example, SignalImageTrainer
can be used to fit images with a SIREN, or an RFFNet.
The trainer classes have a compile
method which can be used to trigger the jit
compilation, and a train_model
method to fit the neural fields to the signals provided.
To handle initialization, the library provides InitModel
classes which can be used to initialize the weights of the neural fields. These classes are agnostic to the type of neural field being used. For example, SharedInit
is used to initialize all the neural fields with the same random weights, and RandomInit
is used to initialize the neural fields with different random weights.
Finally, the library has several neural field architectures already implemented. These can be found in the fit_a_nef.nef
module.
For the full documentation, see here.
After fitting the neural fields, you can use them as representations for downstream tasks. For example, you can use them to classify images or shapes, or to reconstruct the input signal. For this, we recommend using the framework-independent datasets defined in the neural-field-arena repository.
The library provides trainers that allow fitting images and shapes. Additionally, it allows reliable storing of large-scale neural datasets and has code for several neural field architectures.
However, to improve flexibility, the library does not ship with specific datasets or a defined config management system. This repository is meant to provide an example and template on how to correctly use the library, and allow for easy extension of the library to new datasets and tasks.
For some of these, more dependencies are required, which can be found under Optional dependencies above and on the INSTALL.md file.
Depending on the use case you are aiming for with fit-a-nef
, additional optional dependencies may be beneficial to install. For example, for tuning hyperparameters of the neural field fitting, we recommend installing optuna
for automatic hyperparameter selection and wandb
for better logging. Further, if you want to use a specific dataset (e.g. ShapeNet or CIFAR10) for fitting or tuning, which is set up in this repository, ensure that all dependencies for these datasets are met. You can check the needed dependencies by running the simple tuning image and shape tasks (more info on how to do that below).
The repository is structured as follows:
./config
. Configuration files for the tasks../fit_a_nef
. Library for quickly fitting and storing neural fields. Here you can add the trainer for your own task and your own NeF models../dataset
. Package to load the targets and inputs used during training. Here you can add your own dataset../tasks
. Collection of the tasks that we want to carry out. Both fitting and downstream tasks fit here. Here you can add the fitting scripts for your own tasks../tests
. Tests for the code in the repository../assets
. Contains the images used in this README.
The basic usage of this repository is to fit neural fields to a collection of signals. The current signals supported are images and shapes (through occupancy).
Each task has its own fit.py
file which is called to fit the neural fields to the provided signals. The fit.py
file is optimized to provide maximum speed when fitting. Therefore, all logging options have been removed.
Let us look at a simple example. From the root folder we can run:
python tasks/image/fit.py --nef=config/nef.py:SIREN --task.train.end_idx=10000 --task.train.num_parallel_nefs=2000"
This will fit 10k SIRENs each to a different sample from CIFAR10 (the default dataset). This will be done with 2k NeFs in parallel.
For more details, refer to the how-to guide.
If you use this repository in your research, use the following BibTeX entry:
@misc{papa2023train,
title={How to Train Neural Field Representations: A Comprehensive Study and Benchmark},
author={Samuele Papa and Riccardo Valperga and David Knigge and Miltiadis Kofinas and Phillip Lippe and Jan-Jakob Sonke and Efstratios Gavves},
year={2023},
eprint={2312.10531},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Please help us improve this repository by providing your own suggestions or bug report through this repository's GitHub issues system.
Before committing, please ensure to have pre-commit
installed. This will ensure that the code is formatted correctly and that the tests pass. To install it, run:
pip install pre-commit
pre-commit install
Please note that this project has a Code of Conduct. By participating in this project, you agree to abide by its terms.
Distributed under the MIT License. See LICENSE
for more information.
We thank Miltiadis Kofinas, and David Romero for the feedback during development.