Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cuBLAS performance by dequantizing on the GPU #1065

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ if (APPLE AND LLAMA_ACCELERATE)
message(WARNING "Accelerate framework not found")
endif()
endif()

if (LLAMA_OPENBLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
Expand Down Expand Up @@ -150,6 +151,10 @@ if (LLAMA_CUBLAS)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")

enable_language(CUDA)

set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there was a discussion somewhere recently about splitting out the accel specific code into dedicated .c files. what was the state on that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to keep all the cuda code in ggml-cuda.cu to avoid having to compile ggml with nvcc, but otherwise nothing changed.


add_compile_definitions(GGML_USE_CUBLAS)

if (LLAMA_STATIC)
Expand Down Expand Up @@ -242,21 +247,26 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options(/arch:AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions(__AVX512VBMI__)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions(__AVX512VNNI__)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
add_compile_options(/arch:AVX2)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
add_compile_options(/arch:AVX)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
Expand Down Expand Up @@ -293,7 +303,8 @@ endif()

add_library(ggml OBJECT
ggml.c
ggml.h)
ggml.h
${GGML_CUDA_SOURCES})

target_include_directories(ggml PUBLIC .)
target_compile_features(ggml PUBLIC c_std_11) # don't bump
Expand All @@ -315,6 +326,14 @@ if (BUILD_SHARED_LIBS)
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif()

if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
endif()


#
# programs, examples and tests
#
Expand Down
24 changes: 14 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Define the default target now so that it is always the first target
default: main quantize quantize-stats perplexity embedding vdot

ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
Expand Down Expand Up @@ -100,6 +103,9 @@ endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
OBJS += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
nvcc -arch=native -c -o $@ $<
endif
ifdef LLAMA_GPROF
CFLAGS += -pg
Expand Down Expand Up @@ -137,8 +143,6 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )

default: main quantize quantize-stats perplexity embedding vdot

#
# Build library
#
Expand All @@ -155,35 +159,35 @@ common.o: examples/common.cpp examples/common.h
clean:
rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult

main: examples/main/main.cpp ggml.o llama.o common.o
main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo

quantize: examples/quantize/quantize.cpp ggml.o llama.o
quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o
quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

vdot: pocs/vdot/vdot.cpp ggml.o
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

libllama.so: llama.o ggml.o
libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)

#
# Tests
#

benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o
benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS)
./benchmark-q4_0-matmult

Expand Down
116 changes: 116 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include <stdint.h>
#include <cuda_fp16.h>
#include "ggml-cuda.h"

typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");

#define QK4_0 32
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");

#define QK4_1 32
typedef struct {
float d; // delta
float m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");

#define QK4_2 16
typedef struct {
__half d; // delta
uint8_t qs[QK4_2 / 2]; // nibbles / quants
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");


static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;

const int i = blockIdx.x;

const float d = x[i].d;

const uint8_t * pp = x[i].qs;

for (int l = 0; l < QK4_0; l += 2) {
const uint8_t vi = pp[l/2];

const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;

const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;

y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
}
}

static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
const block_q4_1 * x = (const block_q4_1 *) vx;

const int i = blockIdx.x;

const float d = x[i].d;
const float m = x[i].m;

const uint8_t * pp = x[i].qs;

for (int l = 0; l < QK4_1; l += 2) {
const uint8_t vi = pp[l/2];

const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;

const float v0 = vi0*d + m;
const float v1 = vi1*d + m;

y[i*QK4_1 + l + 0] = v0;
y[i*QK4_1 + l + 1] = v1;
}
}

static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
const block_q4_2 * x = (const block_q4_2 *) vx;

const int i = blockIdx.x;

const float d = x[i].d;

const uint8_t * pp = x[i].qs;

for (int l = 0; l < QK4_2; l += 2) {
const uint8_t vi = pp[l/2];

const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;

const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;

y[i*QK4_2 + l + 0] = v0;
y[i*QK4_2 + l + 1] = v1;
}
}

extern "C" {
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}

__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}

__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}
}
11 changes: 11 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifdef __cplusplus
extern "C" {
#endif

void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);

#ifdef __cplusplus
}
#endif
Loading