Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial support for macOS aarch64 #49

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/cmake-darwin.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: macOS

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

env:
BUILD_TYPE: Release

jobs:
build:
runs-on: macos-14
steps:
- uses: actions/checkout@v3
- name: Install openmp
run: brew install libomp
- name: Configure CMake
run: OpenMP_ROOT=$(brew --prefix)/opt/libomp cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}}
- name: Build
run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}}
- name: Test
run: ctest --verbose -C ${{env.BUILD_TYPE}} --test-dir ${{github.workspace}}/build/src/libllm
2 changes: 1 addition & 1 deletion .github/workflows/cmake-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
- name: Build
run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}}
- name: Test
run: ctest --verbose -C ${{env.BUILD_TYPE}} --test-dir ${{github.workspace}}/build/src/libllm
run: ${{github.workspace}}\build\src\libllm\${{env.BUILD_TYPE}}\unittest.exe
18 changes: 14 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,22 @@ set(MKL_PREFIX "/opt/intel/mkl" CACHE STRING "Prefix for MKL headers and librari
#add_link_options(-fsanitize=address)

if(WIN32)
add_definitions( "/D_CRT_SECURE_NO_WARNINGS /DCATCH_AMALGAMATED_CUSTOM_MAIN /DCATCH_CONFIG_PREFIX_ALL" )
add_definitions( "/D_CRT_SECURE_NO_WARNINGS /DCATCH_AMALGAMATED_CUSTOM_MAIN /DCATCH_CONFIG_PREFIX_ALL" )
endif(WIN32)
if(UNIX)
add_definitions( "-DCATCH_AMALGAMATED_CUSTOM_MAIN -DCATCH_CONFIG_PREFIX_ALL" )
set(CMAKE_CXX_FLAGS "-O3 -g")
set(CMAKE_C_FLAGS "-O3 -g")
add_definitions( "-DCATCH_AMALGAMATED_CUSTOM_MAIN -DCATCH_CONFIG_PREFIX_ALL -D_FILE_OFFSET_BITS=64" )
set(CMAKE_CXX_FLAGS "-O3 -g")
set(CMAKE_C_FLAGS "-O3 -g")
endif(UNIX)

message("CMAKE_HOST_SYSTEM_PROCESSOR=" ${CMAKE_HOST_SYSTEM_PROCESSOR})

if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
add_compile_definitions(LIBLLM_ARCH_X86_64)
set(LIBLLM_KERNEL_X86_64 ON)
endif()
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)")
add_compile_definitions(LIBLLM_ARCH_AARCH64)
endif()

add_subdirectory("src/libllm")
70 changes: 44 additions & 26 deletions src/libllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ set(lut_SOURCES
"lut/zip_file.cc")

set(libllm_SOURCES
"cpu/kernel/kernel_avx2.cc"
"cpu/kernel/kernel_avx512.cc"
"cpu/kernel/kernel_fallback.cc"
"cpu/kernel/kernel.cc"
"cpu/kernel/util.cc"
Expand Down Expand Up @@ -80,7 +78,7 @@ set(llm_SOURCES
"dialog_manager.cc"
"llm_main.cc")

set(libllm_INCDIR ".." "../../third_party")
set(libllm_INCDIR ".." "../../third_party" ${OpenMP_CXX_INCLUDE_DIRS})

if (WITH_CUDA)
set(libllm_INCDIR ${libllm_INCDIR} ${CUDAToolkit_INCLUDE_DIRS})
Expand Down Expand Up @@ -118,38 +116,58 @@ if (WITH_CUDA)
"lut/internal/log.cc")
endif()

# OS specific code
if(WIN32)
set_source_files_properties(
"cpu/kernel/kernel_avx512.cc"
PROPERTIES COMPILE_FLAGS /arch:AVX512)
set_source_files_properties(
"cpu/kernel/kernel_avx2.cc"
PROPERTIES COMPILE_FLAGS /arch:AVX2)
set(libllm_SOURCES
${libllm_SOURCES}
"lut/path_windows.cc"
"lut/platform_windows.cc"
"lut/shared_library_windows.cc")
endif(WIN32)

