Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YAML logging and presets #2657

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 314 additions & 9 deletions common/common.cpp

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
#include <unordered_map>
#include <tuple>

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
#else
#define DIRECTORY_SEPARATOR '/'
#endif // _WIN32

//
// CLI argument parsing
//
Expand Down Expand Up @@ -61,6 +67,7 @@ struct gpt_params {
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter
Expand All @@ -82,6 +89,7 @@ struct gpt_params {
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

bool embedding = false; // get only sentence embedding
bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
Expand Down Expand Up @@ -144,3 +152,13 @@ std::string llama_detokenize_spm(
std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);

bool create_directory_with_parents(const std::string & path);
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);
std::string get_sortable_timestamp();

void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
78 changes: 76 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

Expand All @@ -36,9 +37,57 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static llama_context ** g_ctx;
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;

void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
const std::vector<llama_token> input_tokens, const std::string output, const std::vector<llama_token> output_tokens) {

if (params.logdir.empty()) {
return;
}

const std::string timestamp = get_sortable_timestamp();

const bool success = create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}

const std::string logfile_path = params.logdir + timestamp + ".yml";
FILE * logfile = fopen(logfile_path.c_str(), "w");

if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}

fprintf(logfile, "binary: main\n");
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);

fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Generation Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");

dump_string_yaml_multiline(logfile, "output", output.c_str());
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);

llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
}

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
if (signo == SIGINT) {
Expand All @@ -48,6 +97,7 @@ void sigint_handler(int signo) {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
}
Expand All @@ -56,6 +106,7 @@ void sigint_handler(int signo) {

int main(int argc, char ** argv) {
gpt_params params;
g_params = &params;

if (gpt_params_parse(argc, argv, params) == false) {
return 1;
Expand Down Expand Up @@ -116,6 +167,7 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
g_model = &model;
g_ctx = &ctx;

// load the model and apply lora adapter, if any
Expand Down Expand Up @@ -397,6 +449,10 @@ int main(int argc, char ** argv) {
int n_session_consumed = 0;
int n_past_guidance = 0;

std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;

// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);

Expand Down Expand Up @@ -667,7 +723,15 @@ int main(int argc, char ** argv) {
// display text
if (input_echo) {
for (auto id : embd) {
printf("%s", llama_token_to_piece(ctx, id).c_str());
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());

if (embd.size() > 1) {
input_tokens.push_back(id);
} else {
output_tokens.push_back(id);
output_ss << token_str;
}
}
fflush(stdout);
}
Expand Down Expand Up @@ -761,6 +825,8 @@ int main(int argc, char ** argv) {
printf("%s", params.input_suffix.c_str());
}

const size_t original_size = embd_inp.size();

// instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) {
n_consumed = embd_inp.size();
Expand All @@ -775,6 +841,12 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}

for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token);
}

n_remain -= line_inp.size();
}

Expand Down Expand Up @@ -817,6 +889,8 @@ int main(int argc, char ** argv) {
}

llama_print_timings(ctx);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);

if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);
Expand Down
Loading
Loading