-
Notifications
You must be signed in to change notification settings - Fork 485
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: start PyTorch TabNet Paper Implementation
- Loading branch information
0 parents
commit e7dc059
Showing
15 changed files
with
2,964 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
* text=auto | ||
# Basic .gitattributes for a python repo. | ||
|
||
# Source files | ||
# ============ | ||
*.pxd text diff=python | ||
*.py text diff=python | ||
*.py3 text diff=python | ||
*.pyw text diff=python | ||
*.pyx text diff=python | ||
*.pyz text diff=python | ||
|
||
# Binary files | ||
# ============ | ||
*.db binary | ||
*.p binary | ||
*.pkl binary | ||
*.pickle binary | ||
*.pyc binary | ||
*.pyd binary | ||
*.pyo binary | ||
|
||
# Jupyter notebook | ||
*.ipynb text | ||
|
||
# Note: .db, .p, and .pkl files are associated | ||
# with the python modules ``pickle``, ``dbm.*``, | ||
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb`` | ||
# (among others). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
.cache/ | ||
../.history/ | ||
data/ | ||
.ipynb_checkpoints/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM python:3.7-slim-buster | ||
RUN apt update && apt install curl make git -y | ||
RUN curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python | ||
ENV SHELL /bin/bash -l | ||
|
||
ENV POETRY_CACHE /work/.cache/poetry | ||
ENV PIP_CACHE_DIR /work/.cache/pip | ||
ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime | ||
ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config | ||
|
||
RUN $HOME/.poetry/bin/poetry config settings.virtualenvs.path $POETRY_CACHE | ||
|
||
# ENTRYPOINT ["poetry", "run"] | ||
CMD ["bash", "-l"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019 DreamQuark | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# set default shell | ||
SHELL := $(shell which bash) | ||
FOLDER=$$(pwd) | ||
# default shell options | ||
.SHELLFLAGS = -c | ||
NO_COLOR=\\e[39m | ||
OK_COLOR=\\e[32m | ||
ERROR_COLOR=\\e[31m | ||
WARN_COLOR=\\e[33m | ||
PORT=8889 | ||
.SILENT: ; | ||
default: help; # default target | ||
|
||
IMAGE_NAME=python-poetry:latest | ||
|
||
build: | ||
echo "Building Dockerfile" | ||
docker build -t ${IMAGE_NAME} . | ||
.PHONY: build | ||
|
||
start: build | ||
echo "Starting container ${IMAGE_NAME}" | ||
docker run --rm -it -v ${FOLDER}:/work -w /work -p ${PORT}:${PORT} -e "JUPYTER_PORT=${PORT}" ${IMAGE_NAME} | ||
.PHONY: start | ||
|
||
notebook: | ||
poetry run jupyter notebook --allow-root --ip 0.0.0.0 --port ${PORT} --no-browser --notebook-dir . | ||
.PHONY: notebook | ||
|
||
root_bash: | ||
docker exec -it --user root $$(docker ps --filter ancestor=${IMAGE_NAME} --filter expose=${PORT} -q) bash | ||
.PHONY: root_bash | ||
|
||
help: | ||
echo -e "make [ACTION] <OPTIONAL_ARGS>" | ||
echo | ||
echo -e "This image uses Poetry for dependency management (https://poetry.eustace.io/)" | ||
echo | ||
echo -e "Default port for Jupyter notebook is 8888" | ||
echo | ||
echo -e "$(UDLINE_TEXT)ACTIONS$(NORMAL_TEXT):" | ||
echo -e "- $(BOLD_TEXT)init$(NORMAL_TEXT): create pyproject.toml interactive and install virtual env" | ||
echo -e "- $(BOLD_TEXT)run$(NORMAL_TEXT) port=<port>: run the Jupyter notebook on the given port" | ||
echo -e "- $(BOLD_TEXT)stop$(NORMAL_TEXT) port=<port>: stop the running notebook on this port" | ||
echo -e "- $(BOLD_TEXT)logs$(NORMAL_TEXT) port=<port>: show and tail the logs of the notebooks" | ||
echo -e "- $(BOLD_TEXT)shell$(NORMAL_TEXT) port=<port>: open a poetry shell" | ||
.PHONY: help |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# README | ||
|
||
# TabNet : Attentive Interpretable Tabular Learning | ||
|
||
This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) https://arxiv.org/pdf/1908.07442.pdf. | ||
|
||
# Installation | ||
|
||
You can install using pip by running: | ||
`pip install tabnet` | ||
|
||
If you wan to use it locally within a docker container: | ||
|
||
`git clone [email protected]:dreamquark-ai/tabnet.git` | ||
|
||
`cd tabnet` to get inside the repository | ||
|
||
`make start` to build and get inside the container | ||
|
||
`poetry install` to install all the dependencies, including jupyter | ||
|
||
`make notebook` inside the same terminal | ||
|
||
You can then follow the link to a jupyter notebook with tabnet installed. | ||
|
||
|
||
|
||
GPU version is available and should be working but is not supported yet. | ||
|
||
# How to use it? | ||
|
||
The implementation makes it easy to try different architectures of TabNet. | ||
All you need is to change the network parameters and training parameters. All parameters are quickly describe bellow, to get a better understanding of what each parameters do please refer to the orginal paper. | ||
|
||
You can also get comfortable with the code works by playing with the **notebooks tutorials** for adult census income dataset and forest cover type dataset. | ||
|
||
## Network parameters | ||
|
||
- input_dim : int | ||
|
||
Number of initial features of the dataset | ||
|
||
- output_dim : int | ||
|
||
Size of the desired output. Ex : | ||
- 1 for regression task | ||
- 2 for binary classification | ||
- N > 2 for multiclass classifcation | ||
|
||
- nd : int | ||
|
||
Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. | ||
Values typically range from 8 to 64. | ||
|
||
- na : int | ||
|
||
Width of the attention embedding for each mask. | ||
According to the paper nd=na is usually a good choice. | ||
|
||
- n_steps : int | ||
Number of steps in the architecture (usually between 3 and 10) | ||
|
||
- gamma : float | ||
This is the coefficient for feature reusage in the masks. | ||
A value close to 1 will make mask selection least correlated between layers. | ||
Values range from 1.0 to 2.0 | ||
- cat_idxs : list of int | ||
|
||
List of categorical features indices. | ||
- cat_emb_dim : list of int | ||
|
||
List of embeddings size for each categorical features. | ||
- n_independent : int | ||
|
||
Number of independent Gated Linear Units layers at each step. | ||
Usual values range from 1 to 5 (default=2) | ||
- n_shared : int | ||
|
||
Number of shared Gated Linear Units at each step | ||
Usual values range from 1 to 5 (default=2) | ||
- virtual_batch_size : int | ||
|
||
Size of the mini batches used for Ghost Batch Normalization | ||
|
||
## Training parameters | ||
|
||
- max_epochs : int (default = 200) | ||
|
||
Maximum number of epochs for trainng. | ||
- patience : int (default = 15) | ||
|
||
Number of consecutive epochs without improvement before performing early stopping. | ||
- lr : float (default = 0.02) | ||
|
||
Initial learning rate used for training. As mentionned in the original paper, a large initial learning of ```0.02 ``` with decay is a good option. | ||
- clip_value : float (default None) | ||
|
||
If a float is given this will clip the gradient at clip_value. | ||
- lambda_sparse : float (default = 1e-3) | ||
|
||
This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help. | ||
- model_name : str (default = 'DQTabNet') | ||
|
||
Name of the model used for saving in disk, you can customize this to easily retrieve and reuse your trained models. | ||
- saving_path : str (default = './') | ||
|
||
Path defining where to save models. | ||
- scheduler_fn : torch.optim.lr_scheduler (default = None) | ||
|
||
Pytorch Scheduler to change learning rates during training. | ||
- scheduler_params: dict | ||
|
||
Parameters dictionnary for the scheduler_fn. Ex : {"gamma": 0.95, "step_size": 10} | ||
- verbose : int (default=-1) | ||
|
||
Verbosity for notebooks plots, set to 1 to see every epoch. |
Oops, something went wrong.