Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sdpa support for Albert #32092

Merged
merged 10 commits into from
Sep 3, 2024
Merged

Conversation

OmarManzoor
Copy link
Contributor

What does this PR do?

Adds SDPA for the Albert model

Towards #28005

Who can review?

@amyeroberts @fxmarty

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding @OmarManzoor!

All looks good to me! Could you push a commit which contains the message [run_slow] albert which will trigger a run of the slow integrations (and now sdpa) tests?

from transformers import AlbertModel
model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also run a few benchmarks for the model to show expected speed ups when using SDPA

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 20, 2024

Here are some benchmarks for training using albert-base-v2 model and using the AutoModelForMaskedLM. I also used half precision i.e. float16. I used a kaggle notebook with a P100 GPU having 16GB for running the benchmarks.

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
200 1 256 TRUE 0.037 0.037 1.471 234.781 217.758 7.817
200 1 512 TRUE 0.061 0.061 -0.206 454.029 378.171 20.059
200 2 256 TRUE 0.058 0.057 2.046 416.281 378.171 10.077
200 2 512 TRUE 0.113 0.112 0.67 868.131 717.332 21.022
200 4 256 TRUE 0.105 0.103 1.458 794.206 717.332 10.717
200 4 512 TRUE 0.211 0.208 1.286 1664.025 1363.268 22.061

The code is present in the details

import argparse
import random
from typing import Dict

import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForMaskedLM
import gc


def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--num_training_steps",
        type=int,
        default=100,
        help="",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        default="albert/albert-base-v2",
        help="",
    )

    parser.add_argument(
        "--use-half",
        action="store_true",
    )

    parser.add_argument(
        "--use-cuda",
        action="store_true",
    )

    return parser


def seed_init_fn(x):
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    return


def benchmark_training(model, inputs: Dict, num_training_steps: int):
    progress_bar = tqdm(range(num_training_steps))

    model.train()
    # warmup
    for _ in range(10):
        model.zero_grad()
        outputs = model(**inputs)
        loss = outputs.logits.sum()
        loss.backward()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()

    torch.cuda.synchronize()
    start_event.record()
    for _ in range(num_training_steps):
        model.zero_grad()
        outputs = model(**inputs)
        loss = outputs.logits.sum()
        loss.backward()

        progress_bar.update(1)
    end_event.record()
    torch.cuda.synchronize()

    max_memory = torch.cuda.max_memory_allocated(device)

    return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    BATCH_SIZES = [1, 2, 4]
    SEQ_LEN = [256, 512]
    device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")

    output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w")
    output_file.write(
        "num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (sdpa - s), "
        "Speedup (%), Eager peak mem (MB), sdpa peak mem (MB), Mem saving (%)\n"
    )
    all_eager_time_per_batch = {}
    all_eager_max_mem = {}
    all_sdpa_max_mem = {}
    all_sdpa_time_per_batch = {}

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    with torch.device(device):
        hf_model = AutoModelForMaskedLM.from_pretrained(
            args.model_name, torch_dtype=torch.float16 if args.use_half else None, attn_implementation="sdpa"
        )
    hf_model = hf_model.to(device)
    for batch_size in BATCH_SIZES:
        for sequence_length in SEQ_LEN:
            print(f"Benchmark sdpa on: bs={batch_size}, seq_len={sequence_length}")

            vocab_size = hf_model.config.vocab_size
            inputs = {
                "input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to(
                    device
                ),
                # "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
            }

            # raise error if no optimized kernel is available
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
                sdpa_time_per_batch, sdpa_max_mem = benchmark_training(
                    hf_model, inputs=inputs, num_training_steps=args.num_training_steps
                )

            all_sdpa_max_mem[(batch_size, sequence_length)] = sdpa_max_mem
            all_sdpa_time_per_batch[(batch_size, sequence_length)] = sdpa_time_per_batch
            print(f"PT SDPA: {sdpa_time_per_batch:.3f} s, peak {sdpa_max_mem:.2f} MB")

    del hf_model
    gc.collect()

    with torch.device(device):
        hf_model = AutoModelForMaskedLM.from_pretrained(
            args.model_name, torch_dtype=torch.float16 if args.use_half else None, attn_implementation="eager"
        )
    hf_model = hf_model.to(device)

    for batch_size in BATCH_SIZES:
        for sequence_length in SEQ_LEN:
            print(f"Benchmark eager on: bs={batch_size}, seq_len={sequence_length}")

            vocab_size = hf_model.config.vocab_size
            inputs = {
                "input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to(
                    device
                ),
                # "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
            }

            eager_time_per_batch, eager_max_mem = benchmark_training(
                hf_model, inputs=inputs, num_training_steps=args.num_training_steps
            )

            all_eager_time_per_batch[(batch_size, sequence_length)] = eager_time_per_batch
            all_eager_max_mem[(batch_size, sequence_length)] = eager_max_mem

            eager_max_mem = all_eager_max_mem[(batch_size, sequence_length)] * 1e-6
            sdpa_max_mem = all_sdpa_max_mem[(batch_size, sequence_length)] * 1e-6

            eager_time_per_batch = all_eager_time_per_batch[(batch_size, sequence_length)]
            sdpa_time_per_batch = all_sdpa_time_per_batch[(batch_size, sequence_length)]

            print(f"PT eager: {eager_time_per_batch:.3f} s, peak {eager_max_mem:.2f} MB")
            print(f"PT SDPA: {sdpa_time_per_batch:.3f} s, peak {sdpa_max_mem:.2f} MB")
            speedup = (eager_time_per_batch / sdpa_time_per_batch - 1) * 100
            mem_saved = (eager_max_mem / sdpa_max_mem - 1) * 100

            output_file.write(
                "{},{},{},{},{},{},{},{},{},{}\n".format(
                    args.num_training_steps,
                    batch_size,
                    sequence_length,
                    args.use_cuda,
                    f"{eager_time_per_batch:.3f}",
                    f"{sdpa_time_per_batch:.3f}",
                    f"{speedup:.3f}",
                    f"{eager_max_mem:.3f}",
                    f"{sdpa_max_mem:.3f}",
                    f"{mem_saved:.3f}",
                )
            )

    output_file.close()

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 20, 2024

