diff --git a/train_gpt2.cu b/train_gpt2.cu index 074708d0a..913a65393 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -198,6 +198,7 @@ __device__ __forceinline__ void stochastic_rounding(float in, float *out, unsign // ---------------------------------------------------------------------------- // fread convenience utils, with nice handling of error checking using macros // simple replace fopen, fread, fclose with fopenCheck, freadCheck, fcloseCheck + FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { FILE *fp = fopen(path, mode); if (fp == NULL) { @@ -211,6 +212,7 @@ FILE *fopen_check(const char *path, const char *mode, const char *file, int line } return fp; } + #define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__) void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { @@ -232,6 +234,7 @@ void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char exit(EXIT_FAILURE); } } + #define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) void fclose_check(FILE *fp, const char *file, int line) { @@ -243,6 +246,7 @@ void fclose_check(FILE *fp, const char *file, int line) { exit(EXIT_FAILURE); } } + #define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__) // ---------------------------------------------------------------------------- @@ -260,6 +264,7 @@ void *malloc_check(size_t size, const char *file, int line) { } return ptr; } + #define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) // ---------------------------------------------------------------------------- @@ -2055,6 +2060,7 @@ int sample_softmax(const float* logits, int n, float coin) { // ---------------------------------------------------------------------------- // Tokenizer (only supports decoding: tokens (integers) -> strings) + typedef struct { uint32_t vocab_size; char **token_table; @@ -2134,6 +2140,7 @@ void tokenizer_free(Tokenizer *tokenizer) { // ---------------------------------------------------------------------------- // Logger lite, will probably grow/change some over time + typedef struct { FILE *logfile; int flush_every; // every how many steps to flush the log @@ -2164,6 +2171,7 @@ void logger_free(Logger *logger) { // ---------------------------------------------------------------------------- // CLI, poor man's argparse + void error_usage() { // default run = debugging run with TinyShakespeare // bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile