Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attention mask & bias #1

Merged
merged 73 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
8b7b72c
add support for attn mask
robotcator Jul 31, 2022
b8acf76
add mask operation
robotcator Aug 1, 2022
abc409d
add mask operation
robotcator Aug 1, 2022
98b290a
add mask operation
robotcator Aug 1, 2022
ce0aff1
add interface
robotcator Aug 1, 2022
5c023ba
add mask support
robotcator Aug 1, 2022
bc5aa56
add mask supprt
robotcator Aug 2, 2022
a6f232b
fix up
robotcator Aug 4, 2022
2735ee9
add bias
robotcator Aug 4, 2022
3232d8d
add template
robotcator Aug 4, 2022
f33f2ad
add test
robotcator Aug 4, 2022
0402aa5
clean
robotcator Aug 4, 2022
b4793fa
clean code
robotcator Aug 4, 2022
3659eb0
add mask load
robotcator Aug 23, 2022
4368b1b
add mask test
robotcator Aug 24, 2022
81281c3
fix forward bugs
robotcator Aug 25, 2022
4bbe6b1
add test
robotcator Aug 25, 2022
d2b883b
add mask in backward
robotcator Aug 25, 2022
050107c
add test case
robotcator Aug 25, 2022
6f14cb5
add bias
robotcator Aug 26, 2022
dfcd4a7
add mask
robotcator Aug 27, 2022
ffd07e8
add bias test
robotcator Aug 30, 2022
b824852
fix test case
robotcator Aug 30, 2022
319b940
add without mask test
robotcator Aug 30, 2022
1be4a39
add kernel test
robotcator Aug 30, 2022
14f7e08
add ds save
robotcator Sep 1, 2022
debc046
fix interface
robotcator Sep 1, 2022
b812889
add test
robotcator Sep 1, 2022
62a2f88
fix dbias
robotcator Sep 3, 2022
5eb754a
add bias support
robotcator Sep 4, 2022
baa6d1b
add mask shape
robotcator Sep 5, 2022
84e462f
add test
robotcator Sep 5, 2022
81c0743
add support
robotcator Sep 6, 2022
89e74b9
fix bf16 and mask shape
robotcator Sep 6, 2022
05505f9
fix mask head=1 shape
robotcator Sep 7, 2022
e435fd1
add dump
robotcator Sep 8, 2022
30c29a6
to fix len 512
robotcator Sep 9, 2022
f06016e
add test
robotcator Sep 9, 2022
5ad59e9
fix seqlen greater than 256
robotcator Sep 9, 2022
94597de
fix bias seqlen
robotcator Sep 9, 2022
fb7ef92
add constexpr
robotcator Sep 15, 2022
4efdf9e
add const expr for bwd
robotcator Sep 15, 2022
00d3e03
add benchmark
robotcator Sep 16, 2022
24b55bd
add test tools
robotcator Sep 16, 2022
a0b4891
add script
robotcator Sep 16, 2022
95d0308
add cross attention
robotcator Sep 22, 2022
df852f5
add cross attn
robotcator Sep 22, 2022
f71172d
Merge branch 'main' of github.com:robotcator/flash-attention into att…
robotcator Sep 23, 2022
d59fa76
fix bugs
robotcator Oct 10, 2022
bdc1fb3
remove test tools
robotcator Oct 11, 2022
2543703
clean fmha_api.cpp
robotcator Oct 11, 2022
bf68f90
clean fmha_dgrad_fp16_kernel_loop.sm80.cu
robotcator Oct 11, 2022
4f37437
clean fmha_dgrad_kernel_1xN_loop.h
robotcator Oct 11, 2022
60e27d8
clean fmha_fprop_fp16_kernel.sm80.cu
robotcator Oct 11, 2022
96b83bf
clean fmha_fprop_kernel_1xN.h
robotcator Oct 11, 2022
5c167cf
merge from master
robotcator Oct 11, 2022
9c1cb91
cleangmem_tile.h
robotcator Oct 11, 2022
4d20f59
clean softmax.h
robotcator Oct 11, 2022
ede0a96
restore test_flash_attn.py
robotcator Oct 11, 2022
2993cae
clean gmem_tile.h
robotcator Oct 11, 2022
1be0e94
fix fmha_fprop_kernel_1xN.h
robotcator Oct 12, 2022
32f6cd1
fix fmha_dgrad_kernel_1xN_loop.h
robotcator Oct 12, 2022
30e4253
rename has_attn to has_attn_mask, has_bias to has_attn_bias
robotcator Oct 12, 2022
e8a376e
fix fmha_fprop_kernel_1xN.h
robotcator Oct 12, 2022
806e156
rename has_attn to has_attn_mask, has_bias to has_attn_bias
robotcator Oct 12, 2022
15ade00
remove useless benchmark code
robotcator Oct 12, 2022
d663cf5
add declaration
robotcator Oct 12, 2022
de4f2cc
remove useless comments
robotcator Oct 12, 2022
2957838
remove useless comments
robotcator Oct 12, 2022
39fa9d4
add timeout
robotcator Oct 12, 2022
0bb403e
add default timeout for build wheel
robotcator Oct 12, 2022
184991b
remove timeout
robotcator Oct 12, 2022
3384115
reduce build worker for workflow oom
robotcator Oct 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
*.pt
*.tfevents.*
# JetBrains PyCharm IDE
.idea/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# macOS dir files
.DS_Store

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.args
*.egg