For inference I used albert-base-v2 model with the AutoModelForQuestionAnswering. Please let me know if we can improve the code in this benchmark.

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 1 128 True True True 0.094 0.092 2.154 35.763 36.135 -1.03
50 1 265 True True True 0.131 0.135 -3.216 39.797 40.801 -2.461
50 2 128 True True True 0.116 0.115 1.209 39.599 40.528 -2.293
50 2 265 True True True 0.209 0.207 0.845 46.989 47.609 -1.301
50 4 128 True True True 0.172 0.17 1.145 46.484 46.954 -1.001
50 4 265 True True True 0.381 0.375 1.553 62.494 63.418 -1.457

Code:

import argparse

import numpy as np
import pandas as pd
import torch
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForQuestionAnswering


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num-batches",
        type=int,
        default=50,
        help="",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        help="",
    )
    parser.add_argument(
        "--seqlen",
        type=int,
        default=256,
        help="Input sequence length.",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="albert/albert-base-v2",
        help="",
    )
    parser.add_argument(
        "--use-cuda",
        action="store_true",
    )
    parser.add_argument(
        "--use-half",
        action="store_true",
    )
    parser.add_argument(
        "--use-mask",
        action="store_true",
    )
    parser.add_argument(
        "--sweep",
        action="store_true",
    )
    parser.add_argument(
        "--max_token",
        type=int,
        default=100,
        help="Number of new tokens, for autoregressive models using generate.",
    )
    return parser


def get_batch(batch_size, sequence_length):
    tokens = torch.randint(high=5, size=(batch_size, sequence_length))
    mask = torch.ones((batch_size, sequence_length), )
    mask[0, 0] = 0  # real world case where we may mask
    return tokens, mask


def timing_cuda(model, num_batches, input_ids, masks):

    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.synchronize()

    # We need NOT call torch.cuda.empty_cache() here as it appears to negate the warmup.

    latencies = []
    for _ in tqdm(range(num_batches)):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start_event.record()

        _ = model(input_ids, masks)
        end_event.record()
        torch.cuda.synchronize()

        latency_ms = start_event.elapsed_time(end_event)
        latencies.append(latency_ms)

    max_memory = torch.cuda.max_memory_allocated(device)

    return np.mean(latencies), max_memory


def benchmark(model, input_ids, masks, num_batches, max_token, pad_token_id):
    _ = model(input_ids, masks)
    torch.cuda.synchronize()

    total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks)

    return total_time, max_mem


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    if args.sweep:
        BATCH_SIZES = [1, 2, 4]
        SEQ_LEN = [128, 265]
    else:
        BATCH_SIZES = [args.batch_size]
        SEQ_LEN = [args.seqlen]

    device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    tokenizer.padding_side = "left"

    autoclass = AutoModelForQuestionAnswering

    if args.use_cuda:
        with torch.device("cuda:0"):
            hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
                                                 attn_implementation="eager")
        hf_model = hf_model.to("cuda:0")
        hf_model = hf_model.to(torch.float16)
    else:
        hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
                                             attn_implementation="eager")

    output_name = "log_{}.csv".format(args.model_name.replace("/", "-"))
    output_file = open(output_name, "w")
    output_file.write(
        "num_batches, batch_size, seq_len, is cuda, is half, use mask, Per token latency eager (ms), Per token latency SDPA (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n"
    )

    all_max_mem_eager = {}
    total_eager_time = {}
    for bs in tqdm(BATCH_SIZES):
        for seq_len in tqdm(SEQ_LEN):
            print(f"-- Running: bs={bs}, seq_len={seq_len}")
            input_ids, masks = get_batch(bs, seq_len)

            if args.use_cuda:
                input_ids = input_ids.to(device)
                masks = masks.to(device)

            if args.use_mask is False and bs == 1:
                masks = None

            with torch.inference_mode():
                eager_time, max_mem_eager = benchmark(
                    hf_model,
                    input_ids,
                    masks,
                    args.num_batches,
                    args.max_token,
                    tokenizer.pad_token_id,
                )

            total_eager_time[(bs, seq_len)] = eager_time
            all_max_mem_eager[(bs, seq_len)] = max_mem_eager

    del hf_model
    gc.collect()
    total_sdpa_time = {}
    all_max_mem_sdpa = {}

    if args.use_cuda:
        with torch.device("cuda:0"):
            hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
                                                 attn_implementation="sdpa")
        hf_model = hf_model.to("cuda:0")
        hf_model = hf_model.to(torch.float16)
    else:
        hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
                                             attn_implementation="sdpa")

    for bs in tqdm(BATCH_SIZES):
        for seq_len in tqdm(SEQ_LEN):
            print(f"-- Running: bs={bs}, seq_len={seq_len}")
            input_ids, masks = get_batch(bs, seq_len)

            if args.use_cuda:
                input_ids = input_ids.to(device)
                masks = masks.to(device)

            if args.use_mask is False and bs == 1:
                masks = None

            with torch.inference_mode():
                # raise error if no optimized kernel is available
                with torch.backends.cuda.sdp_kernel(
                        enable_flash=True, enable_math=True, enable_mem_efficient=True
                ):
                    sdpa_time, max_mem_sdpa = benchmark(
                        hf_model,
                        input_ids,
                        masks,
                        args.num_batches,
                        args.max_token,
                        tokenizer.pad_token_id,
                    )
                total_sdpa_time[(bs, seq_len)] = sdpa_time
                all_max_mem_sdpa[(bs, seq_len)] = max_mem_sdpa

            per_token_latency_eager = total_eager_time[(bs, seq_len)] / args.max_token
            per_token_latency_sdpa = total_sdpa_time[(bs, seq_len)] / args.max_token

            max_mem_eager = all_max_mem_eager[(bs, seq_len)]
            max_mem_sdpa = all_max_mem_sdpa[(bs, seq_len)]

            speedup = (per_token_latency_eager / per_token_latency_sdpa - 1) * 100
            mem_saved = (max_mem_eager / max_mem_sdpa - 1) * 100

            max_mem_eager = max_mem_eager * 1e-6
            max_mem_sdpa = max_mem_sdpa * 1e-6

            print(f"PT eager: {per_token_latency_eager:.3f} ms, peak {max_mem_eager:.2f} MB")
            print(f"PT SDPA: {per_token_latency_sdpa:.3f} ms, peak {max_mem_sdpa:.2f} MB")

            output_file.write(
                "{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                    args.num_batches,
                    bs,
                    seq_len,
                    args.use_cuda,
                    args.use_half,
                    args.use_mask,
                    f"{per_token_latency_eager:.3f}",
                    f"{per_token_latency_sdpa:.3f}",
                    f"{speedup:.3f}",
                    f"{max_mem_eager:.3f}",
                    f"{max_mem_sdpa:.3f}",
                    f"{mem_saved:.3f}",
                )
            )

    output_file.close()
    print("RESULTS:")
    df = pd.read_csv(output_name)
    print(df.to_markdown(index=False))

@amyeroberts
Copy link
Collaborator