endif()
if(UNIX)
set_source_files_properties(
"cpu/kernel/kernel_avx512.cc"
PROPERTIES COMPILE_FLAGS "-mavx512f")
set_source_files_properties(
"cpu/kernel/kernel_avx2.cc"
PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c")
set(libllm_SOURCES
${libllm_SOURCES}
"lut/path_linux.cc"
"lut/platform_linux.cc"
"lut/shared_library_linux.cc")
endif(UNIX)
endif()
if(UNIX AND APPLE)
set(libllm_SOURCES
${libllm_SOURCES}
"lut/path_darwin.cc")
endif()
if(UNIX AND NOT APPLE)
set(libllm_SOURCES
${libllm_SOURCES}
"lut/path_linux.cc")
endif()

# CPU specific code
if(LIBLLM_KERNEL_X86_64)
set(libllm_SOURCES
${libllm_SOURCES}
"cpu/kernel/kernel_avx2.cc"
"cpu/kernel/kernel_avx512.cc")
if(WIN32)
set_source_files_properties(
"cpu/kernel/kernel_avx512.cc"
PROPERTIES COMPILE_FLAGS /arch:AVX512)
set_source_files_properties(
"cpu/kernel/kernel_avx2.cc"
PROPERTIES COMPILE_FLAGS /arch:AVX2)
endif(WIN32)
if(UNIX)
set_source_files_properties(
"cpu/kernel/kernel_avx512.cc"
PROPERTIES COMPILE_FLAGS "-mavx512f")
set_source_files_properties(
"cpu/kernel/kernel_avx2.cc"
PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c")
endif(UNIX)
endif()

add_library(lut STATIC ${lut_SOURCES})
set_target_properties(lut PROPERTIES CXX_VISIBILITY_PRESET hidden)
target_include_directories(lut PRIVATE ".." "../../third_party/")

target_include_directories(lut PRIVATE ${libllm_INCDIR})

set(libllm_LIBADD
lut
Expand All @@ -165,22 +183,22 @@ target_include_directories(libllm_static PRIVATE ${libllm_INCDIR})
add_library(libllm SHARED $<TARGET_OBJECTS:libllm_static>)
target_link_libraries(libllm ${libllm_LIBADD} )
set_property(TARGET libllm PROPERTY OUTPUT_NAME llm)
if(UNIX)
if(UNIX AND NOT APPLE)
target_link_options(libllm PUBLIC "-Wl,--no-undefined")
endif(UNIX)
endif()

add_library(catch2 STATIC "../../third_party/catch2/catch_amalgamated.cpp")
add_executable(unittest ${unittest_SOURCES})
target_include_directories(unittest PRIVATE .. "../../third_party/")
target_link_libraries(unittest libllm_static lut catch2)
target_include_directories(unittest PRIVATE ${libllm_INCDIR})
target_link_libraries(unittest libllm_static lut catch2 OpenMP::OpenMP_CXX)

add_executable(llm ${llm_SOURCES})
target_include_directories(llm PRIVATE ..)
target_link_libraries(llm libllm lut)

if (WITH_CUDA)
add_library(llmextcublas SHARED ${llmextcublas_SOURCES})
target_include_directories(llmextcublas PRIVATE .. "../../third_party/")
target_include_directories(llmextcublas PRIVATE ${libllm_INCDIR})
target_link_libraries(llmextcublas lut CUDA::cublas)
if(UNIX)
target_link_options(llmextcublas PUBLIC "-Wl,--no-undefined")
Expand Down
24 changes: 19 additions & 5 deletions src/libllm/cpu/kernel/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ namespace kernel {
enum class CPUMathBackend {
DEFAULT,
AVX2,
AVX512
AVX512,
ASIMDHP
};

CPUMathBackend findBestCpuMathBackend() {
Expand All @@ -49,19 +50,30 @@ CPUMathBackend findBestCpuMathBackend() {
bool isaAvx512f = ruapu_supports("avx512f") > 0;
bool isaF16c = ruapu_supports("f16c") > 0;

#ifdef LIBLLM_ARCH_X86_64
LOG(INFO) << lut::sprintf(
"ISA support: AVX2=%d F16C=%d AVX512F=%d", isaAvx2, isaF16c, isaAvx512f);
#endif // LIBLLM_ARCH_X86_64

#ifdef LIBLLM_ARCH_X86_64
if (isaAvx512f && isaF16c) {
LOG(INFO) << "Use Avx512 backend.";
return CPUMathBackend::AVX512;
} else if (isaAvx2 && isaF16c) {
}

if (isaAvx2 && isaF16c) {
LOG(INFO) << "Use Avx2 backend.";
return CPUMathBackend::AVX2;
} else {
LOG(FATAL) << "CPU not supported (AVX2 and F16C is required).";
NOT_IMPL();
}
#endif // LIBLLM_ARCH_X86_64

#ifdef LIBLLM_ARCH_AARCH64
LOG(INFO) << "Use default backend.";
return CPUMathBackend::DEFAULT;
#endif // LIBLLM_ARCH_AARCH64

LOG(FATAL) << "CPU not supported.";
NOT_IMPL();
}

// instance of Api.
Expand Down Expand Up @@ -106,6 +118,7 @@ void Api::init() {

_instance = new Api();
switch (findBestCpuMathBackend()) {
#ifdef LIBLLM_ARCH_X86_64
case CPUMathBackend::AVX512:
_instance->_sgemm = std::make_unique<SGEMMImplAvx512>();
_instance->_sgemmOmp = std::make_unique<SGEMMImplAvx512OMP>();
Expand All @@ -120,6 +133,7 @@ void Api::init() {
_instance->_q4dequant = std::make_unique<DequantQ4Avx2OMP>();
_instance->_cvtHalfToFloat = std::make_unique<CvtHalfToFloatAvx2OMP>();
break;
#endif // LIBLLM_ARCH_X86_64
case CPUMathBackend::DEFAULT:
_instance->_sgemm = std::make_unique<SGEMMImplDefault>();
_instance->_sgemmOmp = std::make_unique<SGEMMImplDefaultOMP>();
Expand Down
19 changes: 18 additions & 1 deletion src/libllm/cpu/kernel/kernel_fallback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,29 @@ void SAxpyFallbackKernel::apply(int64_t n, float a, PCFp32 x, PFp32 y) {
const float *px = x;
float *py = y;
for (int i = 0; i < n; ++i) {
*py = a * *px;
*py += a * *px;
++px;
++py;
}
}

void SAxpyFallbackKernel::applyColumn(const SGEMVArgs &args, int column, float *y) {
apply(args.N, args.x[column], args.A + column * args.lda, y);
}

float SDotFallbackKernel::apply(int64_t n, const float *x, const float *y) {
float sum = 0;
for (int64_t i = 0; i < n; ++i) {
sum += x[i] * y[i];
}

return sum;
}

float SDotFallbackKernel::applyRow(const SGEMVArgs &args, int row) {
return apply(args.N, args.A + row * args.lda, args.x);
}

void CvtHalfToFloatFallbackKernel::apply(int64_t n, PCFp16 x, PFp32 y) {
for (int i = 0; i < n; ++i) {
y[i] = lut::cvtsh_ss(x[i]);
Expand Down
4 changes: 2 additions & 2 deletions src/libllm/cpu/kernel/sgemv.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ class SGEMVImpl : public SGEMV {

typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::SingleThread> SGEMVImplAvx512;
typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::SingleThread> SGEMVImplAvx2;
typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::SingleThread> SGEMVImplDefault;
typedef SGEMVImpl<SAxpyFallbackKernel, SDotFallbackKernel, Mode::SingleThread> SGEMVImplDefault;
typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::OMP> SGEMVImplAvx512OMP;
typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::OMP> SGEMVImplAvx2OMP;
typedef SGEMVImpl<SAxpyAvx2Kernel, SDotAvx2Kernel, Mode::OMP> SGEMVImplDefaultOMP;
typedef SGEMVImpl<SAxpyFallbackKernel, SDotFallbackKernel, Mode::OMP> SGEMVImplDefaultOMP;

} // namespace kernel
} // namespace cpu
Expand Down
7 changes: 7 additions & 0 deletions src/libllm/cpu/kernel/skernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ struct SDotAvx2Kernel {
static float applyRow(const SGEMVArgs &args, int row);
};

struct SDotFallbackKernel {
typedef float ValueType;

static float apply(int64_t n, const float *x, const float *y);
static float applyRow(const SGEMVArgs &args, int row);
};

} // namespace kernel
} // namespace cpu
} // namespace op
Expand Down
6 changes: 5 additions & 1 deletion src/libllm/cpu/kernel/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ void refSgemm(
}
}


#ifdef LIBLLM_ARCH_X86_64

CATCH_TEST_CASE("test q4 dequantization", "[lymath][dequant][q4]") {
constexpr int DIM = DequantMinElemPerThread * 2 + GroupSizeQ4;

Expand Down Expand Up @@ -269,6 +272,7 @@ CATCH_TEST_CASE("test q4 dot kernels apply row", "[lymath][dot][q4]") {
float a = DotQ4Avx2Kernel::apply(NUM_COL * 2, x2.data(), {A.data(), scaleA.data(), zeroA.data()}, 0);
CATCH_REQUIRE(isClose(a, a0 + a1));
}
#endif // LIBLLM_ARCH_X86_64

CATCH_TEST_CASE("test lymath_q4gemm", "[lymath][api][q4]") {
testGemmQ4(true, 1, 32, 128);
Expand Down Expand Up @@ -359,7 +363,7 @@ void testHalfToFloat(int n) {
random.fill(lut::makeSpan(yr));
std::transform(yr.begin(), yr.end(), x.begin(), lut::cvtss_sh);

CvtHalfToFloatAvx2OMP().apply(n, x.data(), y.data());
convertHalfToFloat(n, x.data(), y.data());
CATCH_REQUIRE(isClose(yr, y, 1e-4, 1e-3));
}

Expand Down
8 changes: 0 additions & 8 deletions src/libllm/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ DType DType::getTypeImpl<half>() {
#endif


template DType DType::getTypeImpl<float>();
template DType DType::getTypeImpl<int64_t>();
template DType DType::getTypeImpl<UInt8>();
template DType DType::getTypeImpl<Float16>();
template DType DType::getTypeImpl<Q4>();
template DType DType::getTypeImpl<Int8>();


int64_t DType::getTotalSize(int64_t numel) const {
switch (_dtype) {
case DType::kFloat:
Expand Down
Loading
Loading