Skip to content

Commit

Permalink
Update base for Update on "Extra CR comments from #95621"
Browse files Browse the repository at this point in the history
Specifically:
https://github.com/pytorch/pytorch/pull/95621/files/063e44147152f4dd7e51852cf8c679692bd9fd53#r1120306196
#95621 (comment)

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
  • Loading branch information
ezyang committed Mar 9, 2023
2 parents 0f6aeae + 457396f commit 725c5b8
Show file tree
Hide file tree
Showing 72 changed files with 823 additions and 981 deletions.
5 changes: 2 additions & 3 deletions .ci/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)"
pip install -q --user transformers==4.25.1
pip install -q --user ninja flatbuffers==2.0 numpy==1.22.4 onnxruntime==1.14.0 beartype==0.10.4
# TODO: change this when onnx 1.13.1 is released.
pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@e192ba01e438d22ca2dedd7956e28e3551626c91'
pip install -q --user onnx==1.13.1
# TODO: change this when onnx-script is on testPypi
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@0298154caf6b46fc4e30abba034095c1290c26e3'
pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@29241e15f5182be1384f1cf6ba203d7e2e125196'
# numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21.
# We don't actually need it for our tests, but it's imported if it's present, so uninstall.
pip uninstall -q --yes numba
Expand Down
14 changes: 11 additions & 3 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,16 @@ test_single_dynamo_benchmark() {

local partition_flags=()
if [[ -n "$NUM_TEST_SHARDS" && -n "$shard_id" ]]; then
partition_flags=( --total-partitions 2 --partition-id "$shard_id" )
partition_flags=( --total-partitions "$NUM_TEST_SHARDS" --partition-id "$shard_id" )
fi

if [[ "${TEST_CONFIG}" == *perf* ]]; then
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
python "benchmarks/dynamo/$suite.py" \
--ci --performance --disable-cudagraphs \
"${DYNAMO_BENCHMARK_FLAGS[@]}" \
"$@" "${partition_flags[@]}" \
--output "$TEST_REPORTS_DIR/${name}_${suite}.csv"
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
# MKL_THREADING_LAYER=GNU to mitigate https://github.com/pytorch/pytorch/issues/37377
MKL_THREADING_LAYER=GNU python benchmarks/dynamo/runner.py --suites="$suite" \
--base-sha="$BASE_SHA" --output-dir="$TEST_REPORTS_DIR" "${partition_flags[@]}" \
Expand All @@ -325,7 +331,9 @@ test_dynamo_benchmark() {
local shard_id="$1"
shift

if [[ "${TEST_CONFIG}" == *perf* ]]; then
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
test_single_dynamo_benchmark "amp" "$suite" "$shard_id" --training --amp "$@"
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
# Performance test training only, for float32 and amp
test_single_dynamo_benchmark "amp" "$suite" "$shard_id" --training --dtypes=amp "$@"
test_single_dynamo_benchmark "float32" "$suite" "$shard_id" --training --dtypes=float32 "$@"
Expand Down
8 changes: 7 additions & 1 deletion .github/actions/setup-linux/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ runs:
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
# If it is GCP runner (runner name contains gcp), do not run this
runner_name_str=${{ runner.name }}
if [[ $runner_name_str != *"gcp"* ]]; then
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
else
echo "Runner is from Google Cloud Platform, No info on ec2 metadata"
fi
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
Expand Down
1 change: 1 addition & 0 deletions .github/pytorch-probot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ciflow_push_tags:
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
- ciflow/inductor
- ciflow/inductor-perf-compare
- ciflow/inductor-perf-test-nightly
- ciflow/mps
- ciflow/nightly
Expand Down
37 changes: 37 additions & 0 deletions .github/workflows/inductor-perf-compare.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: inductor-A100-perf-compare

on:
push:
tags:
- ciflow/inductor-perf-compare/*
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

jobs:
linux-bionic-cuda11_8-py3_10-gcc7-inductor-build:
name: cuda11.8-py3.10-gcc7-sm80
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-bionic-cuda11.8-py3.10-gcc7-sm80
docker-image-name: pytorch-linux-bionic-cuda11.8-cudnn8-py3-gcc7
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_compare", shard: 1, num_shards: 1, runner: "linux.gcp.a100" },
{ config: "inductor_timm_perf_compare", shard: 1, num_shards: 2, runner: "linux.gcp.a100" },
{ config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.gcp.a100" },
{ config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.gcp.a100" },
]}
linux-bionic-cuda11_8-py3_10-gcc7-inductor-test:
name: cuda11.8-py3.10-gcc7-sm80
uses: ./.github/workflows/_linux-test.yml
needs: linux-bionic-cuda11_8-py3_10-gcc7-inductor-build
with:
build-environment: linux-bionic-cuda11.8-py3.10-gcc7-sm80
docker-image: ${{ needs.linux-bionic-cuda11_8-py3_10-gcc7-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-bionic-cuda11_8-py3_10-gcc7-inductor-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha
2 changes: 1 addition & 1 deletion .github/workflows/inductor-perf-test-nightly.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: inductor-A100-perf
name: inductor-A100-perf-nightly

on:
schedule:
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
with:
runner: linux.2xlarge
docker-image: ${{ needs.docker-image.outputs.docker-image }}
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down Expand Up @@ -71,6 +72,7 @@ jobs:
with:
runner: linux.2xlarge
docker-image: ${{ needs.docker-image.outputs.docker-image }}
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down Expand Up @@ -124,6 +126,7 @@ jobs:
with:
runner: linux.2xlarge
docker-image: ${{ needs.docker-image.outputs.docker-image }}
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down Expand Up @@ -158,6 +161,7 @@ jobs:
with:
runner: linux.2xlarge
docker-image: ${{ needs.docker-image.outputs.docker-image }}
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down Expand Up @@ -196,6 +200,7 @@ jobs:
runner: linux.2xlarge
docker-image: ${{ needs.docker-image.outputs.docker-image }}
fetch-depth: 0
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand Down
49 changes: 49 additions & 0 deletions .github/workflows/periodic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,52 @@ jobs:
{ include: [
{ config: "default", shard: 1, num_shards: 1, runner: "ubuntu-latest" },
]}
macos-12-py3-x86-64-build:
name: macos-12-py3-x86-64
uses: ./.github/workflows/_mac-build.yml
with:
build-environment: macos-12-py3-x86-64
xcode-version: "13.3.1"
runner-type: macos-12-xl
build-generates-artifacts: true
sccache-use-gha: true
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 3, runner: "macos-12" },
{ config: "default", shard: 2, num_shards: 3, runner: "macos-12" },
{ config: "default", shard: 3, num_shards: 3, runner: "macos-12" },
{ config: "functorch", shard: 1, num_shards: 1, runner: "macos-12" },
]}
secrets:
MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}

macos-12-py3-x86-64-test:
name: macos-12-py3-x86-64
uses: ./.github/workflows/_mac-test.yml
needs: macos-12-py3-x86-64-build
with:
build-environment: macos-12-py3-x86-64
test-matrix: ${{ needs.macos-12-py3-x86-64-build.outputs.test-matrix }}
arch: x86_64
secrets:
AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }}
AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }}

macos-12-py3-x86-64-lite-interpreter-build-test:
name: macos-12-py3-x86-64-lite-interpreter
uses: ./.github/workflows/_mac-build.yml
with:
build-environment: macos-12-py3-lite-interpreter-x86-64
xcode-version: "13.3.1"
runner-type: macos-12-xl
build-generates-artifacts: false
sccache-use-gha: true
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 1, runner: "macos-12" },
]}
secrets:
MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}
49 changes: 0 additions & 49 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,55 +171,6 @@ jobs:
{ config: "default", shard: 1, num_shards: 1, runner: "macos-12" },
]}
macos-12-py3-x86-64-build:
name: macos-12-py3-x86-64
uses: ./.github/workflows/_mac-build.yml
with:
build-environment: macos-12-py3-x86-64
xcode-version: "13.3.1"
runner-type: macos-12-xl
build-generates-artifacts: true
sccache-use-gha: true
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 3, runner: "macos-12" },
{ config: "default", shard: 2, num_shards: 3, runner: "macos-12" },
{ config: "default", shard: 3, num_shards: 3, runner: "macos-12" },
{ config: "functorch", shard: 1, num_shards: 1, runner: "macos-12" },
]}
secrets:
MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}

macos-12-py3-x86-64-test:
name: macos-12-py3-x86-64
uses: ./.github/workflows/_mac-test.yml
needs: macos-12-py3-x86-64-build
with:
build-environment: macos-12-py3-x86-64
test-matrix: ${{ needs.macos-12-py3-x86-64-build.outputs.test-matrix }}
arch: x86_64
secrets:
AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }}
AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }}

macos-12-py3-x86-64-lite-interpreter-build-test:
name: macos-12-py3-x86-64-lite-interpreter
uses: ./.github/workflows/_mac-build.yml
with:
build-environment: macos-12-py3-lite-interpreter-x86-64
xcode-version: "13.3.1"
runner-type: macos-12-xl
build-generates-artifacts: false
sccache-use-gha: true
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 1, runner: "macos-12" },
]}
secrets:
MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}

macos-12-py3-arm64-build:
name: macos-12-py3-arm64
uses: ./.github/workflows/_mac-build.yml
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/upload-torch-dynamo-perf-stats.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Upload torch dynamo performance stats

on:
workflow_run:
workflows: [inductor-A100-perf]
workflows: [inductor-A100-perf-nightly]
types:
- completed
branches:
Expand Down
8 changes: 8 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,11 @@ init_command = [
'--output-name=bazel',
]
is_formatter = true

[[linter]]
code = 'LINTRUNNER_VERSION'
include_patterns = ['**']
command = [
'python3',
'tools/linter/adapters/lintrunner_version_linter.py'
]
4 changes: 4 additions & 0 deletions android/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if [ -z "$PYTORCH_DIR" ]; then
exit 1
fi

retry () {
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
}

check_android_sdk() {
if [ -z "$ANDROID_HOME" ]; then
echo "ANDROID_HOME not set; please set it to Android sdk directory"
Expand Down
3 changes: 2 additions & 1 deletion android/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ echo "Waiting for emulator boot completed"
$ADB_PATH wait-for-device shell 'while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done;'

{
$GRADLE_PATH -PABI_FILTERS=x86 -p $PYTORCH_ANDROID_DIR connectedAndroidTest
# The test currently takes about 10 minutes
retry $GRADLE_PATH -PABI_FILTERS=x86 -p $PYTORCH_ANDROID_DIR connectedAndroidTest
} || {
echo "::error::Check https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test to see how to fix the failed mobile test"
exit 1
Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13896,7 +13896,7 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out

- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
Expand All @@ -13908,55 +13908,55 @@
autogen: _scaled_dot_product_attention.out

# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
dispatch:
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda

- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None) -> (Tensor, Tensor)
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
variants: function

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda

- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offse, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
variants: function
dispatch:
CUDA: _scaled_dot_product_flash_attention_backward_cuda

- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_cuda
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda

- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False) -> (Tensor, Tensor, Tensor)
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_backward_cuda

- func: _chunk_grad_outputs_efficient_attention(Tensor query, Tensor key, Tensor value, bool is_causal=False) -> bool
dispatch:
CUDA: _chunk_grad_outputs_efficient_attention

- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask) -> (Tensor output, Tensor softmax_logsumexp, int philox_seed, int philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, int philox_seed, int philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward

- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset) -> (Tensor, Tensor, Tensor)
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CUDA: _flash_attention_backward

# Returns ouput, logsumexp if compute_logsumexp
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False, *, float? scale=None) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_forward

- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False) -> (Tensor, Tensor, Tensor)
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_backward
Expand Down
Loading

0 comments on commit 725c5b8

Please sign in to comment.