@OmarManzoor, thanks for sharing! I'm surprised by these numbers - we typically would see speeds ups of ~30%, especially given the similarity to other models like BERT. Could you try running just on the AlbertModel class? This will also provide a better comparison between inference and training

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 22, 2024

With simple AlbertModel

Inference

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 1 128 True True True 0.091 0.093 -1.342 37.793 37.826 -0.085
50 1 265 True True True 0.131 0.136 -3.412 41.83 41.97 -0.333
50 2 128 True True True 0.116 0.116 0.398 41.565 41.63 -0.156
50 2 265 True True True 0.209 0.207 0.94 49.636 49.916 -0.56
50 4 128 True True True 0.172 0.17 1.187 49.108 49.238 -0.264
50 4 265 True True True 0.38 0.375 1.365 65.25 65.809 -0.85

I don't think the simple model works for training.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 22, 2024

With AlbertForSequenceClassification

python benchmark_sdpa_training.py --use-half --use-cuda

Training

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
100 1 256 TRUE 0.036 0.035 3.571 208.449 190.889 9.199
100 1 512 TRUE 0.059 0.059 0.195 398.132 321.088 23.995
100 2 256 TRUE 0.055 0.054 2.483 358.542 321.088 11.665
100 2 512 TRUE 0.109 0.108 0.887 753.589 602.66 25.044
100 4 256 TRUE 0.1 0.099 1.512 680.189 602.66 12.864
100 4 512 TRUE 0.202 0.199 1.33 1435.475 1134.14 26.569

Inference 1

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 1 128 True True True 0.093 0.088 5.441 37.599 37.631 -0.086
50 1 265 True True True 0.13 0.136 -4.319 41.425 41.565 -0.336
50 2 128 True True True 0.117 0.117 -0.606 41.172 41.238 -0.158
50 2 265 True True True 0.209 0.206 1.04 48.823 49.103 -0.569
50 4 128 True True True 0.172 0.17 1.228 48.319 48.45 -0.268
50 4 265 True True True 0.381 0.375 1.608 64.019 65.285 -1.939

Inference 2

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
200 1 128 True True True 0.094 0.09 4.541 37.599 37.631 -0.086
200 1 265 True True True 0.13 0.136 -4.15 41.425 41.565 -0.336
200 2 128 True True True 0.117 0.116 0.793 41.172 41.238 -0.158
200 2 265 True True True 0.209 0.207 0.854 48.823 49.103 -0.569
200 4 128 True True True 0.172 0.17 1.16 48.319 48.45 -0.268
200 4 265 True True True 0.381 0.376 1.322 64.019 65.285 -1.939

@OmarManzoor
Copy link
Contributor Author

@amyeroberts Maybe the inference code I am using needs to be modified? Or maybe since this model does not have a decoder, the changes are not significant?

@OmarManzoor
Copy link
Contributor Author

Some more inference benchmarks with AlbertForSequenceClassification

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
100 4 128 True True True 0.173 0.17 1.548 48.319 48.45 -0.268
100 4 256 True True True 0.349 0.346 0.869 63.4 63.922 -0.817
100 4 512 True True True 0.688 0.675 1.85 110.092 94.343 16.693
100 8 128 True True True 0.332 0.326 1.715 63.4 63.66 -0.409
100 8 256 True True True 0.621 0.61 1.887 91.202 92.246 -1.132
100 8 512 True True True 1.338 1.289 3.783 186.159 152.564 22.021
100 16 128 True True True 0.588 0.575 2.38 91.202 91.722 -0.567
100 16 256 True True True 1.209 1.177 2.692 148.378 150.467 -1.388
100 16 512 True True True 2.742 2.656 3.21 338.293 271.102 24.784

@amyeroberts
Copy link
Collaborator

@OmarManzoor, thanks for running some more numbers.

@amyeroberts Maybe the inference code I am using needs to be modified? Or maybe since this model does not have a decoder, the changes are not significant?

This shouldn't matter - BERT is encoder-only. Could you try running the script on BERT to see if you're able to replicate the same speedups reported in the docs?

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 22, 2024

python benchmark_sdpa_inference.py --num-batches 100 --model-name bert-base-uncased --use-half --use-cuda --use-mask --sweep

