diff --git a/README.md b/README.md index 1088b3338ff17..8fe1f4b4b6a7a 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ Typically finetunes of the base models below are supported as well. - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) - [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat) - [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a) +- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM) (instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md)) @@ -122,6 +123,7 @@ Typically finetunes of the base models below are supported as well. - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) - C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) +- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html) - Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s) - Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj) - React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn) @@ -172,6 +174,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: - [LARS - The LLM & Advanced Referencing Solution](https://github.com/abgulati/LARS) (AGPL) - [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT) - [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL) +- [PocketPal AI - An iOS and Android App](https://github.com/a-ghorbani/pocketpal-ai) (MIT) *(to have a project listed here, it should clearly state that it depends on `llama.cpp`)* @@ -187,6 +190,7 @@ Unless otherwise noted these projects are open-source with permissive licensing: - [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp - [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs +- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly **Games:** - [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you. diff --git a/ci/run.sh b/ci/run.sh index e067782193b9b..dc26d94eed1fd 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -53,7 +53,7 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then exit 1 fi - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" fi if [ ! -z ${GG_BUILD_VULKAN} ]; then diff --git a/common/arg.cpp b/common/arg.cpp index d6a8e1f6ff0bf..e1e933934f0ef 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - params.model = fs_get_cache_file(string_split(f, '/').back()); + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -251,6 +251,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & antiprompt : params.antiprompt) { string_process_escapes(antiprompt); } + for (auto & seq_breaker : params.sparams.dry_sequence_breakers) { + string_process_escapes(seq_breaker); + } } if (!params.kv_overrides.empty()) { @@ -879,7 +882,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--samplers"}, "SAMPLERS", string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), [](common_params & params, const std::string & value) { - const auto sampler_names = string_split(value, ';'); + const auto sampler_names = string_split(value, ';'); params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); } ).set_sparam()); @@ -997,6 +1000,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sparams.penalty_freq = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--dry-multiplier"}, "N", + string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier), + [](common_params & params, const std::string & value) { + params.sparams.dry_multiplier = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-base"}, "N", + string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base), + [](common_params & params, const std::string & value) { + float potential_base = std::stof(value); + if (potential_base >= 1.0f) + { + params.sparams.dry_base = potential_base; + } + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-allowed-length"}, "N", + string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length), + [](common_params & params, int value) { + params.sparams.dry_allowed_length = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-penalty-last-n"}, "N", + string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n), + [](common_params & params, int value) { + params.sparams.dry_penalty_last_n = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-sequence-breaker"}, "STRING", + string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n", + params.sparams.dry_sequence_breakers.empty() ? "none" : + std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()), + params.sparams.dry_sequence_breakers.end(), + std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'", + [](const std::string& a, const std::string& b) { + std::string formatted_b = (b == "\n") ? "\\n" : b; + return a + ", '" + formatted_b + "'"; + }).c_str()), + [](common_params & params, const std::string & value) { + static bool defaults_cleared = false; + + if (!defaults_cleared) { + params.sparams.dry_sequence_breakers.clear(); + defaults_cleared = true; + } + + if (value == "none") { + params.sparams.dry_sequence_breakers.clear(); + } else { + params.sparams.dry_sequence_breakers.emplace_back(value); + } + } + ).set_sparam()); add_opt(common_arg( {"--dynatemp-range"}, "N", string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), @@ -1097,7 +1158,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); add_opt(common_arg( - {"--attention"}, "{causal,non,causal}", + {"--attention"}, "{causal,non-causal}", "attention type for embeddings, use model default if unspecified", [](common_params & params, const std::string & value) { /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } @@ -1695,7 +1756,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_BENCH})); add_opt(common_arg( {"--embd-normalize"}, "N", - string_format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), + string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), [](common_params & params, int value) { params.embd_normalize = value; } @@ -1709,7 +1770,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--embd-separator"}, "STRING", - "separator of embendings (default \\n) for example \"<#sep#>\"", + "separator of embeddings (default \\n) for example \"<#sep#>\"", [](common_params & params, const std::string & value) { params.embd_sep = value; } diff --git a/common/common.cpp b/common/common.cpp index 2bc0b8800a939..ff8cc4076e95d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) { return std::string(buf.data(), size); } -std::vector string_split(std::string input, char separator) { - std::vector parts; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(0, separator_pos); - parts.emplace_back(part); - input = input.substr(separator_pos + 1); - separator_pos = input.find(separator); - } - parts.emplace_back(input); - return parts; -} - std::string string_strip(const std::string & str) { size_t start = 0; size_t end = str.size(); @@ -1035,7 +1022,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { return GGML_TYPE_Q5_1; } - throw std::runtime_error("Invalid cache type: " + s); + throw std::runtime_error("Unsupported cache type: " + s); } struct llama_context_params common_context_params_to_llama(const common_params & params) { @@ -1047,7 +1034,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? - params.cpuparams.n_threads : params.cpuparams_batch.n_threads; + params.cpuparams.n_threads : params.cpuparams_batch.n_threads; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -2019,6 +2006,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); + fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); + fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); + fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); diff --git a/common/common.h b/common/common.h index 5ca8fd391ab74..18b2121ed89b0 100644 --- a/common/common.h +++ b/common/common.h @@ -84,14 +84,15 @@ enum llama_example { enum common_sampler_type { COMMON_SAMPLER_TYPE_NONE = 0, - COMMON_SAMPLER_TYPE_TOP_K = 1, - COMMON_SAMPLER_TYPE_TOP_P = 2, - COMMON_SAMPLER_TYPE_MIN_P = 3, - COMMON_SAMPLER_TYPE_TFS_Z = 4, - COMMON_SAMPLER_TYPE_TYPICAL_P = 5, - COMMON_SAMPLER_TYPE_TEMPERATURE = 6, - COMMON_SAMPLER_TYPE_XTC = 7, - COMMON_SAMPLER_TYPE_INFILL = 8, + COMMON_SAMPLER_TYPE_DRY = 1, + COMMON_SAMPLER_TYPE_TOP_K = 2, + COMMON_SAMPLER_TYPE_TOP_P = 3, + COMMON_SAMPLER_TYPE_MIN_P = 4, + COMMON_SAMPLER_TYPE_TFS_Z = 5, + COMMON_SAMPLER_TYPE_TYPICAL_P = 6, + COMMON_SAMPLER_TYPE_TEMPERATURE = 7, + COMMON_SAMPLER_TYPE_XTC = 8, + COMMON_SAMPLER_TYPE_INFILL = 9, }; // dimensionality reduction methods, used by cvector-generator @@ -104,32 +105,39 @@ enum dimre_method { struct common_sampler_params { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float tfs_z = 1.00f; // 1.0 = disabled - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - bool ignore_eos = false; - bool no_perf = false; // disable performance metrics + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float tfs_z = 1.00f; // 1.0 = disabled + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; + bool no_perf = false; // disable performance metrics + + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY std::vector samplers = { + COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TFS_Z, COMMON_SAMPLER_TYPE_TYPICAL_P, @@ -274,9 +282,9 @@ struct common_params { // embedding bool embedding = false; // get only sentence embedding - int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix - std::string embd_sep = "\n"; // separator of embendings + std::string embd_sep = "\n"; // separator of embeddings bool reranking = false; // enable reranking support on server // server params @@ -380,8 +388,6 @@ bool set_process_priority(enum ggml_sched_priority prio); LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) std::string string_format(const char * fmt, ...); -std::vector string_split(std::string input, char separator); - std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); @@ -389,6 +395,7 @@ void string_replace_all(std::string & s, const std::string & search, const std:: template static std::vector string_split(const std::string & str, char delim) { + static_assert(!std::is_same::value, "Please use the specialized version for std::string"); std::vector values; std::istringstream str_stream(str); std::string token; @@ -401,6 +408,22 @@ static std::vector string_split(const std::string & str, char delim) { return values; } +template<> +std::vector string_split(const std::string & input, char separator) +{ + std::vector parts; + size_t begin_pos = 0; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(begin_pos, separator_pos - begin_pos); + parts.emplace_back(part); + begin_pos = separator_pos + 1; + separator_pos = input.find(separator, begin_pos); + } + parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + return parts; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); diff --git a/common/sampling.cpp b/common/sampling.cpp index 56cd0df6b81bc..48a9df8ba5b88 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -130,9 +130,11 @@ std::string common_sampler_params::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, + dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, mirostat, mirostat_eta, mirostat_tau); @@ -171,60 +173,57 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.penalize_nl, params.ignore_eos)); - if (params.temp > 0.0f) { - if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + if (params.mirostat == 0) { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto& str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } + + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TFS_Z: - llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); - } + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TFS_Z: + llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); } - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); - } else if (params.mirostat == 1) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); - } else if (params.mirostat == 2) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); - } else { - GGML_ASSERT(false && "unknown mirostat version"); } + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } else if (params.mirostat == 1) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + } else if (params.mirostat == 2) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { - if (params.n_probs > 0) { - // some use cases require to sample greedily, but still obtain the probabilities of the top tokens - // ref: https://github.com/ggerganov/llama.cpp/pull/9605 - // - // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but - // it is much faster, since we avoid sorting all tokens and should give a good approximation - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs)); - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - } - llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); + GGML_ASSERT(false && "unknown mirostat version"); } return result; @@ -372,6 +371,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: return 'd'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TFS_Z: return 'f'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; @@ -386,6 +386,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: return "dry"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; @@ -400,6 +401,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { + { "dry", COMMON_SAMPLER_TYPE_DRY }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, @@ -448,6 +450,7 @@ std::vector common_sampler_types_from_names(const std::vect std::vector common_sampler_types_from_chars(const std::string & chars) { std::unordered_map sampler_name_map = { + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da5feb25b1961..a34dabe235a34 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -573,6 +573,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 res = "bert-bge" + if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": + # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 + res = "bert-bge-large" if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": # ref: https://huggingface.co/mosaicml/mpt-7b res = "mpt" @@ -2864,6 +2867,9 @@ def set_vocab(self): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.chat_template = "rwkv-world" + # hack: Add '\n\n' as the EOT token to make it chat normally + special_vocab._set_special_token("eot", 261) special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 022354a3b624e..28cd02e5a7f66 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -72,6 +72,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, + {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 439a78de108ca..bc68f68afb768 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -348,6 +348,9 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: if ".base_layer.weight" in name: continue logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + if ".embed_tokens.weight" in name or ".lm_head.weight" in name: + logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") + logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()") sys.exit(1) if base_name in tensor_map: diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index dcd9803a2adc2..65cd4eb515c7f 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -46,7 +46,6 @@ actor LlamaContext { let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) - llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax()) llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) } diff --git a/examples/llama.vim b/examples/llama.vim new file mode 100644 index 0000000000000..57eb2a9772d51 --- /dev/null +++ b/examples/llama.vim @@ -0,0 +1,783 @@ +" LLM-based text completion using llama.cpp +" +" requires: +" +" - neovim or vim +" - curl +" - llama.cpp server instance +" - FIM-compatible model +" +" sample config: +" +" - Tab - accept the current suggestion +" - Shift+Tab - accept just the first line of the suggestion +" - Ctrl+F - toggle FIM completion manually +" +" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim +" +" start the llama.cpp server with a FIM-compatible model. for example: +" +" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256 +" +" --batch-size [512, model max context] +" +" adjust the batch size to control how much of the provided local context will be used during the inference +" lower values will use smaller part of the context around the cursor, which will result in faster processing +" +" --ubatch-size [64, 2048] +" +" chunks the batch into smaller chunks for faster processing +" depends on the specific hardware. use llama-bench to profile and determine the best size +" +" --cache-reuse (ge:llama_config.n_predict, 1024] +" +" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict +" using non-zero value enables context reuse on the server side which dramatically improves the performance at +" large contexts. a value of 256 should be good for all cases +" +" run this once to initialise llama.vim: +" +" :call llama#init() +" +" more info: https://github.com/ggerganov/llama.cpp/pull/9787 +" + +" colors (adjust to your liking) +highlight llama_hl_hint guifg=#ff772f ctermfg=202 +highlight llama_hl_info guifg=#77ff2f ctermfg=119 + +" general parameters: +" +" endpoint: llama.cpp server endpoint +" n_prefix: number of lines before the cursor location to include in the local prefix +" n_suffix: number of lines after the cursor location to include in the local suffix +" n_predict: max number of tokens to predict +" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported) +" t_max_predict_ms: max alloted time for the prediction +" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline) +" auto_fim: trigger FIM completion automatically on cursor movement +" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor +" +" ring buffer of chunks, accumulated with time upon: +" +" - completion request +" - yank +" - entering a buffer +" - leaving a buffer +" - writing a file +" +" parameters for the ring-buffer with extra context: +" +" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable) +" ring_chunk_size: max size of the chunks (in number of lines) +" note: adjust these numbers so that you don't overrun your context +" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context +" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM +" ring_update_ms: how often to process queued chunks in normal mode +" +let s:default_config = { + \ 'endpoint': 'http://127.0.0.1:8012/infill', + \ 'n_prefix': 256, + \ 'n_suffix': 64, + \ 'n_predict': 128, + \ 't_max_prompt_ms': 500, + \ 't_max_predict_ms': 3000, + \ 'show_info': 2, + \ 'auto_fim': v:true, + \ 'max_line_suffix': 8, + \ 'ring_n_chunks': 64, + \ 'ring_chunk_size': 64, + \ 'ring_scope': 1024, + \ 'ring_update_ms': 1000, + \ } + +let g:llama_config = get(g:, 'llama_config', s:default_config) + +function! s:get_indent(str) + let l:count = 0 + for i in range(len(a:str)) + if a:str[i] == "\t" + let l:count += &tabstop - 1 + else + break + endif + endfor + return l:count +endfunction + +function! s:rand(i0, i1) abort + return a:i0 + rand() % (a:i1 - a:i0 + 1) +endfunction + +function! llama#init() + if !executable('curl') + echohl WarningMsg + echo 'llama.vim requires the "curl" command to be available' + echohl None + return + endif + + let s:pos_x = 0 " cursor position upon start of completion + let s:pos_y = 0 + + let s:line_cur = '' + + let s:line_cur_prefix = '' + let s:line_cur_suffix = '' + + let s:ring_chunks = [] " current set of chunks used as extra context + let s:ring_queued = [] " chunks that are queued to be sent for processing + let s:ring_n_evict = 0 + + let s:hint_shown = v:false + let s:pos_y_pick = -9999 " last y where we picked a chunk + let s:pos_dx = 0 + let s:content = [] + let s:can_accept = v:false + + let s:timer_fim = -1 + let s:t_fim_start = reltime() " used to measure total FIM time + let s:t_last_move = reltime() " last time the cursor moved + + let s:current_job = v:null + + let s:ghost_text_nvim = exists('*nvim_buf_get_mark') + let s:ghost_text_vim = has('textprop') + + if s:ghost_text_vim + let s:hlgroup_hint = 'llama_hl_hint' + let s:hlgroup_info = 'llama_hl_info' + + if empty(prop_type_get(s:hlgroup_hint)) + call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint}) + endif + if empty(prop_type_get(s:hlgroup_info)) + call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info}) + endif + endif + + augroup llama + autocmd! + autocmd InsertEnter * inoremap llama#fim_inline(v:false) + autocmd InsertLeavePre * call llama#fim_cancel() + + autocmd CursorMoved * call s:on_move() + autocmd CursorMovedI * call s:on_move() + autocmd CompleteChanged * call llama#fim_cancel() + + if g:llama_config.auto_fim + autocmd CursorMovedI * call llama#fim(v:true) + endif + + " gather chunks upon yanking + autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif + + " gather chunks upon entering/leaving a buffer + autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)}) + autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + + " gather chunk upon saving the file + autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + augroup END + + silent! call llama#fim_cancel() + + " init background update of the ring buffer + if g:llama_config.ring_n_chunks > 0 + call s:ring_update() + endif +endfunction + +" compute how similar two chunks of text are +" 0 - no similarity, 1 - high similarity +" TODO: figure out something better +function! s:chunk_sim(c0, c1) + let l:lines0 = len(a:c0) + let l:lines1 = len(a:c1) + + let l:common = 0 + + for l:line0 in a:c0 + for l:line1 in a:c1 + if l:line0 == l:line1 + let l:common += 1 + break + endif + endfor + endfor + + return 2.0 * l:common / (l:lines0 + l:lines1) +endfunction + +" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing +" +" no_mod - do not pick chunks from buffers with pending changes +" do_evict - evict chunks that are very similar to the new one +" +function! s:pick_chunk(text, no_mod, do_evict) + " do not pick chunks from buffers with pending changes or buffers that are not files + if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%'))) + return + endif + + " if the extra context option is disabled - do nothing + if g:llama_config.ring_n_chunks <= 0 + return + endif + + " don't pick very small chunks + if len(a:text) < 3 + return + endif + + if len(a:text) + 1 < g:llama_config.ring_chunk_size + let l:chunk = a:text + else + let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2])) + let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)]) + + let l:chunk = a:text[l:l0:l:l1] + endif + + let l:chunk_str = join(l:chunk, "\n") . "\n" + + " check if this chunk is already added + let l:exist = v:false + + for i in range(len(s:ring_chunks)) + if s:ring_chunks[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + for i in range(len(s:ring_queued)) + if s:ring_queued[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + if l:exist + return + endif + + " evict queued chunks that are very similar to the new one + for i in range(len(s:ring_queued) - 1, 0, -1) + if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_queued, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " also from s:ring_chunks + for i in range(len(s:ring_chunks) - 1, 0, -1) + if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_chunks, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " TODO: become parameter ? + if len(s:ring_queued) == 16 + call remove(s:ring_queued, 0) + endif + + call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')}) + + "let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) +endfunction + +" picks a queued chunk, sends it for processing and adds it to s:ring_chunks +" called every g:llama_config.ring_update_ms +function! s:ring_update() + call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()}) + + " update only if in normal mode or if the cursor hasn't moved for a while + if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0 + return + endif + + if len(s:ring_queued) == 0 + return + endif + + " move the first queued chunk to the ring buffer + if len(s:ring_chunks) == g:llama_config.ring_n_chunks + call remove(s:ring_chunks, 0) + endif + + call add(s:ring_chunks, remove(s:ring_queued, 0)) + + "let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) + + " send asynchronous job with the new extra context so that it is ready for the next FIM + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " no samplers needed here + let l:request = json_encode({ + \ 'input_prefix': "", + \ 'input_suffix': "", + \ 'input_extra': l:extra_context, + \ 'prompt': "", + \ 'n_predict': 1, + \ 'temperature': 0.0, + \ 'stream': v:false, + \ 'samplers': ["temperature"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': 1, + \ 't_max_predict_ms': 1 + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + " no callbacks because we don't need to process the response + if s:ghost_text_nvim + call jobstart(l:curl_command, {}) + elseif s:ghost_text_vim + call job_start(l:curl_command, {}) + endif +endfunction + +" necessary for 'inoremap ' +function! llama#fim_inline(is_auto) abort + call llama#fim(a:is_auto) + return '' +endfunction + +" the main FIM call +" takes local context around the cursor and sends it together with the extra context to the server for completion +function! llama#fim(is_auto) abort + " we already have a suggestion for the current cursor position + if s:hint_shown && !a:is_auto + call llama#fim_cancel() + return + endif + + call llama#fim_cancel() + + " avoid sending repeated requests too fast + if reltimefloat(reltime(s:t_fim_start)) < 0.6 + if s:timer_fim != -1 + call timer_stop(s:timer_fim) + let s:timer_fim = -1 + endif + + let s:t_fim_start = reltime() + let s:timer_fim = timer_start(600, {-> llama#fim(v:true)}) + return + endif + + let s:t_fim_start = reltime() + + let s:content = [] + let s:can_accept = v:false + + let s:pos_x = col('.') - 1 + let s:pos_y = line('.') + let l:max_y = line('$') + + let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1) + let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix])) + + let s:line_cur = getline('.') + + let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x) + let s:line_cur_suffix = strpart(s:line_cur, s:pos_x) + + if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix + return + endif + + let l:prefix = "" + \ . join(l:lines_prefix, "\n") + \ . "\n" + + let l:prompt = "" + \ . s:line_cur_prefix + + let l:suffix = "" + \ . s:line_cur_suffix + \ . "\n" + \ . join(l:lines_suffix, "\n") + \ . "\n" + + " prepare the extra context data + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " the indentation of the current line + let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + + let l:request = json_encode({ + \ 'input_prefix': l:prefix, + \ 'input_suffix': l:suffix, + \ 'input_extra': l:extra_context, + \ 'prompt': l:prompt, + \ 'n_predict': g:llama_config.n_predict, + \ 'n_indent': l:indent, + \ 'top_k': 40, + \ 'top_p': 0.99, + \ 'stream': v:false, + \ 'samplers': ["top_k", "top_p", "infill"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms, + \ 't_max_predict_ms': g:llama_config.t_max_predict_ms + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + if s:current_job != v:null + if s:ghost_text_nvim + call jobstop(s:current_job) + elseif s:ghost_text_vim + call job_stop(s:current_job) + endif + endif + + " send the request asynchronously + if s:ghost_text_nvim + let s:current_job = jobstart(l:curl_command, { + \ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'on_exit': function('s:fim_on_exit'), + \ 'stdout_buffered': v:true + \ }) + elseif s:ghost_text_vim + let s:current_job = job_start(l:curl_command, { + \ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'exit_cb': function('s:fim_on_exit') + \ }) + endif + + " TODO: per-file location + let l:delta_y = abs(s:pos_y - s:pos_y_pick) + + " gather some extra context nearby and process it in the background + " only gather chunks if the cursor has moved a lot + " TODO: something more clever? reranking? + if a:is_auto && l:delta_y > 32 + " expand the prefix even further + call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false) + + " pick a suffix chunk + call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false) + + let s:pos_y_pick = s:pos_y + endif +endfunction + +" if first_line == v:true accept only the first line of the response +function! llama#fim_accept(first_line) + " insert the suggestion at the cursor location + if s:can_accept && len(s:content) > 0 + call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0]) + if len(s:content) > 1 + if !a:first_line + call append(s:pos_y, s:content[1:-1]) + endif + endif + + " move the cursor to the end of the accepted text + if !a:first_line && len(s:content) > 1 + call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1) + else + call cursor(s:pos_y, s:pos_x + len(s:content[0])) + endif + endif + + call llama#fim_cancel() +endfunction + +function! llama#fim_cancel() + let s:hint_shown = v:false + + " clear the virtual text + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1) + elseif s:ghost_text_vim + call prop_remove({'type': s:hlgroup_hint, 'all': v:true}) + call prop_remove({'type': s:hlgroup_info, 'all': v:true}) + endif + + " remove the mappings + silent! iunmap + silent! iunmap + silent! iunmap +endfunction + +function! s:on_move() + let s:t_last_move = reltime() + + call llama#fim_cancel() +endfunction + +" callback that processes the FIM result from the server and displays the suggestion +function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null) + if s:ghost_text_nvim + let l:raw = join(a:data, "\n") + elseif s:ghost_text_vim + let l:raw = a:data + endif + + if len(l:raw) == 0 + return + endif + + if a:pos_x != col('.') - 1 || a:pos_y != line('.') + return + endif + + " show the suggestion only in insert mode + if mode() !=# 'i' + return + endif + + let s:pos_x = a:pos_x + let s:pos_y = a:pos_y + + let s:can_accept = v:true + let l:has_info = v:false + + if s:can_accept && v:shell_error + if !a:is_auto + call add(s:content, "<| curl error: is the server on? |>") + endif + let s:can_accept = v:false + endif + + let l:n_prompt = 0 + let l:t_prompt_ms = 1.0 + let l:s_prompt = 0 + + let l:n_predict = 0 + let l:t_predict_ms = 1.0 + let l:s_predict = 0 + + " get the generated suggestion + if s:can_accept + let l:response = json_decode(l:raw) + + for l:part in split(get(l:response, 'content', ''), "\n", 1) + call add(s:content, l:part) + endfor + + " remove trailing new lines + while len(s:content) > 0 && s:content[-1] == "" + call remove(s:content, -1) + endwhile + + let l:generation_settings = get(l:response, 'generation_settings', {}) + let l:n_ctx = get(l:generation_settings, 'n_ctx', 0) + + let l:n_cached = get(l:response, 'tokens_cached', 0) + let l:truncated = get(l:response, 'truncated', v:false) + + " if response.timings is available + if len(get(l:response, 'timings', {})) > 0 + let l:has_info = v:true + let l:timings = get(l:response, 'timings', {}) + + let l:n_prompt = get(l:timings, 'prompt_n', 0) + let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1) + let l:s_prompt = get(l:timings, 'prompt_per_second', 0) + + let l:n_predict = get(l:timings, 'predicted_n', 0) + let l:t_predict_ms = get(l:timings, 'predicted_ms', 1) + let l:s_predict = get(l:timings, 'predicted_per_second', 0) + endif + endif + + if len(s:content) == 0 + call add(s:content, "") + let s:can_accept = v:false + endif + + if len(s:content) == 0 + return + endif + + " NOTE: the following is logic for discarding predictions that repeat existing text + " the code is quite ugly and there is very likely a simpler and more canonical way to implement this + " + " still, I wonder if there is some better way that avoids having to do these special hacks? + " on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would + " start generating whatever we have given it via the extra context. but on the other hand, it's not very + " helpful to re-generate the same code that is already there + + " truncate the suggestion if the first line is empty + if len(s:content) == 1 && s:content[0] == "" + let s:content = [""] + endif + + " ... and the next lines are repeated + if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1) + let s:content = [""] + endif + + " truncate the suggestion if it repeats the suffix + if len(s:content) == 1 && s:content[0] == s:line_cur_suffix + let s:content = [""] + endif + + " find the first non-empty line (strip whitespace) + let l:cmp_y = s:pos_y + 1 + while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$' + let l:cmp_y += 1 + endwhile + + if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y) + " truncate the suggestion if it repeats the next line + if len(s:content) == 1 + let s:content = [""] + endif + + " ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1 + if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1] + let s:content = [""] + endif + + " ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1) + if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n") + let s:content = [""] + endif + endif + + " keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix + "let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + "for i in range(1, len(s:content) - 1) + " if strlen(matchstr(s:content[i], '^\s*')) < l:indent + " let s:content = s:content[:i - 1] + " break + " endif + "endfor + + let s:pos_dx = len(s:content[-1]) + + let s:content[-1] .= s:line_cur_suffix + + call llama#fim_cancel() + + " display virtual text with the suggestion + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + endif + + " construct the info message + if g:llama_config.show_info > 0 && l:has_info + let l:prefix = ' ' + + if l:truncated + let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx + \ ) + else + let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued), + \ l:n_prompt, l:t_prompt_ms, l:s_prompt, + \ l:n_predict, l:t_predict_ms, l:s_predict, + \ 1000.0 * reltimefloat(reltime(s:t_fim_start)) + \ ) + endif + + if g:llama_config.show_info == 1 + " display the info in the statusline + let &statusline = l:info + let l:info = '' + endif + endif + + " display the suggestion and append the info to the end of the first line + if s:ghost_text_nvim + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, { + \ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']], + \ 'virt_text_win_col': virtcol('.') - 1 + \ }) + + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, { + \ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}), + \ 'virt_text_win_col': virtcol('.') + \ }) + elseif s:ghost_text_vim + let l:new_suffix = s:content[0] + if !empty(l:new_suffix) + call prop_add(s:pos_y, s:pos_x + 1, { + \ 'type': s:hlgroup_hint, + \ 'text': l:new_suffix + \ }) + endif + for line in s:content[1:] + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_hint, + \ 'text': line, + \ 'text_padding_left': s:get_indent(line), + \ 'text_align': 'below' + \ }) + endfor + if !empty(l:info) + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_info, + \ 'text': l:info, + \ 'text_padding_left': col('$'), + \ 'text_wrap': 'truncate' + \ }) + endif + endif + + " setup accept shortcuts + inoremap :call llama#fim_accept(v:false) + inoremap :call llama#fim_accept(v:true) + + let s:hint_shown = v:true +endfunction + +function! s:fim_on_exit(job_id, exit_code, event = v:null) + if a:exit_code != 0 + echom "Job failed with exit code: " . a:exit_code + endif + + let s:current_job = v:null +endfunction diff --git a/examples/main/README.md b/examples/main/README.md index 7e192b9f2837c..c7c8231717980 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -187,6 +187,30 @@ Use the `--no-penalize-nl` option to disable newline penalization when applying Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl` +### DRY Repetition Penalty + +DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)). + +- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled). +- `--dry-base N`: Set the DRY sampling base value (default: 1.75). +- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2). +- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size). +- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used. + +The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8. + +The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions. + +The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words. + +The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens. + +The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied. + +DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence. + +Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"` + ### Top-K Sampling - `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40). diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 5f60a86cbc2d2..8c49a52a66124 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -42,7 +42,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); // tokenize prompt @@ -107,7 +106,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl2 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl2, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -171,7 +169,6 @@ int main(int argc, char ** argv) { llama_sampler * smpl3 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl3, llama_sampler_init_softmax()); llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); diff --git a/examples/server/README.md b/examples/server/README.md index 09f1aa249ab1f..bc737237eb018 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -114,6 +114,11 @@ The project is under active development, and we are [looking for feedback and co | `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | | `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | | `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | +| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--dry-base N` | DRY sampling base value (default: 1.75) | +| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) | +| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | +| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers (`['\n', ':', '"', '*']`) in the process; use `"none"` to not use any sequence breakers | `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | | `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | | `--mirostat N` | use Mirostat sampling.
Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | @@ -319,6 +324,18 @@ node index.js - The prompt is a string or an array with the first element given as a string - The model's `tokenizer.ggml.add_bos_token` metadata is `true` + These input shapes and data type are allowed for `prompt`: + + - Single string: `"string"` + - Single sequence of tokens: `[12, 34, 56]` + - Mixed tokens and strings: `[12, 34, "string", 56, 78]` + + Multiple prompts are also supported. In this case, the completion result will be an array. + + - Only strings: `["string1", "string2"]` + - Strings and sequences of tokens: `["string1", [12, 34, 56]]` + - Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]` + `temperature`: Adjust the randomness of the generated text. Default: `0.8` `dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled. @@ -357,6 +374,16 @@ node index.js `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. + `dry_multiplier`: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled. + + `dry_base`: Set the DRY repetition penalty base value. Default: `1.75` + + `dry_allowed_length`: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2` + + `dry_penalty_last_n`: How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size. + + `dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']` + `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index ad4183cd928f7..cb3995abee0ef 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -40,6 +40,10 @@ repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_penalty: 1.0, // 1.0 = disabled penalize_nl: false, // true only useful for infinite completion + dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well + dry_base: 1.75, // 0.0 = disabled + dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well + dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) top_k: 0, // <= 0 to use vocab size top_p: 1.0, // 1.0 = disabled min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 @@ -833,13 +837,17 @@
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} - ${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} + ${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} + ${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })} + ${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })} + ${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })} + ${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })} + ${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
@@ -1144,6 +1152,8 @@

llama.cpp

repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 }, presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, + dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 }, + dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 }, }; // add an event listener for each slider Object.keys(snapSettings).forEach(sliderName => { diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 88065705fb669..7f9b02bfbf83b 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -304,6 +304,10 @@ repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled penalize_nl: false, + dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well + dry_base: 1.75, // 0.0 = disabled + dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well + dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled @@ -1015,6 +1019,10 @@ ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} + ${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })} + ${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })} + ${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })} + ${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })} ${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} diff --git a/examples/server/public/style.css b/examples/server/public/style.css old mode 100755 new mode 100644 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3992108e7f383..ff1d9b03cec5d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -43,21 +43,6 @@ #include #include -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - using json = nlohmann::ordered_json; enum stop_type { @@ -68,6 +53,7 @@ enum stop_type { // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, @@ -79,7 +65,7 @@ enum server_state { }; enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -89,21 +75,22 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_cmpl_type { - SERVER_TASK_CMPL_TYPE_NORMAL, - SERVER_TASK_CMPL_TYPE_EMBEDDING, - SERVER_TASK_CMPL_TYPE_RERANK, - SERVER_TASK_CMPL_TYPE_INFILL, +enum server_task_inf_type { + SERVER_TASK_INF_TYPE_COMPLETION, + SERVER_TASK_INF_TYPE_EMBEDDING, + SERVER_TASK_INF_TYPE_RERANK, + SERVER_TASK_INF_TYPE_INFILL, }; struct server_task { int id = -1; // to be filled by server_queue int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL + llama_tokens prompt_tokens; server_task_type type; json data; - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -161,26 +148,20 @@ struct server_slot { int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; // can be either a string, array of strings or array of token ids - - json input_prefix; - json input_suffix; - json input_extra; - - // when a task is submitted, we first tokenize the prompt and store it here - std::vector prompt_tokens; - std::vector extra_tokens; + // input prompt tokens + llama_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; - std::vector cache_tokens; + llama_tokens cache_tokens; std::vector generated_token_probs; - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; bool has_next_token = true; bool has_new_line = false; @@ -229,7 +210,7 @@ struct server_slot { n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + inf_type = SERVER_TASK_INF_TYPE_COMPLETION; generated_token_probs.clear(); } @@ -734,42 +715,6 @@ struct server_context { metrics.init(); } - std::vector tokenize(const json & json_prompt, bool add_special, bool parse_special) const { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - std::vector p; - if (first) { - p = common_tokenize(ctx, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(ctx, s, false, parse_special); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(ctx, s, add_special, parse_special); - } - - return prompt_tokens; - } - server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { if (slot.id == id) { @@ -794,22 +739,16 @@ struct server_context { continue; } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) { + // skip the slot if it does not contains cached tokens + if (slot.prompt_tokens.empty()) { continue; } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); - - // length of the current slot's prompt - int slot_prompt_len = slot_prompt.size(); - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = longest_common_prefix(slot_prompt, prompt); + int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens); // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcp_len) / slot_prompt_len; + similarity = static_cast(lcp_len) / static_cast(slot.prompt_tokens.size()); // select the current slot if the criteria match if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { @@ -861,35 +800,58 @@ struct server_context { slot.oaicompat_model = ""; } - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); - slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement - slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms); + slot.params.stream = json_value(data, "stream", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", false); + slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); + slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability); + slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep); + slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement + slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms); + + if (slot.sparams.dry_base < 1.0f) + { + slot.sparams.dry_base = default_sparams.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (slot.sparams.dry_sequence_breakers.empty()) { + send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -914,57 +876,6 @@ struct server_context { SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } - // infill - slot.input_prefix = json_value(data, "input_prefix", json()); - slot.input_suffix = json_value(data, "input_suffix", json()); - slot.input_extra = json_value(data, "input_extra", json()); - - SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size()); - for (const auto & chunk : slot.input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk["text"].is_string()) { - send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - // filename is optional - if (chunk.contains("filename") && !chunk["filename"].is_string()) { - send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str()); - } - - // get prompt - { - const auto & prompt = data.find("prompt"); - if (prompt == data.end()) { - send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - if ((prompt->is_string()) || - (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || - (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) { - slot.prompt = *prompt; - } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { - slot.prompt = prompt->at(0); - } else if (prompt->is_array() && prompt->size() > 1) { - // array of strings - for (const auto & el : *prompt) { - if (!el.is_string()) { - send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - slot.prompt = *prompt; - } else { - send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - { slot.sparams.logit_bias.clear(); @@ -1044,8 +955,7 @@ struct server_context { } } - slot.state = SLOT_STATE_PROCESSING_PROMPT; - slot.prompt_tokens.clear(); + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -1245,6 +1155,11 @@ struct server_context { {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, @@ -1297,7 +1212,7 @@ struct server_context { }; if (slot.sparams.n_probs > 0) { - const std::vector to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); + const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); @@ -1333,7 +1248,7 @@ struct server_context { {"tokens_predicted", slot.n_decoded}, {"tokens_evaluated", slot.n_prompt_tokens}, {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, + {"prompt", common_detokenize(ctx, slot.prompt_tokens)}, {"has_new_line", slot.has_new_line}, {"truncated", slot.truncated}, {"stopped_eos", slot.stopped_eos}, @@ -1348,7 +1263,7 @@ struct server_context { if (slot.sparams.n_probs > 0) { std::vector probs; if (!slot.params.stream && slot.stopped_word) { - const std::vector stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); probs = std::vector( @@ -1457,19 +1372,17 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s) + std::vector create_tasks_inference(json data, server_task_inf_type inf_type) { std::vector tasks; - auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { + auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) { + SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size()); server_task task; - task.id = queue_tasks.get_new_id(); - task.cmpl_type = cmpl_type; - task.type = SERVER_TASK_TYPE_COMPLETION; - if (replace_prompt) { - task.data = task_data; - task.data["prompt"] = std::move(prompt); - } else { - task.data = std::move(task_data); - } + task.id = queue_tasks.get_new_id(); + task.inf_type = inf_type; + task.type = SERVER_TASK_TYPE_INFERENCE; + task.data = task_data; + task.prompt_tokens = std::move(prompt_tokens); tasks.push_back(std::move(task)); }; @@ -1478,41 +1391,49 @@ struct server_context { throw std::runtime_error(error_msg); } - json prompt = data.at("prompt"); - - // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task - if (prompt.is_string() || json_is_array_of_numbers(prompt)) { - data["index"] = 0; - create_task(data, false, nullptr); - } else if (prompt.is_array()) { - // otherwise, it's a multiple-prompt task, we break it into smaller tasks - std::vector prompts = prompt; - if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - // prompts[0] is the question - // the rest are the answers/documents - SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); - for (size_t i = 1; i < prompts.size(); i++) { - json qd; - qd.push_back(prompts[0]); - qd.push_back(prompts[i]); - data["index"] = i - 1; - create_task(data, true, qd); - } - } else { - SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { + // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread + bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL; + std::vector tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true); + switch (inf_type) { + case SERVER_TASK_INF_TYPE_RERANK: + { + // prompts[0] is the question + // the rest are the answers/documents + GGML_ASSERT(tokenized_prompts.size() > 1); + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1); + for (size_t i = 1; i < tokenized_prompts.size(); i++) { + data["index"] = i - 1; + auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]); + create_task(data, tokens); + } + } break; + case SERVER_TASK_INF_TYPE_INFILL: + { + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + auto tokens = format_infill( + ctx, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + params.n_batch, + params.n_predict, + slots[0].n_ctx, // TODO: there should be a better way + params.spm_infill, + tokenized_prompts[i] + ); + create_task(data, tokens); + } + } break; + default: + { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + data["index"] = i; + create_task(data, tokenized_prompts[i]); } } - } - } else { - // invalid case - throw std::runtime_error(error_msg); } return tasks; @@ -1534,7 +1455,7 @@ struct server_context { queue_tasks.post(cancel_tasks, true); } - // receive the results from task(s) created by create_tasks_cmpl + // receive the results from task(s) created by create_tasks_inference void receive_cmpl_results( const std::unordered_set & id_tasks, const std::function&)> & result_handler, @@ -1558,7 +1479,7 @@ struct server_context { result_handler(results); } - // receive the results from task(s) created by create_tasks_cmpl, in stream mode + // receive the results from task(s) created by create_tasks_inference, in stream mode void receive_cmpl_results_stream( const std::unordered_set & id_tasks, const std::function & result_handler, const @@ -1591,7 +1512,7 @@ struct server_context { void process_single_task(const server_task & task) { switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFERENCE: { const int id_slot = json_value(task.data, "id_slot", -1); @@ -1623,9 +1544,10 @@ struct server_context { slot->reset(); - slot->id_task = task.id; - slot->cmpl_type = task.cmpl_type; - slot->index = json_value(task.data, "index", 0); + slot->id_task = task.id; + slot->inf_type = task.inf_type; + slot->index = json_value(task.data, "index", 0); + slot->prompt_tokens = std::move(task.prompt_tokens); if (!launch_slot_with_task(*slot, task)) { SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); @@ -1658,7 +1580,7 @@ struct server_context { slot_data["id"] = slot.id; slot_data["id_task"] = slot.id_task; slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; + slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens); slot_data["next_token"] = { {"has_next_token", slot.has_next_token}, {"has_new_line", slot.has_new_line}, @@ -1785,9 +1707,6 @@ struct server_context { } slot->cache_tokens.resize(token_count); - // TODO: maybe detokenize the slot->cache_tokens instead? - slot->prompt = string_format("[restored %d tokens from file]", (int) token_count); - const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -1954,142 +1873,18 @@ struct server_context { if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty()) { - SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size()); - + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - - switch (slot.cmpl_type) { - case SERVER_TASK_CMPL_TYPE_NORMAL: - case SERVER_TASK_CMPL_TYPE_EMBEDDING: - { - prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true); - } break; - case SERVER_TASK_CMPL_TYPE_RERANK: - { - // require slot.prompt to be array of 2 strings - if (!slot.prompt.is_array() || slot.prompt.size() != 2) { - SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); - slot.release(); - send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); - continue; - } - - // prompt: [BOS]query[EOS][SEP]doc[EOS] - prompt_tokens.clear(); - prompt_tokens.push_back(llama_token_bos(model)); - { - const auto part = tokenize(slot.prompt[0], false, false); - prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); - } - prompt_tokens.push_back(llama_token_eos(model)); - prompt_tokens.push_back(llama_token_sep(model)); - { - const auto part = tokenize(slot.prompt[1], false, false); - prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); - } - prompt_tokens.push_back(llama_token_eos(model)); - } break; - case SERVER_TASK_CMPL_TYPE_INFILL: - { - // TODO: optimize this block by reducing memory allocations and movement - - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - auto tokens_prefix = tokenize(slot.input_prefix, false, false); - auto tokens_suffix = tokenize(slot.input_suffix, false, false); - auto tokens_prompt = tokenize(slot.prompt, false, false); - - slot.extra_tokens.clear(); - if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) { - static const auto k_fim_repo = tokenize("myproject\n", false, false); - - slot.extra_tokens.push_back(llama_token_fim_rep(model)); - slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); - } - - for (const auto & chunk : slot.input_extra) { - // { "text": string, "filename": string } - const std::string text = chunk.value("text", ""); - const std::string filename = chunk.value("filename", "tmp"); - - if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = tokenize(filename + "\n", false, false); - - slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model)); - slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false); - - slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); - } - - const auto chunk_tokens = tokenize(text, false, false); - slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); - } - - if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = tokenize("filename\n", false, false); - - slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model)); - slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } - - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_suffix_take = std::min(tokens_suffix.size(), (n_batch/4)); - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4) - 3); - - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size()); - - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); - - tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model)); - - auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix; - - if (llama_add_bos_token(model)) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); - } - - SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_token_fim_mid(model)); - - prompt_tokens = std::move(embd_inp); - } break; - } - slot.n_past = 0; slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) if (1) { @@ -2114,13 +1909,18 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - // this prompt is too large to process - discard it + if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } } else { if (!params.ctx_shift) { // if context shift is disabled, we make sure prompt size is smaller than KV size @@ -2144,7 +1944,7 @@ struct server_context { const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - std::vector new_tokens( + llama_tokens new_tokens( prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); @@ -2163,17 +1963,10 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - common_sampler_reset(slot.smpl); - if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); - // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) { - common_sampler_accept(slot.smpl, slot.cache_tokens[i], false); - } - // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache @@ -2205,9 +1998,6 @@ struct server_context { for (size_t i = 0; i < n_match; i++) { slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; - - common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false); - slot.n_past++; } @@ -2234,7 +2024,7 @@ struct server_context { } // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2243,8 +2033,8 @@ struct server_context { // check that we are in the right batch_type, if not defer the slot const bool slot_type = - slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || - slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || + slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0; if (batch_type == -1) { batch_type = slot_type; @@ -2259,8 +2049,6 @@ struct server_context { // there is no common part left slot.n_past = 0; - - common_sampler_reset(slot.smpl); } SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); @@ -2288,6 +2076,13 @@ struct server_context { GGML_ASSERT(batch.n_tokens > 0); + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, prompt_tokens[i], false); + } + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; @@ -2357,7 +2152,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -2365,7 +2160,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -2612,7 +2407,7 @@ int main(int argc, char ** argv) { auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); + auto tmp = string_split(req.path, '.'); if (req.path == "/" || tmp.back() == "html") { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; @@ -2919,13 +2714,13 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); + std::vector tasks = ctx_server.create_tasks_inference(data, inf_type); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -2971,10 +2766,11 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); + return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); }; const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + // check model compatibility std::string err; if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { err += "prefix token is missing. "; @@ -2985,14 +2781,42 @@ int main(int argc, char ** argv) { if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { err += "middle token is missing. "; } - if (!err.empty()) { res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); return; } json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); + + // validate input + if (!data.contains("input_prefix")) { + res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return; + } + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); }; // TODO: maybe merge this function with "handle_completions_generic" @@ -3004,7 +2828,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); + std::vector tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3077,7 +2901,7 @@ int main(int argc, char ** argv) { const bool add_special = json_value(body, "add_special", false); const bool with_pieces = json_value(body, "with_pieces", false); - std::vector tokens = ctx_server.tokenize(body.at("content"), add_special, true); + llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true); if (with_pieces) { for (const auto& token : tokens) { @@ -3114,7 +2938,7 @@ int main(int argc, char ** argv) { std::string content; if (body.count("tokens") != 0) { - const std::vector tokens = body.at("tokens"); + const llama_tokens tokens = body.at("tokens"); content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); } @@ -3148,7 +2972,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3225,7 +3049,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + std::vector tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); diff --git a/examples/server/tests/features/infill.feature b/examples/server/tests/features/infill.feature new file mode 100644 index 0000000000000..a0bbfef77707b --- /dev/null +++ b/examples/server/tests/features/infill.feature @@ -0,0 +1,36 @@ +@llama.cpp +@infill +Feature: llama.cpp server + + # The current model is made by adding FIM tokens to the existing stories260K + # We may want to use a better model in the future, maybe something like SmolLM 360M + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models + And a model file test-model-infill.gguf + And a model alias tinyllama-infill + And 42 as server seed + And 1024 as batch size + And 1024 as ubatch size + And 2048 KV cache size + And 64 max tokens to predict + And 0.0 temperature + Then the server is starting + Then the server is healthy + + Scenario: Infill without input_extra + Given a prompt "Complete this" + And an infill input extra none none + And an infill input prefix "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_" + And an infill input suffix "}\n" + And an infill request with no api error + Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird + + Scenario: Infill with input_extra + Given a prompt "Complete this" + And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n" + And an infill input prefix "#include \n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_" + And an infill input suffix "}\n" + And an infill request with no api error + Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room" diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 540a2ecd56374..2e418d8aa571b 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.lora_file = None context.disable_ctx_shift = False + # infill + context.infill_input_extra = None + context.infill_input_suffix = '' + context.infill_input_prefix = '' + context.tasks_result = [] context.concurrent_tasks = [] context.prompts = [] @@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" +@step('an infill request with {api_error} api error') +@async_run_until_complete +async def step_request_completion(context, api_error: Literal['raised'] | str): + if api_error != 'no': + raise ValueError(f'api_error={api_error} is not yet implemented') + payload = { + "prompt": context.prompts[0], + "input_suffix": context.infill_input_suffix, + "input_prefix": context.infill_input_prefix, + "n_predict": context.n_predict, + "seed": context.seed, + "temperature": context.temperature, + } + if context.infill_input_extra is not None: + payload['input_extra'] = context.infill_input_extra + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + async with session.post(f'{context.base_url}/infill', + json=payload) as response: + assert response.status == 200 + context.tasks_result = [await response.json()] + + @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): context.completion = context.tasks_result.pop() @@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt): context.n_prompts = len(context.prompts) +# TODO: allow this to be repeated +@step('an infill input extra {filename} {text}') +def step_infill_input_extra(context, filename, text): + if filename == 'none': + context.infill_input_extra = None + else: + context.infill_input_extra = [{'filename': filename, 'text': text}] + + +@step('an infill input suffix {text}') +def step_infill_input_suffix(context, text): + context.infill_input_suffix = text + + +@step('an infill input prefix {text}') +def step_infill_input_prefix(context, text): + context.infill_input_prefix = text + + @step('{num_prompts:d} prompts {prompt} with seed {seed:d}') def step_many_prompts(context, num_prompts, prompt, seed): if context.seed is None: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 69519ef95b2d9..8112420624185 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -24,6 +24,22 @@ #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::ordered_json; +using llama_tokens = std::vector; + +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) + +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) + +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { @@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul } // -// chat template utils +// tokenizer and input processing utils // +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } + } + return true; + } + return false; +} + +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(ctx, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(ctx, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(ctx, s, add_special, parse_special); + } + + return prompt_tokens; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(ctx, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } + } + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + return result; +} + +// +// template utils +// + +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_token_bos(model)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_token_eos(model)); + result.push_back(llama_token_sep(model)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_token_eos(model)); + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_context * ctx, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto model = llama_get_model(ctx); + auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false); + + if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false); + + extra_tokens.push_back(llama_token_fim_rep(model)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(ctx, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_suffix_take = std::min(tokens_suffix.size(), (n_batch/4)); + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4) - 3); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_add_bos_token(model)) { + embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_token_fim_mid(model)); + + return embd_inp; +} + // Format given chat. If tmpl is empty, we take the template from model metadata inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { std::vector chat; @@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin return std::string::npos; } -static bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number()) { - return false; - } - } - return true; - } - return false; -} - // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b201bd714a447..a40e755a26f84 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -39,6 +39,11 @@ int main(int argc, char ** argv) { return 1; } + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + common_init(); if (params.model_draft.empty()) { @@ -180,8 +185,6 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); - struct llama_sampler * softmax = llama_sampler_init_softmax(); - // draft sequence data std::vector drafts(n_seq_dft); @@ -190,8 +193,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sparams); } - llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); + llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); const auto t_dec_start = ggml_time_us(); @@ -441,7 +444,7 @@ int main(int argc, char ** argv) { ++n_past_dft; } - if (n_predict > params.n_predict || has_eos) { + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { break; } @@ -624,7 +627,6 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_sampler_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/flake.lock b/flake.lock index 70252702896d1..1f8defab78f0c 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1728492678, - "narHash": "sha256-9UTxR8eukdg+XZeHgxW5hQA9fIKHsKCdOIUycTryeVw=", + "lastModified": 1729256560, + "narHash": "sha256-/uilDXvCIEs3C9l73JTACm4quuHUsIHcns1c+cHUJwA=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5633bcff0c6162b9e4b5f1264264611e950c8ec7", + "rev": "4c2fcb090b1f3e5b47eaa7bd33913b574a11e0a0", "type": "github" }, "original": { diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index 95bdaf10d17d0..5289754938871 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -34,6 +34,8 @@ extern "C" { */ #define GGML_CANN_MAX_DEVICES 16 +GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void); + /** * @brief Initializes the CANN backend for a specified device. * diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 81d09cd8bce9c..7d7b63a15a17f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -561,6 +561,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na # include "ggml-amx.h" #endif +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + struct ggml_backend_registry { std::vector backends; std::vector devices; @@ -587,8 +591,11 @@ struct ggml_backend_registry { #ifdef GGML_USE_AMX register_backend(ggml_backend_amx_reg()); #endif +#ifdef GGML_USE_CANN + register_backend(ggml_backend_cann_reg()); +#endif - // TODO: kompute, cann + // TODO: kompute register_backend(ggml_backend_cpu_reg()); } diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp index ec3c0a6889a2e..af0fb603a7c36 100644 --- a/ggml/src/ggml-cann.cpp +++ b/ggml/src/ggml-cann.cpp @@ -39,6 +39,8 @@ #include "ggml-common.h" +#define GGML_CANN_NAME "CANN" + /** * @brief Handles CANN errors by printing an error message and aborting. * @@ -851,13 +853,6 @@ static void ggml_backend_cann_buffer_set_tensor( void *transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size) == 0); - free(check_buffer); -#endif ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); @@ -969,7 +964,7 @@ static void ggml_backend_cann_buffer_clear( * This structure defines function pointers to operations that can be performed * on a CANN buffer within the backend. */ -static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { +static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .get_name = */ ggml_backend_cann_buffer_get_name, /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, @@ -1105,19 +1100,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size( GGML_UNUSED(buft); } +static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + /** * @brief Interface for managing CANN buffer types in the GGML backend. * * Provides function pointers for allocating, querying properties, and managing * memory for CANN buffer types in the GGML backend. */ -static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { +static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { /* .get_name = */ ggml_backend_cann_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size, - /* .is_host = */ NULL, + /* .is_host = */ ggml_backend_cann_buffer_type_is_host, }; /** @@ -1148,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) { for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) { ggml_backend_cann_buffer_types[i] = { /* .iface = */ ggml_backend_cann_buffer_type_interface, - /* .device = */ nullptr, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device), /* .context = */ new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i)}, @@ -1264,7 +1265,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() { /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, - /* .device = */ nullptr, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0), /* .context = */ nullptr, }; @@ -1511,13 +1512,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, void *transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size)); - free(check_buffer); -#endif ACL_CHECK(aclrtMemcpyAsync( (char *)tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); @@ -1692,7 +1686,7 @@ static enum ggml_status ggml_backend_cann_graph_compute( * @return bool Returns true if the operation is supported by the backend, * otherwise false. */ -static bool ggml_backend_cann_supports_op(ggml_backend_t backend, +static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { switch (op->op) { case GGML_OP_UNARY: @@ -1783,7 +1777,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend, return false; } - GGML_UNUSED(backend); + GGML_UNUSED(dev); } /** @@ -1801,31 +1795,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } -/** - * @brief Checks if the CANN backend supports a specific backend buffer type. - * - * This function determines whether the CANN backend supports the given backend - * buffer type by comparing the device context of the backend and buffer type. - * It returns true if the devices are same between the backend context and - * buffer type context. - * - * @param backend Pointer to the CANN backend. - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the CANN backend supports the buffer type, - * otherwise false. - */ -static bool ggml_backend_cann_supports_buft( - ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_cann(buft)) { - ggml_backend_cann_context * cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_cann_buffer_type_context * buft_ctx = - (ggml_backend_cann_buffer_type_context *)buft->context; - return buft_ctx->device == cann_ctx->device; - } - return false; -} - /** * @brief Determines if a tensor operation should be offloaded to the CANN * backend. @@ -1840,54 +1809,14 @@ static bool ggml_backend_cann_supports_buft( * @return bool Returns true if the operation should be offloaded, otherwise * false. */ -static bool ggml_backend_cann_offload_op(ggml_backend_t backend, +static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor* op) { const int min_batch_size = 32; - GGML_UNUSED(backend); + GGML_UNUSED(dev); return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; } -/** - * @brief Creates a new event for the CANN backend. - * - * This function initializes a new event for the CANN backend by setting the - * device and creating an ACL runtime event. The created event is then wrapped - * in a ggml_backend_event structure and returned. - * - * @param backend Pointer to the CANN backend. - * @return ggml_backend_event_t Returns a pointer to the new event structure. - */ -static ggml_backend_event_t ggml_backend_cann_event_new( - ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - ggml_cann_set_device(cann_ctx->device); - - aclrtEvent event; - ACL_CHECK(aclrtCreateEvent(&event)); - - return new ggml_backend_event{ - /* .device = */ nullptr, - /* .context = */ event, - }; -} - -/** - * @brief Frees a CANN backend event. - * - * This function destroys the ACL runtime event associated with the given CANN - * backend event and then deletes the event structure itself. - * - * @param event Pointer to the event structure to be freed. - */ -static void ggml_backend_cann_event_free(ggml_backend_event_t event) { - ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); - - delete event; -} - /** * @brief Records an event on the CANN backend stream. * @@ -1924,17 +1853,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, } } -/** - * @brief Synchronizes the given event on the CANN backend. - * - * This function waits for the specified event to complete on the ACL runtime. - * - * @param event Pointer to the event structure to be synchronized. - */ -static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) { - ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); -} - /** * @brief Structure defining the interface for the CANN backend. * @@ -1942,7 +1860,7 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) { * supported by the CANN backend, including name retrieval, memory * management, tensor operations, synchronization, and event handling. */ -static ggml_backend_i ggml_backend_cann_interface = { +static const ggml_backend_i ggml_backend_cann_interface = { /* .get_name = */ ggml_backend_cann_name, /* .free = */ ggml_backend_cann_free, /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type, @@ -1955,9 +1873,9 @@ static ggml_backend_i ggml_backend_cann_interface = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cann_graph_compute, - /* .supports_op = */ ggml_backend_cann_supports_op, - /* .supports_buft = */ ggml_backend_cann_supports_buft, - /* .offload_op = */ ggml_backend_cann_offload_op, + /* .supports_op = */ NULL, // moved to device + /* .supports_buft = */ NULL, // moved to device + /* .offload_op = */ NULL, // moved to device /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, }; @@ -1976,6 +1894,234 @@ static ggml_guid_t ggml_backend_cann_guid() { return &guid; } +// backend device +struct ggml_backend_cann_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_get_device_memory(ctx->device, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU_FULL; +} + +static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cann_device_get_name(dev); + props->description = ggml_backend_cann_device_get_description(dev); + props->type = ggml_backend_cann_device_get_type(dev); + ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr; + + props->caps = { + /* .async = */ false, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ true, + }; +} + +static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ggml_backend_cann_init(ctx->device); +} + +/** + * @brief Checks if the CANN backend supports a specific backend buffer type. + * + * This function determines whether the CANN backend supports the given backend + * buffer type by comparing the device context of the backend and buffer type. + * It returns true if the devices are same between the backend context and + * buffer type context. + * + * @param backend Pointer to the CANN backend. + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the CANN backend supports the buffer type, + * otherwise false. + */ +static bool ggml_backend_cann_supports_buft( + ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (ggml_backend_buft_is_cann(buft)) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_buffer_type_context * buft_ctx = + (ggml_backend_cann_buffer_type_context *)buft->context; + return buft_ctx->device == dev_ctx->device; + } + return false; +} + +static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ggml_backend_cann_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_cann_host_buffer_type(); +} + +/** + * @brief Creates a new event for the CANN backend device. + * + * This function initializes a new event for the CANN backend by setting the + * device and creating an ACL runtime event. The created event is then wrapped + * in a ggml_backend_event structure and returned. + * + * @param backend Pointer to the CANN backend. + * @return ggml_backend_event_t Returns a pointer to the new event structure. + */ +static ggml_backend_event_t ggml_backend_cann_device_event_new( + ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + + ggml_cann_set_device(dev_ctx->device); + + aclrtEvent event; + ACL_CHECK(aclrtCreateEvent(&event)); + + return new ggml_backend_event{ + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device), + /* .context = */ event, + }; +} + +/** + * @brief Frees a CANN backend event. + * + * This function destroys the ACL runtime event associated with the given CANN + * backend event and then deletes the event structure itself. + * + * @param event Pointer to the event structure to be freed. + */ +static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); + + delete event; + GGML_UNUSED(dev); +} + +/** + * @brief Synchronizes the given event on the CANN backend. + * + * This function waits for the specified event to complete on the ACL runtime. + * + * @param event Pointer to the event structure to be synchronized. + */ +static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); + + GGML_UNUSED(dev); +} + +static const ggml_backend_device_i ggml_backend_cann_device_interface = { + /* .get_name = */ ggml_backend_cann_device_get_name, + /* .get_description = */ ggml_backend_cann_device_get_description, + /* .get_memory = */ ggml_backend_cann_device_get_memory, + /* .get_type = */ ggml_backend_cann_device_get_type, + /* .get_props = */ ggml_backend_cann_device_get_props, + /* .init_backend = */ ggml_backend_cann_device_init, // called for every card + /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, // not supported for CANN + /* .supports_op = */ ggml_backend_cann_supports_op, + /* .supports_buft = */ ggml_backend_cann_supports_buft, + /* .offload_op = */ ggml_backend_cann_offload_op, + /* .event_new = */ ggml_backend_cann_device_event_new, + /* .event_free = */ ggml_backend_cann_device_event_free, + /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize, +}; + + +// backend reg +struct ggml_backend_cann_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_CANN_NAME; +} + +static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_UNUSED(reg); + GGML_UNUSED(name); + // reserved for future use + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_cann_reg_interface = { + /* .get_name = */ ggml_backend_cann_reg_get_name, + /* .get_device_count = */ ggml_backend_cann_reg_get_device_count, + /* .get_device_get = */ ggml_backend_cann_reg_get_device, + /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address, +}; + +// backend registry, called only once for cann backend +ggml_backend_reg_t ggml_backend_cann_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + aclInit(nullptr); + ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context; + + for (int i = 0; i < ggml_cann_info().device_count; i++) { + ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context(); + dev_ctx->description = aclrtGetSocName(); + dev_ctx->device = i; + dev_ctx->name = GGML_CANN_NAME + std::to_string(i); + ggml_cann_set_device(i); + ggml_backend_dev_t dev = new ggml_backend_device { + /* .interface = */ ggml_backend_cann_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .interface = */ ggml_backend_cann_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + ggml_backend_t ggml_backend_cann_init(int32_t device) { aclInit(nullptr); if (device < 0 || device >= ggml_backend_cann_get_device_count()) { @@ -1992,7 +2138,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) { ggml_backend_t cann_backend = new ggml_backend{/* .guid = */ ggml_backend_cann_guid(), /* .interface = */ ggml_backend_cann_interface, - /* .device = */ nullptr, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device), /* .context = */ ctx}; return cann_backend; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1338bd45836bb..21c9f5e3829cb 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); - char * src_ptr = (char *) src->data; - char * dst_ptr = (char *) dst; + const char * src_ptr = (const char *) src->data; + char * dst_ptr = (char *) dst; const int64_t ne0 = src->ne[0]; const int64_t nb0 = src->nb[0]; @@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; + const int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { @@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat( if (src0_is_contiguous) { dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data; } else { - dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0)); + // If src0 is not contiguous it will be copied to a temporary buffer. + // This buffer needs to be cleared entirely because multiple regions will function as padding. + const size_t nbytes_data = ggml_nbytes(src0); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); } - // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: + // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { - const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); - const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); } @@ -3141,7 +3146,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: - return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 7961674266ee1..28b06cddaa87b 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_CPY_BLOCK_SIZE 64 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 16463ab0fb683..86a54e42bb7e6 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t OH = is_2D ? dst->ne[2] : 1; const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[3]; - const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) { im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 4935f8818679f..ae5c68ab35129 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne00 = src0->ne[0]; - const int64_t nb01 = src0->nb[1]; - const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); @@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = nb01 / ggml_type_size(src0->type); + const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 172a0f925d316..80c08f15b2999 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32, @@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -685,6 +689,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); @@ -716,6 +722,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } [metal_library release]; @@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: return false; + case GGML_OP_POOL_2D: case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: @@ -1007,19 +1015,21 @@ static void ggml_metal_encode_node( id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} +#if 0 + GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + if (src0) { + GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif id device = ctx_dev->mtl_device; @@ -1802,14 +1812,16 @@ static void ggml_metal_encode_node( [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -1978,20 +1990,22 @@ static void ggml_metal_encode_node( [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2040,6 +2054,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts @@ -2545,6 +2562,8 @@ static void ggml_metal_encode_node( } break; case GGML_OP_IM2COL: { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); @@ -2574,30 +2593,54 @@ static void ggml_metal_encode_node( const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - id pipeline = nil; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; + + const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + case GGML_TYPE_F32: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); + } break; + case GGML_TYPE_F16: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); + } break; default: GGML_ABORT("fatal error"); }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; + + if (is_gt_mttpt) { + [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; + [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; + [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; + + const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); + + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } } break; case GGML_OP_UPSCALE: { @@ -3001,6 +3044,64 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); + + const int32_t * opts = dst->op_params; + enum ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case GGML_TYPE_F32: { + switch(op) { + case GGML_OP_POOL_AVG: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; + case GGML_OP_POOL_MAX: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + } + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; + [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; + [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; + [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; + [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..defde6246f129 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32( const int64_t i3 = tgpig.z; const int64_t nc = ne10; - const int64_t ncs = ne00; - const int64_t nr = ne01; - const int64_t n_t = ne1; - const int64_t n_s = ne2; + //const int64_t ncs = ne00; + //const int64_t nr = ne01; + //const int64_t n_t = ne1; + //const int64_t n_s = ne2; device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); @@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32( const int64_t i3 = tgpig.y; const int64_t nc = d_state; - const int64_t nr = d_inner; + //const int64_t nr = d_inner; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; + //const int64_t n_s = n_seqs; for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); @@ -1064,17 +1064,18 @@ kernel void kernel_group_norm( inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]); + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (sumy * -16.f + acc[0] + acc[1]); + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // putting them in the kernel cause a significant performance penalty @@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, uint r3, threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; @@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; @@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl( // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; } +#pragma unroll for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; @@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } float yl[NB_Q8_0]; float sumf[nr]={0.f}; @@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*ax[row][ib].d; } yb += NB_Q8_0 * nw; @@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_MV_T_T 4 @@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { @@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; @@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv( nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, + nb13, ne0, ne1, r2, @@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float * y = (device const float *) (src1 + offset1); float sumf = 0; if (ne00 < 128) { @@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { @@ -1933,6 +1999,85 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + + const int32_t d = tgpig[0] / CHW; + const int32_t chw = tgpig[0] % CHW; + const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int32_t HW = tgpig[0] % KHW; + + const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= N) { + return; + } + + const int32_t tpitg_1 = HW / KW; + const int32_t tpitg_2 = HW % KW; + + const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + + const int32_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + kernel void kernel_upscale_f32( device const char * src0, device char * dst, @@ -3337,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3354,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q2_K) * nb; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -3413,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += step/2; - sc += step; - dh += step/2; + qs += nb01/2; + sc += nb01; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3440,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3454,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q3_K_f32_impl( @@ -3464,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3486,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float yl[32]; @@ -3529,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - const int step = sizeof(block_q3_K) * nb / 2; - device const float * y1 = yy + ix*QK_K + y_offset; uint32_t scales32, aux32; @@ -3540,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl( float sumf1[2] = {0.f}; float sumf2[2] = {0.f}; for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; @@ -3554,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl( device const half * dh = &x[i].d; for (int row = 0; row < 2; ++row) { - const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -3594,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += step; - h += step; - a += step; - dh += step; - + q += nb01/2; + h += nb01/2; + a += nb01/2; + dh += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -3627,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3641,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q4_K_f32_impl( @@ -3651,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3677,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl( const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q4_K) * nb / 2; - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; @@ -3713,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -3742,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - sc += step; - dh += step; + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3769,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3783,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( @@ -3793,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3815,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf[2]={0.f}; - const int step = sizeof(block_q5_K) * nb; - float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -3851,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y1 = yy + ix*QK_K + y_offset; for (int i = ix; i < nb; i += 4) { - device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * qh = x[i].qh + l0; device const half * dh = &x[i].d; @@ -3867,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl( } for (int row = 0; row < 2; ++row) { - device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -3896,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - qh += step; - dh += step/2; - a += step/2; - + q1 += nb01; + qh += nb01; + dh += nb01/2; + a += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -3926,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3940,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( @@ -3950,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3977,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf = 0; @@ -4036,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4050,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4062,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4079,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4140,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( } sumf[row] += d * sum; - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; + dh += nb01/2; + q2 += nb01/2; } y4 += 32 * 32; @@ -4166,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4181,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_xs_f32_impl( @@ -4191,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4208,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4278,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); + dh += nb01/2; + q2 += nb01/2; + sc += nb01; } y4 += 32 * 32; @@ -4305,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4320,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_xxs_f32_impl( @@ -4330,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4347,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4410,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; + dh += nb01/2; + q3 += nb01; + gas += nb01/2; } y4 += 32 * 32; @@ -4437,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4452,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_s_f32_impl( @@ -4462,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4479,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4540,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4569,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4584,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_s_f32_impl( @@ -4594,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4611,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4673,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4702,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4717,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq1_s_f32_impl( @@ -4727,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4744,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4794,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; + dh += nb01/2; + qs += nb01; + qh += nb01/2; } y4 += 32 * 32; @@ -4817,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4834,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4893,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); + sc += nb01/2; + qs += nb01; + qh += nb01; } y4 += 32 * 32; @@ -4916,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4933,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 @@ -5010,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -5027,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 @@ -5109,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5123,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5137,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5151,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5165,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5180,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5194,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5209,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } //============================= templates and their specializations ============================= @@ -5754,10 +6000,12 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5794,12 +6042,13 @@ kernel void kernel_mul_mm(device const uchar * src0, const uint i12 = im%ne12; const uint i13 = im/ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; ushort offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im + + nb13 * i13 + + nb12 * i12 + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); @@ -6178,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6198,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6220,6 +6477,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6227,6 +6485,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6237,7 +6496,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); } template @@ -6251,6 +6510,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6258,6 +6518,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6268,7 +6529,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6317,8 +6578,8 @@ kernel void kernel_mul_mv_id( const int64_t i2 = i12; device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; impl_fn( /* src0 */ src0_cur, @@ -6326,19 +6587,21 @@ kernel void kernel_mul_mv_id( /* dst */ dst_cur, /* ne00 */ ne00, /* ne01 */ ne01, - /* ne02 */ 1,//ne02, + /* ne02 */ 1, // ne02, /* nb00 */ nb00, /* nb01 */ nb01, /* nb02 */ nb02, + /* nb03 */ nb02, // ne02 == 1 /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, + /* ne11 */ 1, // ne11, + /* ne12 */ 1, // ne12, + /* ne13 */ 1, // ne13, /* nb10 */ nb10, /* nb11 */ nb11, /* nb12 */ nb12, + /* ne13 */ nb12, // ne12 == 1 /* ne0 */ ne0, - /* ne1 */ 1,//ne1, + /* ne1 */ 1, // ne1, /* nb1 */ nb1, /* r2 */ 1, /* r3 */ 1, @@ -6372,3 +6635,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} + +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index f95233284d854..0e936b3437e3e 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -57,8 +57,9 @@ struct socket_t { } }; +// all RPC structures must be packed +#pragma pack(push, 1) // ggml_tensor is serialized into rpc_tensor -#pragma pack(1) struct rpc_tensor { uint64_t id; uint32_t type; @@ -95,76 +96,64 @@ enum rpc_cmd { RPC_CMD_COUNT, }; -#pragma pack(1) struct rpc_msg_alloc_buffer_req { uint64_t size; }; -#pragma pack(1) struct rpc_msg_alloc_buffer_rsp { uint64_t remote_ptr; uint64_t remote_size; }; -#pragma pack(1) struct rpc_msg_get_alignment_rsp { uint64_t alignment; }; -#pragma pack(1) struct rpc_msg_get_max_size_rsp { uint64_t max_size; }; -#pragma pack(1) struct rpc_msg_buffer_get_base_req { uint64_t remote_ptr; }; -#pragma pack(1) struct rpc_msg_buffer_get_base_rsp { uint64_t base_ptr; }; -#pragma pack(1) struct rpc_msg_free_buffer_req { uint64_t remote_ptr; }; -#pragma pack(1) struct rpc_msg_buffer_clear_req { uint64_t remote_ptr; uint8_t value; }; -#pragma pack(1) struct rpc_msg_get_tensor_req { rpc_tensor tensor; uint64_t offset; uint64_t size; }; -#pragma pack(1) struct rpc_msg_copy_tensor_req { rpc_tensor src; rpc_tensor dst; }; -#pragma pack(1) struct rpc_msg_copy_tensor_rsp { uint8_t result; }; -#pragma pack(1) struct rpc_msg_graph_compute_rsp { uint8_t result; }; -#pragma pack(1) struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; +#pragma pack(pop) // RPC data structures diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b96925e14eba..7b10cf688148e 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,6 @@ #include "mmvq.hpp" #include "vecdotq.hpp" - +#include template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, @@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq3_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq3_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq1_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq1_m_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_NL == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq4_nl_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq4_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7e24313edb269..66df9a9c1e621 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -324,8 +324,9 @@ struct ggml_logger_state { static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { - if (format == NULL) + if (format == NULL) { return; + } va_list args_copy; va_copy(args_copy, args); char buffer[128]; @@ -3463,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t nbytes; - size_t blck_size = ggml_blck_size(tensor->type); + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { nbytes = ggml_type_size(tensor->type); for (int i = 0; i < GGML_MAX_DIMS; ++i) { @@ -3851,10 +3852,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { }, }; - for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); @@ -15723,6 +15720,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); + GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 0193a463aefec..9eead3f61e090 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -942,6 +942,36 @@ class tinyBLAS_Q0_AVX { return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); } + inline __m256i load(const block_q5_0 *b) { + return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh)); + } + + inline __m128i load0(const block_q5_0* b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + uint32_t x32; + memcpy(&x32, b->qh, sizeof(uint32_t)); + __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x); + __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1), + _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm_shuffle_epi8(_mm_set1_epi32(x32), + _mm_set_epi64x(0x0101010101010101, 0x0000000000000000)))); + bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0)); + return _mm_or_si128(qxl, bytesl); + } + + inline __m128i load1(const block_q5_0* b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + uint32_t x32; + memcpy(&x32, b->qh, sizeof(uint32_t)); + __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1), + _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm_shuffle_epi8(_mm_set1_epi32(x32), + _mm_set_epi64x(0x0303030303030303, 0x0202020202020202)))); + bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0)); + return _mm_or_si128(qxh, bytesh); + } + inline __m256i load(const block_iq4_nl *b) { return MM256_SET_M128I(load1(b), load0(b)); } @@ -973,6 +1003,17 @@ class tinyBLAS_Q0_AVX { _mm_srli_epi16(x, 4), 1)); } + static inline __m256i bittobyte(const uint8_t *p) { + uint32_t x32; + memcpy(&x32, p, sizeof(uint32_t)); + __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1), + _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm256_shuffle_epi8(_mm256_set1_epi32(x32), + _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000)))); + return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0)); + } + const TA *const A; const TB *const B; TC *const C; @@ -1182,6 +1223,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #endif } + case GGML_TYPE_Q5_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q5_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + case GGML_TYPE_IQ4_NL: { if (Btype != GGML_TYPE_Q8_0) return false; diff --git a/include/llama.h b/include/llama.h index 2558e9267905b..b2d1e7d5ae16b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -217,6 +217,7 @@ extern "C" { typedef struct llama_token_data_array { // TODO: consider SoA + // NOTE: this pointer can be modified by the samplers llama_token_data * data; size_t size; int64_t selected; // this is the index in the data array (i.e. not the token id) @@ -1069,12 +1070,13 @@ extern "C" { // available samplers: - LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. - LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), + "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1090,6 +1092,8 @@ extern "C" { /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); + + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. @@ -1137,6 +1141,16 @@ extern "C" { bool penalize_nl, // consider newlines as a repeatable token bool ignore_eos); // ignore the end-of-sequence token + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_model * model, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index ffce2aab0918e..fba29b9352e68 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -76,6 +76,7 @@ while read c; do src/ggml*.m \ src/ggml*.metal \ src/ggml*.cu \ + src/ggml-amx/* \ src/ggml-cann/* \ src/ggml-cuda/* \ src/ggml-sycl/* \ @@ -121,6 +122,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml-aarch64.c -> ggml/src/ggml-aarch64.c # src/ggml-aarch64.h -> ggml/src/ggml-aarch64.h # src/ggml-alloc.c -> ggml/src/ggml-alloc.c + # src/ggml-amx/* -> ggml/src/ggml-amx/ + # src/ggml-amx.cpp -> ggml/src/ggml-amx.cpp # src/ggml-backend-impl.h -> ggml/src/ggml-backend-impl.h # src/ggml-backend.cpp -> ggml/src/ggml-backend.cpp # src/ggml-cann/* -> ggml/src/ggml-cann/ @@ -141,6 +144,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # # include/ggml.h -> ggml/include/ggml.h # include/ggml-alloc.h -> ggml/include/ggml-alloc.h + # include/ggml-amx.h -> ggml/include/ggml-amx.h # include/ggml-backend.h -> ggml/include/ggml-backend.h # include/ggml-blas.h -> ggml/include/ggml-blas.h # include/ggml-cann.h -> ggml/include/ggml-cann.h @@ -168,6 +172,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.c/\1ggml\/src\/ggml-aarch64.c/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.h/\1ggml\/src\/ggml-aarch64.h/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-alloc\.c/\1ggml\/src\/ggml-alloc.c/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-amx\.cpp/\1ggml\/src\/ggml-amx.cpp/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-backend-impl\.h/\1ggml\/src\/ggml-backend-impl.h/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-backend\.cpp/\1ggml\/src\/ggml-backend.cpp/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ @@ -187,6 +193,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml\.h/\1ggml\/include\/ggml.h/g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml-alloc\.h/\1ggml\/include\/ggml-alloc.h/g' \ + -e 's/([[:space:]]|[ab]\/)include\/ggml-amx\.h/\1ggml\/include\/ggml-amx.h/g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml-backend\.h/\1ggml\/include\/ggml-backend.h/g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml-blas\.h/\1ggml\/include\/ggml-blas.h/g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml-cann\.h/\1ggml\/include\/ggml-cann.h/g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6d31b21b9df35..da40927e196cd 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -2327bda7a55ac6b72614ac5ebd5c5a5e02553b9b +162e232411ee98ceb0cccfa84886118d917d2123 diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index f6ff5e68354f1..f5d87324ab366 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -8,6 +8,8 @@ cp -rpv ../ggml/src/ggml.c ./ggml/src/ggml.c cp -rpv ../ggml/src/ggml-aarch64.c ./ggml/src/ggml-aarch64.c cp -rpv ../ggml/src/ggml-aarch64.h ./ggml/src/ggml-aarch64.h cp -rpv ../ggml/src/ggml-alloc.c ./ggml/src/ggml-alloc.c +cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/ +cp -rpv ../ggml/src/ggml-amx.cpp ./ggml/src/ggml-amx.cpp cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml/src/ggml-backend-impl.h cp -rpv ../ggml/src/ggml-backend.cpp ./ggml/src/ggml-backend.cpp cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ @@ -29,6 +31,7 @@ cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/ cp -rpv ../ggml/include/ggml.h ./ggml/include/ggml.h cp -rpv ../ggml/include/ggml-alloc.h ./ggml/include/ggml-alloc.h +cp -rpv ../ggml/include/ggml-amx.h ./ggml/include/ggml-amx.h cp -rpv ../ggml/include/ggml-backend.h ./ggml/include/ggml-backend.h cp -rpv ../ggml/include/ggml-blas.h ./ggml/include/ggml-blas.h cp -rpv ../ggml/include/ggml-cann.h ./ggml/include/ggml-cann.h diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index bd750c40ec651..25536eb6c5a06 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) { } */ +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} + static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); @@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); } @@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + + llama_sampler_temp_impl(cur_p, ctx->temp); } static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { @@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; + float exponent_val = ctx->exponent; // no need to do anything if there is only one (or zero) candidates @@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= dyn_temp; - } + llama_sampler_temp_impl(cur_p, dyn_temp); // Re-compute softmax probabilities after scaling logits with dynamic temperature const double max_l_double = cur_p->data[0].logit; @@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke } #endif } else { - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); } } @@ -1660,6 +1683,397 @@ struct llama_sampler * llama_sampler_init_penalties( }; } +// DRY + +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) { + std::string word = llama_detokenize(vocab, {token_id}, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector()); + } else { + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { + return "dry"; +} + +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + ctx->last_tokens.push_back(token); +} + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); + + if (last_n_repeat <= ctx->dry_allowed_length) { + return; + } + + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < ctx->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { + ++n; + } + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k+n-1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + ctx->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} + +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; + + // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying + auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler { + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), + }, + }; +} + +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; + + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } + + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } + + return result; +} + // logit-bias struct llama_sampler_logit_bias { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 2683f1b92696f..919f6fdfcefb8 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -28,3 +28,21 @@ struct llama_sampler * llama_sampler_init_grammar_impl( struct llama_sampler * llama_sampler_init_infill_impl( const struct llama_vocab & vocab); + +struct llama_sampler * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0a49ddbe3e291..d1dc96276c2a2 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1966,3 +1966,19 @@ int32_t llama_detokenize_impl( return total <= text_len_max ? total : -total; } + +std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} diff --git a/src/llama-vocab.h b/src/llama-vocab.h index d958d0073be95..4bb16d2e4299f 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -163,3 +163,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special); + +std::string llama_detokenize( + const struct llama_vocab & vocab, + const std::vector & tokens, + bool special); diff --git a/src/llama.cpp b/src/llama.cpp index 1813dd29be2b2..50eebc2c298f5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10,8 +10,6 @@ #if defined(GGML_USE_KOMPUTE) # include "ggml-kompute.h" -#elif defined(GGML_USE_CANN) -# include "ggml-cann.h" #endif #ifndef __AMX_INT8__ @@ -3399,10 +3397,6 @@ static int llama_get_device_count(const llama_model & model) { count += (int) model.rpc_servers.size(); #endif -#if defined(GGML_USE_CANN) - count += ggml_backend_cann_get_device_count(); -#endif - return count; GGML_UNUSED(model); @@ -3420,11 +3414,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(const llama_mode } } -#if defined(GGML_USE_CANN) - if (host_buffer) { - buft = ggml_backend_cann_host_buffer_type(); - } -#elif defined(GGML_USE_CPU_HBM) +#if defined(GGML_USE_CPU_HBM) buft = ggml_backend_cpu_hbm_buffer_type(); #endif @@ -3446,8 +3436,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ #if defined(GGML_USE_KOMPUTE) buft = ggml_backend_kompute_buffer_type(device); -#elif defined(GGML_USE_CANN) - buft = ggml_backend_cann_buffer_type(device); #endif if (buft == nullptr) { @@ -3491,14 +3479,13 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { return free; } -#if defined(GGML_USE_CANN) - size_t total; - size_t free; - ggml_backend_cann_get_device_memory(device, &free, &total); - return free; -#else + if (model.devices.size() > 0) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(model.devices[0]); + LLAMA_LOG_WARN("%s: failed to get free memmory of device:%d of backend:%s, for device id is out of range.\n", __func__, device, ggml_backend_reg_name(reg)); + } else { + LLAMA_LOG_WARN("%s: failed to get free memmory of device, no devices in inputted model.\n", __func__); + } return 1; -#endif GGML_UNUSED(model); GGML_UNUSED(device); @@ -5190,6 +5177,57 @@ struct llama_model_loader { } }; +// temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; +struct llama_batch_allocr { + std::array seq_id_0 = {batch_default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) { + batch = in_batch; + GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos = -1; + for (const auto & cell : ctx.kv_self.cells) { + if (cell.has_seq_id(batch_default_seq_id)) { + last_pos = std::max(last_pos, cell.pos); + } + } + last_pos++; // next position + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; @@ -10030,7 +10068,7 @@ struct llm_build_context { llama_context & lctx; const llama_hparams & hparams; const llama_cparams & cparams; - const llama_ubatch & batch; + const llama_ubatch & ubatch; const llama_kv_cache & kv_self; const int64_t n_embd; @@ -10076,14 +10114,14 @@ struct llm_build_context { // TODO: consider making the entire interface noexcept llm_build_context( llama_context & lctx, - const llama_ubatch & batch, + const llama_ubatch & ubatch, const llm_build_cb & cb, bool worst_case) : model (lctx.model), lctx (lctx), hparams (model.hparams), cparams (lctx.cparams), - batch (batch), + ubatch (ubatch), kv_self (lctx.kv_self), n_embd (hparams.n_embd), n_layer (hparams.n_layer), @@ -10105,7 +10143,7 @@ struct llm_build_context { beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), - n_tokens (batch.n_tokens), + n_tokens (ubatch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), @@ -10474,7 +10512,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -10634,7 +10672,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr; @@ -10749,7 +10787,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -10853,7 +10891,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -10975,7 +11013,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // multiply by embedding_multiplier_scale of 78.38367176906169 inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); @@ -11133,7 +11171,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -11255,7 +11293,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -11358,7 +11396,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -11460,7 +11498,7 @@ struct llm_build_context { } // construct input embeddings (token, type, position) - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // token types are hardcoded to zero ("Sentence A") struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); @@ -11647,7 +11685,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -11749,7 +11787,7 @@ struct llm_build_context { struct ggml_tensor * pos; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -11887,7 +11925,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12037,7 +12075,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12150,7 +12188,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12265,7 +12303,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12410,7 +12448,7 @@ struct llm_build_context { struct ggml_tensor * ffn_output; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12529,7 +12567,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12657,7 +12695,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12762,7 +12800,7 @@ struct llm_build_context { struct ggml_tensor * pos; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12867,7 +12905,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -12977,7 +13015,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -13095,7 +13133,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -13222,7 +13260,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // scale the input embeddings inpL = ggml_scale(ctx0, inpL, scale_embd); @@ -13366,7 +13404,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // scale the input embeddings inpL = ggml_scale(ctx0, inpL, scale_embd); @@ -13567,7 +13605,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1); @@ -13675,7 +13713,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); cb(inpL, "inp_scaled", -1); @@ -13813,7 +13851,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -13929,7 +13967,7 @@ struct llm_build_context { struct ggml_tensor * inpL; // {n_embd, n_tokens} - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); @@ -13941,7 +13979,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur, state_copy, state_mask, kv_head, n_kv, cb, il); @@ -13987,7 +14025,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14144,7 +14182,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14272,7 +14310,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14391,7 +14429,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14518,7 +14556,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14663,7 +14701,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -14804,7 +14842,7 @@ struct llm_build_context { struct ggml_tensor * inpL; // {n_embd, n_tokens} - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -15019,7 +15057,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -15173,7 +15211,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); GGML_ASSERT(lctx.is_encoding); struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false); @@ -15305,7 +15343,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); GGML_ASSERT(!lctx.is_encoding); GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); @@ -15507,7 +15545,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -15599,7 +15637,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -15713,7 +15751,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -15837,7 +15875,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -15957,11 +15995,11 @@ struct llm_build_context { // Token shift state dimensions should be 2 * n_emb GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); - const int64_t n_seqs = batch.n_seqs; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); struct ggml_tensor * cur; @@ -15969,7 +16007,7 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); for (int il = 0; il < n_layer; ++il) { @@ -16083,7 +16121,7 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -16279,7 +16317,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_ubatch & batch, + const llama_ubatch & ubatch, bool worst_case) { const auto & model = lctx.model; @@ -16301,7 +16339,7 @@ static struct ggml_cgraph * llama_build_graph( // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer; - if (batch.n_tokens < 32 || full_offload) { + if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { for (auto * backend : lctx.backends) { if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) && @@ -16316,7 +16354,7 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, batch, cb, worst_case); + struct llm_build_context llm(lctx, ubatch, cb, worst_case); llm.init(); @@ -16567,7 +16605,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } -static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { // // set input data // @@ -16576,28 +16614,28 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; - if (batch.token) { - const int64_t n_tokens = batch.n_tokens; + if (ubatch.token) { + const int64_t n_tokens = ubatch.n_tokens; - ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); + ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); } - if (batch.embd) { + if (ubatch.embd) { const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; - ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); } - if (batch.pos && lctx.inp_pos) { - const int64_t n_tokens = batch.n_tokens; + if (ubatch.pos && lctx.inp_pos) { + const int64_t n_tokens = ubatch.n_tokens; - ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); + ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); int32_t * data = (int32_t *) lctx.inp_out_ids->data; @@ -16606,10 +16644,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.output) { + } else if (ubatch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.output[i]) { + if (ubatch.output[i]) { data[n_outputs++] = i; } } @@ -16634,9 +16672,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; float * data = nullptr; @@ -16653,14 +16691,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the batch. + // of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = batch.pos[s*n_seq_tokens + j]; + const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; for (int i = 0; i < n_kv; ++i) { float f; @@ -16706,9 +16744,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } else { - const int64_t n_tokens = batch.n_tokens; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; // when using kv cache, the mask needs to match the kv cache size const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; @@ -16718,7 +16756,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { for (int h = 0; h < 1; ++h) { for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = batch.seq_id[s1][0]; + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; for (int j = 0; j < n_seq_tokens; ++j) { const int32_t tj = s1*n_seq_tokens + j; @@ -16728,10 +16766,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const int32_t ti = s0*n_seq_tokens + i; float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[s0]; ++s) { - if (batch.seq_id[s0][s] == seq_id) { + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id) { if (hparams.use_alibi) { - f = -std::abs(batch.pos[ti] - batch.pos[tj]); + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); } else { f = 0.0f; } @@ -16753,9 +16791,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -16766,12 +16804,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { std::vector sum(n_tokens, 0); for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; - // TODO: adapt limits to n_seqs when batch.equal_seqs is true + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += batch.n_seq_tokens; + sum[seq_id] += ubatch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -16783,7 +16821,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; for (int i = 0; i < n_seq_tokens; ++i) { data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; @@ -16794,9 +16832,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (cparams.embeddings && ( cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { - const int64_t n_tokens = batch.n_tokens; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -16805,13 +16843,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; - // TODO: adapt limits to n_seqs when batch.equal_seqs is true + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; if (pos == 0) { data[seq_id] = s*n_seq_tokens + i; @@ -16821,9 +16859,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = batch.n_tokens; - const int64_t n_seq_tokens = batch.n_seq_tokens; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -16835,13 +16873,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { std::vector last_row(n_tokens, -1); for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; - // TODO: adapt limits to n_seqs when batch.equal_seqs is true + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; if (pos >= last_pos[seq_id]) { last_pos[seq_id] = pos; @@ -16903,10 +16941,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (lctx.inp_pos_bucket) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); - GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; @@ -16915,7 +16953,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); } } } @@ -16923,7 +16961,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); } } } @@ -16939,10 +16977,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); - GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing float * data = (float *) lctx.inp_KQ_mask_cross->data; @@ -16950,8 +16988,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_output_enc; ++i) { float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[j]; ++s) { - const llama_seq_id seq_id = batch.seq_id[j][s]; + for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[j][s]; if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { f = 0.0f; } @@ -17108,16 +17146,20 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = false; - const uint32_t n_tokens_all = batch.n_tokens; - if (n_tokens_all == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(lctx, inp_batch); + const llama_batch & batch = batch_allocr.batch; + const uint32_t n_tokens_all = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -17422,17 +17464,20 @@ static int llama_decode_internal( // static int llama_encode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = true; - const uint32_t n_tokens = batch.n_tokens; - - if (n_tokens == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(lctx, inp_batch); + const llama_batch & batch = batch_allocr.batch; + const uint32_t n_tokens = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -19243,7 +19288,7 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) { + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; } @@ -19396,30 +19441,6 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(backend); } -#elif defined(GGML_USE_CANN) - // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used - // TODO: ggml_backend_cann is not support split tensor now, just leave code here. - if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { - ggml_backend_t backend = ggml_backend_cann_init(main_gpu); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, main_gpu); - llama_free(ctx); - return nullptr; - } - ctx->backends.push_back(backend); - } else { - // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU - // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version. - for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) { - ggml_backend_t backend = ggml_backend_cann_init(device); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device); - llama_free(ctx); - return nullptr; - } - ctx->backends.push_back(backend); - } - } #endif // add other backends (such as BLAS) @@ -21127,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } -// temporary allocate memory for the input batch if needed -static const llama_seq_id batch_default_seq_id = 0; -struct llama_batch_allocr { - std::array seq_id_0 = {batch_default_seq_id}; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id; - std::vector logits; - struct llama_batch batch; - // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) { - batch = in_batch; - if (!batch.pos) { - // determine the last position in KV cache - llama_pos last_pos = -1; - for (const auto & cell : ctx->kv_self.cells) { - if (cell.has_seq_id(batch_default_seq_id)) { - last_pos = std::max(last_pos, cell.pos); - } - } - last_pos++; // next position - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i+last_pos; - } - batch.pos = pos.data(); - } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - n_seq_id[i] = seq_id_0.size(); - } - batch.n_seq_id = n_seq_id.data(); - } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { - seq_id[i] = seq_id_0.data(); - } - batch.seq_id = seq_id.data(); - } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); - } - } -}; - int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_encode_internal(*ctx, batch_allocr.batch); + const int ret = llama_encode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21192,8 +21162,7 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_decode_internal(*ctx, batch_allocr.batch); + const int ret = llama_decode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -21734,6 +21703,16 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) { + // this template requires the model to have "\n\n" as EOT token + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "User: " << message->content << "\n\nAssistant:"; + } else { + ss << message->content << "\n\n"; + } + } } else { // template not supported return -1; @@ -21796,6 +21775,10 @@ struct llama_sampler * llama_sampler_init_infill(const struct llama_model * mode return llama_sampler_init_infill_impl(model->vocab); } +struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers); +} + // // model split // diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ee1a8877e1b0a..2e3ad79f01944 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case { const int64_t m; const int64_t n; const int64_t k; - const std::array bs; // dims 3 and 4 - const std::array nr; // repeat in dims 3 and 4 + const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 + const std::array per; // permutation of dimensions std::string vars() override { - return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr); + return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per); } double max_nmse_err() override { @@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case { test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, - std::array nr = {2, 2}) - : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {} + std::array nr = {2, 2}, + std::array per = {0, 1, 2, 3}) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) - ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); - ggml_set_name(a, "a"); - ggml_set_name(b, "b"); + ggml_tensor * a; + ggml_tensor * b; + + const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3); + if (npermuted > 0) { + GGML_ASSERT(npermuted == 2); + GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0); + GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0); + + // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k. + const int64_t ne_a[4] = {k, m, bs[0], bs[1]}; + const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]}; + + a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); + b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + + a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]); + b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]); + ggml_set_name(a, "a_permuted"); + ggml_set_name(b, "b_permuted"); + } else { + a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + } ggml_tensor * out = ggml_mul_mat(ctx, a, b); ggml_set_name(out, "out"); @@ -3308,13 +3336,49 @@ static std::vector> make_test_cases_eval() { } } - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); - // test cases for 1D im2col + // im2col 1D test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + for (int s0 : {1, 3}) { + for (int p0 : {0, 3}) { + for (int d0 : {1, 3}) { + test_cases.emplace_back(new test_im2col( + GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1}, + s0, 0, p0, 0, d0, 0, false)); + } + } + } + + // im2col 2D + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + for (int s0 : {1, 3}) { + for (int s1 : {1, 3}) { + for (int p0 : {0, 3}) { + for (int p1 : {0, 3}) { + for (int d0 : {1, 3}) { + for (int d1 : {1, 3}) { + test_cases.emplace_back(new test_im2col( + GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2}, + s0, s1, p0, p1, d0, d1, true)); + } + } + } + } + } + } + + // extra tests for im2col 2D + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) @@ -3442,13 +3506,14 @@ static std::vector> make_test_cases_eval() { #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + // test cases without permutation + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); @@ -3457,6 +3522,19 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + + // test cases with permutation + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); } } for (ggml_type type_a : other_types) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 1372bdf13f2f6..eb39661c3698f 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -10,6 +10,8 @@ #include #include +extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers); + static void dump(const llama_token_data_array * cur_p) { for (size_t i = 0; i < cur_p->size; i++) { printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); @@ -18,203 +20,199 @@ static void dump(const llama_token_data_array * cur_p) { #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0) -#define APPLY(__cnstr, __cur_p) do { \ - auto * cnstr = (__cnstr); \ - llama_sampler_apply(cnstr, (__cur_p)); \ - llama_sampler_free(cnstr); \ -} while(0) +struct sampler_tester { + sampler_tester(size_t n_vocab) { + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { + const float logit = logf(token_id); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); + } -static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { - const size_t n_vocab = probs.size(); + cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false }; + } - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); + sampler_tester(const std::vector & probs, const std::vector & probs_expected) : probs_expected(probs_expected) { + cur.reserve(probs.size()); + for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) { + const float logit = logf(probs[token_id]); + cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]}); + } + + cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false }; } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); - DUMP(&cur_p); - APPLY(llama_sampler_init_top_k(k), &cur_p); - DUMP(&cur_p); - - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5); + void apply(llama_sampler * sampler) { + llama_sampler_apply(sampler, &cur_p); + llama_sampler_free(sampler); } -} -static void test_top_p(const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); + void check() { + GGML_ASSERT(cur_p.size == probs_expected.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5); + } + } + + llama_token_data_array cur_p; + +private: + const std::vector probs_expected; std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } +}; - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); - DUMP(&cur_p); - APPLY(llama_sampler_init_top_p(p, 1), &cur_p); - DUMP(&cur_p); - - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); - } +static void test_temp(const std::vector & probs, const std::vector & probs_expected, float temp) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_temp(temp)); + tester.apply(llama_sampler_init_dist(0)); + DUMP(&tester.cur_p); + + tester.check(); } -static void test_tfs(const std::vector & probs, const std::vector & expected_probs, float z) { - const size_t n_vocab = probs.size(); +static void test_temp_ext(const std::vector & probs, const std::vector & probs_expected, float temp, float delta, float exponent) { + sampler_tester tester(probs, probs_expected); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - DUMP(&cur_p); - APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); - DUMP(&cur_p); + tester.check(); +} - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); - } +static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_k(k)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); } -static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); +static void test_top_p(const std::vector & probs, const std::vector & probs_expected, float p) { + sampler_tester tester(probs, probs_expected); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_p(p, 1)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - DUMP(&cur_p); - APPLY(llama_sampler_init_min_p(p, 1), &cur_p); - DUMP(&cur_p); - APPLY(llama_sampler_init_softmax(), &cur_p); - - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); - } + tester.check(); } -static void test_xtc(const std::vector & probs, const std::vector & expected_probs, float p, float t) { - const size_t n_vocab = probs.size(); +static void test_tfs(const std::vector & probs, const std::vector & probs_expected, float z) { + sampler_tester tester(probs, probs_expected); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_tail_free(z, 1)); + DUMP(&tester.cur_p); - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - APPLY(llama_sampler_init_softmax(), &cur_p); - DUMP(&cur_p); - APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p); - DUMP(&cur_p); - - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5); - } + tester.check(); } -static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) { - const size_t n_vocab = probs.size(); +static void test_min_p(const std::vector & probs, const std::vector & probs_expected, float p) { + sampler_tester tester(probs, probs_expected); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_min_p(p, 1)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - DUMP(&cur_p); - APPLY(llama_sampler_init_typical(p, 1), &cur_p); - DUMP(&cur_p); + tester.check(); +} - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); - } +static void test_xtc(const std::vector & probs, const std::vector & probs_expected, float p, float t) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_xtc(p, t, 0, 0)); + DUMP(&tester.cur_p); + + tester.check(); +} + +static void test_typical(const std::vector & probs, const std::vector & probs_expected, float p) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_typical(p, 1)); + DUMP(&tester.cur_p); + + tester.check(); } static void test_penalties( const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence + const std::vector & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence ) { - GGML_ASSERT(probs.size() == expected_probs.size()); + GGML_ASSERT(probs.size() == probs_expected.size()); + + sampler_tester tester(probs, probs_expected); const size_t n_vocab = probs.size(); + auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(probs[token_id]); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); + for (size_t i = 0; i < last_tokens.size(); i++) { + llama_sampler_accept(sampler, last_tokens[i]); } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + DUMP(&tester.cur_p); + tester.apply(sampler); + tester.apply(llama_sampler_init_dist(0)); + DUMP(&tester.cur_p); - auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); + tester.check(); +} + +static void test_dry( + const std::vector & probs, const std::vector & last_tokens, + const std::vector & expected_probs, float dry_multiplier, float dry_base, + int dry_allowed_length, int dry_penalty_last_n, + const std::vector> & seq_breakers +) { + GGML_ASSERT(probs.size() == expected_probs.size()); + + sampler_tester tester(probs, expected_probs); + + auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers); for (size_t i = 0; i < last_tokens.size(); i++) { llama_sampler_accept(sampler, last_tokens[i]); } - APPLY(llama_sampler_init_softmax(), &cur_p); - DUMP(&cur_p); - APPLY(sampler, &cur_p); - APPLY(llama_sampler_init_softmax(), &cur_p); - DUMP(&cur_p); - - GGML_ASSERT(cur_p.size == expected_probs.size()); - for (size_t i = 0; i < cur_p.size; i++) { - GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); - } + DUMP(&tester.cur_p); + tester.apply(sampler); + tester.apply(llama_sampler_init_dist(0)); + DUMP(&tester.cur_p); + tester.check(); } static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - const float logit = logf(token_id); - cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + sampler_tester tester(n_vocab); llama_token min_token_id = 0; const llama_token max_token_id = n_vocab-1; for (auto s : samplers_sequence) { switch (s){ - case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break; + case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break; case 'f': GGML_ABORT("tail_free test not implemented"); case 'y': GGML_ABORT("typical test not implemented"); - case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break; - case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break; + case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break; + case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break; case 't': GGML_ABORT("temperature test not implemented"); default : GGML_ABORT("Unknown sampler"); } - APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests + tester.apply(llama_sampler_init_dist(0)); + + auto & cur_p = tester.cur_p; const int size = cur_p.size; @@ -307,21 +305,26 @@ static void test_perf() { BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); BENCH(llama_sampler_init_typical (0.5f, 1), data, 32); BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32); - BENCH(llama_sampler_init_softmax (), data, 32); } int main(void) { ggml_time_init(); - test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); - test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); + test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f); + test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f); + + test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f); + test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f); + + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f); - test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f); test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f); @@ -355,6 +358,13 @@ int main(void) { test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + + test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {}); + test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {}); + test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}}); + test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); + test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);