Skip to content

Commit

Permalink
Faster repetition penalty sampling (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Nov 24, 2023
1 parent b071907 commit e09726e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
21 changes: 10 additions & 11 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <string>
#include <sys/stat.h>
#include <thread>
#include <unordered_set>

#ifdef __has_include
#if __has_include(<unistd.h>)
Expand Down Expand Up @@ -589,19 +588,19 @@ int BaseModelForCausalLM::generate_next_token(const std::vector<int> &input_ids,
void BaseModelForCausalLM::sampling_repetition_penalty(float *first, float *last, const std::vector<int> &input_ids,
float penalty) {
CHATGLM_CHECK(penalty > 0) << "penalty must be a positive float, but got " << penalty;
std::unordered_set<int> unique_input_ids(input_ids.begin(), input_ids.end());
for (int id : unique_input_ids) {
CHATGLM_CHECK(first <= first + id && first + id < last) << "invalid input id " << id;
if (first[id] > 0) {
first[id] /= penalty;
} else {
first[id] *= penalty;
const float inv_penalty = 1.f / penalty;
const int vocab_size = last - first;
std::vector<bool> occurrence(vocab_size, false);
for (const int id : input_ids) {
if (!occurrence[id]) {
first[id] *= (first[id] > 0) ? inv_penalty : penalty;
}
occurrence[id] = true;
}
}

void BaseModelForCausalLM::sampling_temperature(float *first, float *last, float temp) {
float inv_temp = 1.f / temp;
const float inv_temp = 1.f / temp;
for (float *it = first; it != last; it++) {
*it *= inv_temp;
}
Expand All @@ -616,12 +615,12 @@ TokenIdScore *BaseModelForCausalLM::sampling_top_p(TokenIdScore *first, TokenIdS
sampling_softmax_inplace(first, last);

while (first + 1 < last) {
float pivot_score = (last - 1)->score; // use mid score?
const float pivot_score = (last - 1)->score; // use mid score?
TokenIdScore *mid =
std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; });
std::swap(*mid, *(last - 1));

float prefix_sum =
const float prefix_sum =
std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; });
if (prefix_sum >= top_p) {
last = mid;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage

__version__ = "0.3.0"
__version__ = "0.3.1.dev"


@dataclass
Expand Down
24 changes: 24 additions & 0 deletions chatglm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ static inline char *read_tensor_data(char *ptr, ggml_tensor *tensor) {

static inline float random() { return rand() / (float)RAND_MAX; }

static inline float random(float lo, float hi) { return lo + random() * (hi - lo); }

static inline void random_fill(ggml_tensor *tensor) {
std::vector<float> values(ggml_nelements(tensor));
for (float &v : values) {
Expand Down Expand Up @@ -115,6 +117,28 @@ TEST(Sampling, RepetitionPenalty) {
}
}

TEST(DISABLED_Sampling, BenchmarkRepetitionPenalty) {
const float penalty = 1.2;
constexpr size_t vocab_size = 128000;
constexpr int seq_len = 32000;
std::vector<float> logits(vocab_size);
for (auto &x : logits) {
x = random(-1, 1);
}
std::vector<int> input_ids(seq_len);
for (size_t i = 0; i < input_ids.size(); i++) {
input_ids[i] = i;
}

auto fn = [&logits, &input_ids, penalty] {
BaseModelForCausalLM::sampling_repetition_penalty(logits.data(), logits.data() + logits.size(), input_ids,
penalty);
};
auto elapsed_ms = timeit(fn, 2, 100);
std::cout << "[" << ::testing::UnitTest::GetInstance()->current_test_info()->name() << "] " << elapsed_ms
<< " ms\n";
}

TEST(Sampling, Temperature) {
constexpr float temp = 0.7;
std::vector<float> logits(64);
Expand Down

0 comments on commit e09726e

Please sign in to comment.