AutoModelForSequenceClassification

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
100 4 128 True True True 0.163 0.159 2.505 258.348 258.871 -0.202
100 4 256 True True True 0.325 0.322 0.915 273.821 267.921 2.202
100 4 512 True True True 0.643 0.631 1.895 323.512 296.245 9.204
100 8 128 True True True 0.308 0.303 1.6 267.53 267.659 -0.048
100 8 256 True True True 0.576 0.565 1.917 298.346 285.759 4.405
100 8 512 True True True 1.248 1.2 3.959 394.975 342.276 15.397
100 16 128 True True True 0.542 0.53 2.341 285.763 285.235 0.185
100 16 256 True True True 1.118 1.087 2.791 344.643 321.828 7.089
100 16 512 True True True 2.564 2.48 3.417 539.211 434.599 24.071

@amyeroberts
Copy link
Collaborator

@OmarManzoor OK, this indicates to me there might be something wrong with the script or your setup, as you should be seeing the same speedup numbers as in the docs

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 23, 2024

@amyeroberts

Inference benchmarks using GeForce RTX 2060 with 8GB

RESULTS:

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
50 4 128 True True True 0.083 0.071 16.967 48.319 48.45 -0.268
50 4 256 True True True 0.148 0.127 16.37 63.4 63.922 -0.817
50 4 512 True True True 0.31 0.247 25.473 110.092 94.343 16.693
50 8 128 True True True 0.137 0.124 11.102 63.4 63.66 -0.409
50 8 256 True True True 0.271 0.231 17.271 91.202 92.246 -1.132
50 8 512 True True True 0.602 0.48 25.47 186.159 152.564 22.021
50 16 128 True True True 0.252 0.224 12.506 91.202 91.722 -0.567
50 16 256 True True True 0.526 0.448 17.604 148.378 150.467 -1.388
50 16 512 True True True 1.203 0.96 25.365 338.293 271.102 24.784

Probably we can't setup appropriately on kaggle.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 23, 2024

For training AutoModelForSequenceClassification

num_training_steps batch_size seq_len is cuda Time per batch (eager - s) Time per batch (sdpa - s) Speedup (%) Eager peak mem (MB) sdpa peak mem (MB) Mem saving (%)
100 2 256 True 0.028 0.024 14.388 358.411 321.088 11.624
100 2 512 True 0.049 0.041 17.681 753.458 602.660 25.022
100 4 256 True 0.044 0.039 12.246 679.534 602.660 12.756
100 4 512 True 0.090 0.076 18.472 1434.820 1134.140 26.512
100 8 256 True 0.081 0.072 12.664 1283.825 1134.140 13.198
100 8 512 True 0.170 0.143 18.957 2820.398 2219.695 27.062

@qubvel
Copy link
Member

qubvel commented Jul 23, 2024

Hi @OmarManzoor thanks for working on this!

I got the following result on your branch

Env:

- `transformers` version: 4.43.0.dev0
- Platform: Linux-6.5.0-1020-aws-x86_64-with-glibc2.35
- Python version: 3.10.12
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- GPU type: NVIDIA A10G

Inference speeedup:

num_batches batch_size seq_len is cuda is half use mask Per token latency eager (ms) Per token latency SDPA (ms) Speedup (%) Mem eager (MB) Mem BT (MB) Mem saved (%)
1000 1 128 True True True 0.072 0.059 22.456 251.132 251.427 -0.117
1000 1 265 True True True 0.065 0.057 13.205 254.735 253.394 0.529
1000 2 128 True True True 0.071 0.056 26.882 252.838 253.165 -0.129
1000 2 265 True True True 0.069 0.058 19.653 260.96 258.016 1.141
1000 4 128 True True True 0.071 0.055 27.826 257.168 257.56 -0.152
1000 4 265 True True True 0.084 0.066 25.907 273.766 267.26 2.434

@amyeroberts
Copy link
Collaborator

@OmarManzoor Thanks for iterating on this! Could you rebase on main to include the recent upstream changes? This should solve the code quality checks

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

Thanks for all your work on this @OmarManzoor!

Only thing remaining is a final slow model run ([run_slow] albert commit message), as there have been a few commits since the last. Then we're good to merge!

@OmarManzoor
Copy link
Contributor Author

@amyeroberts Can this be merged now?

@amyeroberts
Copy link
Collaborator

@OmarManzoor Yep! Thanks for all your work on this

@amyeroberts amyeroberts merged commit 03c12d0 into huggingface:main Sep 3, 2024
22 checks passed
@OmarManzoor OmarManzoor deleted the albert_sdpa branch September 3, 2024 14:05
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* Add sdpa support for Albert

* [run_slow] albert

* Add benchmarks and PR suggestion

* Fix quality

* Fix

* [run_slow] albert
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants