Skip to content

Commit

Permalink
attention mask & bias (#1)
Browse files Browse the repository at this point in the history
* add support for attn mask

* add mask operation

* add mask operation

* add mask operation

* add interface

* add mask support

* add mask supprt

* fix up

* add bias

* add template

* add test

* clean

* clean code

* add mask load

* add mask test

* fix forward bugs

* add test

* add mask in backward

* add test case

* add bias

* add mask

* add bias test

* fix test case

* add without mask test

* add kernel test

* add ds save

* fix interface

* add test

* fix dbias

* add bias support

* add mask shape

* add test

* add support

* fix bf16 and mask shape

* fix mask head=1 shape

* add dump

* to fix len 512

* add test

* fix seqlen greater than 256

* fix bias seqlen

* add constexpr

* add const expr for bwd

* add benchmark

* add test tools

* add script

* add cross attention

* add cross attn

* fix bugs

* remove test tools

* clean fmha_api.cpp

* clean fmha_dgrad_fp16_kernel_loop.sm80.cu

* clean fmha_dgrad_kernel_1xN_loop.h

* clean fmha_fprop_fp16_kernel.sm80.cu

* clean fmha_fprop_kernel_1xN.h

* cleangmem_tile.h

* clean softmax.h

* restore test_flash_attn.py

* clean gmem_tile.h

* fix fmha_fprop_kernel_1xN.h

* fix fmha_dgrad_kernel_1xN_loop.h

* rename has_attn to has_attn_mask, has_bias to has_attn_bias

* fix fmha_fprop_kernel_1xN.h

* rename has_attn to has_attn_mask, has_bias to has_attn_bias

* remove useless benchmark code

* add declaration

* remove useless comments

* remove useless comments

* add timeout

* add default timeout for build wheel

* remove timeout

* reduce build worker for workflow oom
  • Loading branch information
robotcator authored Oct 13, 2022
1 parent f515c77 commit a80a963
Show file tree
Hide file tree
Showing 17 changed files with 1,201 additions and 104 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ jobs:
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
pip install wheel
python setup.py bdist_wheel --dist-dir=dist
MAX_JOBS=1 python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
Expand All @@ -127,4 +127,4 @@ jobs:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
asset_content_type: application/*
125 changes: 125 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
*.pt
*.tfevents.*
# JetBrains PyCharm IDE
.idea/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# macOS dir files
.DS_Store

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.args
*.egg

# Checkpoints
checkpoints

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mypy
.mypy_cache/

# VSCODE
.vscode/ftp-sync.json
.vscode/settings.json

# too big to git
*.lmdb
*.sto
*.pt
*.pkl

# pytest
.pytest_cache
test/.pytest_cache
/local*
/_*
Loading

0 comments on commit a80a963

Please sign in to comment.