Skip to content

A customizable 1D/2D U-Net model for libtorch (PyTorch C++ UNet)

License

Notifications You must be signed in to change notification settings

divideconcept/PyTorch-libtorch-U-Net

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch/libtorch Customizable 1D/2D U-Net

A customizable 1D/2D U-Net model for libtorch (PyTorch c++ UNet)
Robin Lobel, March 2020 - Requires libtorch 1.4.0 or higher. CPU & CUDA compatible. Qt compatible.

The default parameters produce the original 2D UNet ( https://arxiv.org/pdf/1505.04597.pdf ) with all core improvements activated, resulting in a fully convolutional 2D network.
The default parameters for the 1D Unet are inspired by the Wave UNet ( https://arxiv.org/pdf/1806.03185.pdf ) with all core improvements activated, resulting in a fully convolutional 1D network.

You can customize the number of in/out channels, the number of hidden feature channels, the number of levels, the size of the kernel, and activate improvements such as:

You can additionally display the size of all internal layers the first time you use the network

How to choose the parameters

  • The number of input channels is the number of useful infos you can feed the model for each pixel (for instance 3 channels (RGB) for a picture, 2 channels for a spectrogramme (real/imaginary)).
  • The number of output channels is the number of infos you want in the end; it can be the same as the input if you want to get a filtered picture or a spectrogram back, for instance, but can also be any other kind of infos (classification masks...).
  • The number of hidden feature channels can only be determined by experimenting (that's why I would recommend to only tweak that parameter last). Start with a low number of feature channels (8 for instance) because the training will go fast, then double it until the output no longer increase in quality (check the loss value, and visualize the results).
  • The number of levels can be determined by opening your input samples into a viewer, and then downscale by a factor of 2 several times until you can't discriminate any useful feature anymore. The number of downscales correspond to the number of useful levels for the model.

Usage (2D UNet)

#include "cunet.h"

int main(int argc, char *argv[])
{
    int batchSize=16;
    int inChannels=3, outChannels=3;
    int height=512, width=512;
    
    CUNet2d model(inChannels,outChannels);
    torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));
    
    torch::Tensor source=torch::randn({batchSize,inChannels,height,width});
    torch::Tensor target=torch::randn({batchSize,outChannels,height,width});
    torch::Tensor result, loss;
    
    model->train();
    for (int epoch = 0; epoch < 100; epoch++)
    {
        optim.zero_grad();
        result = model(source);
        loss = torch::mse_loss(result, target);
        loss.backward();
        optim.step();
    }
    
    model->eval();
    torch::Tensor validation=torch::randn({batchSize,inChannels,height,width});
    torch::Tensor inference = model(validation);
    
    return 0;
}

Usage (1D UNet)

#include "cunet.h"

int main(int argc, char *argv[])
{
    int batchSize=16;
    int inChannels=2, outChannels=2;
    int size=2048;
    
    CUNet1d model(inChannels,outChannels);
    torch::optim::Adam optim(model->parameters(), torch::optim::AdamOptions(1e-3));
    
    torch::Tensor source=torch::randn({batchSize,inChannels,size});
    torch::Tensor target=torch::randn({batchSize,outChannels,size});
    torch::Tensor result, loss;
    
    model->train();
    for (int epoch = 0; epoch < 100; epoch++)
    {
        optim.zero_grad();
        result = model(source);
        loss = torch::mse_loss(result, target);
        loss.backward();
        optim.step();
    }
    
    model->eval();
    torch::Tensor validation=torch::randn({batchSize,inChannels,size});
    torch::Tensor inference = model(validation);
    
    return 0;
}

About

A customizable 1D/2D U-Net model for libtorch (PyTorch C++ UNet)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages