Skip to content

Commit

Permalink
WIP YAML logging
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 27, 2023
1 parent 230d46c commit 6094648
Show file tree
Hide file tree
Showing 8 changed files with 599 additions and 39 deletions.
288 changes: 279 additions & 9 deletions common/common.cpp

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct gpt_params {
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter
Expand All @@ -82,6 +83,7 @@ struct gpt_params {
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

bool embedding = false; // get only sentence embedding
bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
Expand Down Expand Up @@ -144,3 +146,13 @@ std::string llama_detokenize_spm(
std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);

bool create_directory_with_parents(const std::string & path);
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data, bool remove_first);
std::string get_sortable_timestamp();

void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model);
44 changes: 43 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

Expand Down Expand Up @@ -397,6 +398,10 @@ int main(int argc, char ** argv) {
int n_session_consumed = 0;
int n_past_guidance = 0;

std::vector<int> input_tokens;
std::vector<int> output_tokens;
std::ostringstream output_ss;

// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);

Expand Down Expand Up @@ -667,7 +672,15 @@ int main(int argc, char ** argv) {
// display text
if (input_echo) {
for (auto id : embd) {
printf("%s", llama_token_to_piece(ctx, id).c_str());
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());

if (embd.size() > 1) {
input_tokens.push_back(id);
} else {
output_tokens.push_back(id);
output_ss << token_str;
}
}
fflush(stdout);
}
Expand Down Expand Up @@ -817,6 +830,35 @@ int main(int argc, char ** argv) {
}

llama_print_timings(ctx);

if (!params.logdir.empty()) {
const std::string timestamp = get_sortable_timestamp();

const bool success = create_directory_with_parents(params.logdir);
if (success) {

FILE * logfile = fopen((params.logdir + timestamp + ".yml").c_str(), "w");
fprintf(logfile, "binary: main\n");
char model_type[128];
llama_model_desc(model, model_type, sizeof(model_type));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_type);

fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Generation Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");

dump_string_yaml_multiline(logfile, "output", output_ss.str().c_str(), false);
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);

llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
} else {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
}
}
if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);
Expand Down
128 changes: 100 additions & 28 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
#include "build-info.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <sstream>
#include <cstring>
#include <thread>
#include <mutex>
#include <tuple>
#include <utility>
#include <vector>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand All @@ -29,20 +33,20 @@ std::vector<float> softmax(const std::vector<float>& logits) {
return probs;
}

float log_softmax(int n_vocab, const float * logits, int tok) {
std::tuple<double, float, float> log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
return logits[tok] - max_logit - log(sum_exp);
return std::make_tuple(-(logits[tok] - max_logit - log(sum_exp)), logits[tok], expf(logits[tok] - max_logit) / sum_exp);
}

void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
double& nll, double& nll2) {
void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
double & nll, double & nll2, float * logit_history, float * prob_history) {

std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0, local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
Expand All @@ -52,34 +56,44 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n
break;
}
lock.unlock();
double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
const std::tuple<double, float, float> v = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
const double v0 = std::get<0>(v);
local_nll += v0;
local_nll2 += v0*v0;

logit_history[i] = std::get<1>(v);
prob_history[i] = std::get<2>(v);
}
};
for (auto& w : workers) w = std::thread(compute);
for (auto & w : workers) w = std::thread(compute);
compute();
for (auto& w : workers) w.join();
for (auto & w : workers) w.join();

}

void perplexity_v2(llama_context * ctx, const gpt_params & params) {
std::tuple<std::vector<llama_token>, std::vector<float>, std::vector<float>, float>
perplexity_v2(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval

if (params.ppl_stride <= 0) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return;
}

const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;

fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<float> logit_history;
std::vector<float> prob_history;

logit_history.resize(tokens.size());
prob_history.resize(tokens.size());

if (params.ppl_stride <= 0) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return std::make_tuple(tokens, logit_history, prob_history, -1);
}

const int calc_chunk = params.n_ctx;

Expand All @@ -88,7 +102,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
return;
return std::make_tuple(tokens, logit_history, prob_history, -1);
}

const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
Expand Down Expand Up @@ -120,7 +134,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return;
return std::make_tuple(tokens, logit_history, prob_history, -1);
}

// save original token and restore it after eval
Expand Down Expand Up @@ -161,6 +175,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
logits.begin() + (j + 1) * n_vocab);

const float prob = softmax(tok_logits)[tokens[start + j + 1]];
logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
prob_history[start + j + 1] = prob;

nll += -std::log(prob);
++count;
Expand All @@ -174,12 +190,15 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
fflush(stdout);
}
printf("\n");

return std::make_tuple(tokens, logit_history, prob_history, std::exp(nll / count));
}

void perplexity(llama_context * ctx, const gpt_params & params) {
std::tuple<std::vector<llama_token>, std::vector<float>, std::vector<float>, float>
perplexity(llama_context * ctx, const gpt_params & params) {

if (params.ppl_stride > 0) {
perplexity_v2(ctx, params);
return;
return perplexity_v2(ctx, params);
}

// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
Expand All @@ -193,11 +212,17 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);

auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);

auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());

std::vector<float> logit_history;
logit_history.resize(tokens.size());

std::vector<float> prob_history;
prob_history.resize(tokens.size());

const int n_chunk_max = tokens.size() / params.n_ctx;

const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
Expand Down Expand Up @@ -236,7 +261,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {

if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
return std::make_tuple(tokens, logit_history, prob_history, -1);
}

// restore the original token in case it was set to BOS
Expand Down Expand Up @@ -272,7 +297,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = std::min(512, params.n_ctx/2);
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += params.n_ctx - first - 1;

// perplexity is e^(average negative log-likelihood)
Expand All @@ -287,16 +313,19 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
fflush(stdout);
}
printf("\n");

nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
double ppl = exp(nll);
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
}

return std::make_tuple(tokens, logit_history, prob_history, ppl);
}

std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
Expand Down Expand Up @@ -604,13 +633,56 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}

std::vector<llama_token> tokens;
std::vector<float> logits;
std::vector<float> probs;
double perplexity_value = -1;
if (params.hellaswag) {
hellaswag_score(ctx, params);
} else {
perplexity(ctx, params);
auto ret = perplexity(ctx, params);
tokens = std::get<0>(ret);
logits = std::get<1>(ret);
probs = std::get<2>(ret);
perplexity_value = std::get<3>(ret);
}

llama_print_timings(ctx);

if (params.hellaswag && !params.logdir.empty()) {
fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
}

if (!params.hellaswag && !params.logdir.empty()) {
const std::string timestamp = get_sortable_timestamp();

const bool success = create_directory_with_parents(params.logdir);
if (success) {

FILE * logfile = fopen((params.logdir + timestamp + ".yml").c_str(), "w");
fprintf(logfile, "binary: perplexity\n");
char model_type[128];
llama_model_desc(model, model_type, sizeof(model_type));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, tokens, model_type);

fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Perplexity Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");

dump_vector_float_yaml(logfile, "logits", logits);
fprintf(logfile, "ppl_value: %f\n", perplexity_value);
dump_vector_float_yaml(logfile, "probs", probs);

llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
} else {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
}
}

llama_free(ctx);
llama_free_model(model);

Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -ts SPLIT --tensor-split SPLIT\n");
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
fprintf(stdout, " -nommq, --no-mul-mat-q\n");
fprintf(stdout, " use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n");
Expand Down
Loading

0 comments on commit 6094648

Please sign in to comment.