Skip to content
/ OTDD Public

Python implementation of Geometric Dataset Distances via Optimal Transport

Notifications You must be signed in to change notification settings

kheyer/OTDD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Optimal Transport Dataset Distances

This repo is a python implementation of Geometric Dataset Distances via Optimal Transport and Robust Optimal Transport. Routines are implemented in numpy with Python Optimal Transport and CVXPY, as well as in Pytorch using KeOps and GeomLoss.

The OTDD algorithm allows us to incorporate label information into the optimal transport problem.

coupling comparison

Algorithm OverviewAPIExamples

Installing

Core dependencies can be installed from the environment.yml file

conda env create -f environment.yml

To use the Pytorch implementation, install Pytorch, KeOps and GeomLoss

conda install pytorch torchvision torchaudio -c pytorch pip install pykeops pip install geomloss

Then validate the KeOps installation

import pykeops
pykeops.clean_pykeops()
pykeops.test_torch_bindings() 

To use the cheminformatics functions in chem.py, install RDKit

conda install -c rdkit rdkit

About

Python implementation of Geometric Dataset Distances via Optimal Transport

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published