# Checkpoints
checkpoints

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mypy
.mypy_cache/

# VSCODE
.vscode/ftp-sync.json
.vscode/settings.json

# too big to git
*.lmdb
*.sto
*.pt
*.pkl

# pytest
.pytest_cache
test/.pytest_cache
/local*
/_*
35 changes: 35 additions & 0 deletions benchmarks/correctness/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from typing import Optional, Callable, List, Tuple, Sequence

from unicore.modules import softmax_dropout


def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])

def _attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor:
dtype_og = query.dtype

if upcast:
query = query.float()
key = key.float()
value = value.float()
if mask is not None:
mask = mask.float()
if bias is not None:
bias = bias.float()

# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 0))

# [*, H, Q, K]
a = torch.matmul(query, key)

a = softmax_dropout(a, dropout_prob=0, is_training=True, mask=mask, bias=bias)

# [*, H, Q, C_hidden]
b = torch.matmul(a, value)

return b.to(dtype_og)
128 changes: 128 additions & 0 deletions benchmarks/correctness/benchmark_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.utils.benchmark as benchmark

from flash_attention import _flash_attn
from attention import _attention
from torch_attention import _torch_attention

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False)
parser.add_argument("--eval", required=False, help="test whether has backward", type=bool, default=False)

args = parser.parse_args()
print(args)


def benchmark_memory(fn, inputs, mask=None, bias=None, grad=None, eval=True, desc='', verbose=False, **kwinputs):
def fwd(grad, inputs, mask=mask, bias=bias, **kwinputs):
with torch.no_grad():
y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs)


def fwd_bwd(grad, inputs, mask=mask, bias=bias, **kwinputs):
y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs)
if type(y) is tuple:
y = y[0]
if grad is None:
grad = torch.randn_like(y)
else:
if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape')
y.backward(grad, retain_graph=False)

if eval:
f = fwd
if verbose:
print ("using fwd func...")
else:
f = fwd_bwd
if verbose:
print ("using fwd and bwd func...")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

f(None, inputs, mask, bias)

torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
if verbose:
print(f"{desc} max memory: ", mem)
torch.cuda.empty_cache()
return mem


def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask


def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True, eval=True):
bs = 1
head = 4
c_dim = 32
seq_q = seq_k = seq_v = seqlen
dtype = torch.bfloat16
device = "cuda"

inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5)
inputs.requires_grad = True
if verbose:
print ("inputs shape: ", inputs.shape)
# [bs, seq, seq, head, c_dim]

if has_bias:
bias = torch.randn(
1, 1, head, seq_q, seq_k, dtype=dtype, device=device
)
bias.requires_grad = True
if verbose:
print ("bias shape: ", bias.shape)
# [1, 1, seq, head, seq_k]
else:
bias = None

if has_mask:
mask = gen_attn_mask(
(
torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2
).type(dtype),
-3e4,
)
if verbose:
print ("mask shape: ", mask.shape)
else:
mask = None

print ("processing seq length: {} in eval model {} ......".format(seqlen, eval))

try:
m1 = benchmark_memory(_attention, inputs, mask=mask, bias=bias, eval=eval, desc='Normal Attention forward')
print (m1)
except:
print ("Normal Attention OOM")

try:
m2 = benchmark_memory(_flash_attn, inputs, mask=mask, bias=bias, eval=eval, desc='Flash Attention forward')
print (m2)
except:
print ("Flash Attention OOM")


for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]:
if args.has_mask_bias:
if not args.eval:
fun(seqlen=seqlen, eval=False)
else:
fun(seqlen=seqlen, eval=True)
else:
if not args.eval:
fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=False)
else:
fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=True)

Loading