Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bench method & update example #269

Merged
merged 25 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
27da9c9
feat(example): bump to rn 0.74 & add react-navigation
jhen0409 Nov 6, 2024
a210e88
feat(example): wip Bench container
jhen0409 Nov 6, 2024
492a6dc
feat(example): download models
jhen0409 Nov 7, 2024
2b7b8e7
feat(ios, cpp): add bench method
jhen0409 Nov 7, 2024
b4d7d94
feat(ios): log as markdown
jhen0409 Nov 7, 2024
b2d5f52
fix: timings
jhen0409 Nov 7, 2024
089e57f
feat(example): copy to clipboard for logs
jhen0409 Nov 7, 2024
74428f8
fix(ts): lint errors
jhen0409 Nov 7, 2024
6edb946
feat(example): update default select
jhen0409 Nov 7, 2024
9e90e97
feat(cpp, android): add system config & update android module
jhen0409 Nov 7, 2024
13e38bf
fix(cpp): system_info parse
jhen0409 Nov 7, 2024
dc6c6d2
fix: patch
jhen0409 Nov 7, 2024
46badb5
feat(cpp): refactor whisper_timings
jhen0409 Nov 7, 2024
5783cdd
fix(example): use react-native-gesture-handler
jhen0409 Nov 7, 2024
d5e4334
feat(example): split copy logs button
jhen0409 Nov 7, 2024
7681312
fix(example): log
jhen0409 Nov 7, 2024
920826e
Merge branch 'main' into bench
jhen0409 Nov 7, 2024
9726d1e
feat(example): move useFlashAttn usage to context-opts
jhen0409 Nov 7, 2024
b8babc8
feat(example): minor refactor
jhen0409 Nov 8, 2024
7c6c6fc
chore: update deps & docgen
jhen0409 Nov 8, 2024
413dc6d
fix(example): remove default props
jhen0409 Nov 8, 2024
ac3f9db
feat(ts): update mock
jhen0409 Nov 8, 2024
3c9310c
fix: deps
jhen0409 Nov 8, 2024
ff751bb
feat(example): update model button style
jhen0409 Nov 8, 2024
6b6db48
fix(example): realtime button title
jhen0409 Nov 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions android/src/main/java/com/rnwhisper/RNWhisper.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ protected void onPostExecute(Void result) {
tasks.put(task, "abortTranscribe-" + id);
}

public void bench(double id, double nThreads, Promise promise) {
final WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
promise.resolve(context.bench((int) nThreads));
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
5 changes: 5 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ public void stopCurrentTranscribe() {
stopTranscribe(this.jobId);
}

public String bench(int n_threads) {
return bench(context, n_threads);
}

public void release() {
stopCurrentTranscribe();
freeContext(context);
Expand Down Expand Up @@ -527,4 +531,5 @@ protected static native int fullWithJob(
int slice_index,
int n_samples
);
protected static native String bench(long context, int n_threads);
}
13 changes: 13 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,4 +508,17 @@ Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
return whisper_full_get_segment_speaker_turn_next(context, index);
}

JNIEXPORT jstring JNICALL
Java_com_rnwhisper_WhisperContext_bench(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jint n_threads
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
std::string result = rnwhisper::bench(context, n_threads);
return env->NewStringUTF(result.c_str());
}

} // extern "C"
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
rnwhisper.abortTranscribe(contextId, jobId, promise);
}

@ReactMethod
public void bench(double id, double nThreads, Promise promise) {
rnwhisper.bench(id, nThreads, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnwhisper.releaseContext(id, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
rnwhisper.abortTranscribe(contextId, jobId, promise);
}

@ReactMethod
public void bench(double id, double nThreads, Promise promise) {
rnwhisper.bench(id, nThreads, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnwhisper.releaseContext(id, promise);
Expand Down
91 changes: 91 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,97 @@

namespace rnwhisper {

const char * system_info(void) {
static std::string s;
s = "";
if (wsp_ggml_cpu_has_avx() == 1) s += "AVX ";
if (wsp_ggml_cpu_has_avx2() == 1) s += "AVX2 ";
if (wsp_ggml_cpu_has_avx512() == 1) s += "AVX512 ";
if (wsp_ggml_cpu_has_fma() == 1) s += "FMA ";
if (wsp_ggml_cpu_has_neon() == 1) s += "NEON ";
if (wsp_ggml_cpu_has_arm_fma() == 1) s += "ARM_FMA ";
if (wsp_ggml_cpu_has_metal() == 1) s += "METAL ";
if (wsp_ggml_cpu_has_f16c() == 1) s += "F16C ";
if (wsp_ggml_cpu_has_fp16_va() == 1) s += "FP16_VA ";
if (wsp_ggml_cpu_has_blas() == 1) s += "BLAS ";
if (wsp_ggml_cpu_has_sse3() == 1) s += "SSE3 ";
if (wsp_ggml_cpu_has_ssse3() == 1) s += "SSSE3 ";
if (wsp_ggml_cpu_has_vsx() == 1) s += "VSX ";
#ifdef WHISPER_USE_COREML
s += "COREML ";
#endif
s.erase(s.find_last_not_of(" ") + 1);
return s.c_str();
}

std::string bench(struct whisper_context * ctx, int n_threads) {
const int n_mels = whisper_model_n_mels(ctx);

if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
return "error: failed to set mel: " + std::to_string(ret);
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

whisper_token tokens[512];
memset(tokens, 0, sizeof(tokens));

// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

whisper_reset_timings(ctx);

// actual run
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

const struct whisper_timings * timings = whisper_get_timings(ctx);

const int32_t n_encode = std::max(1, timings->n_encode);
const int32_t n_decode = std::max(1, timings->n_decode);
const int32_t n_batchd = std::max(1, timings->n_batchd);
const int32_t n_prompt = std::max(1, timings->n_prompt);

return std::string("[") +
"\"" + system_info() + "\"," +
std::to_string(n_threads) + "," +
std::to_string(1e-3f * timings->t_encode_us / n_encode) + "," +
std::to_string(1e-3f * timings->t_decode_us / n_decode) + "," +
std::to_string(1e-3f * timings->t_batchd_us / n_batchd) + "," +
std::to_string(1e-3f * timings->t_prompt_us / n_prompt) + "]";
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
Expand Down
2 changes: 2 additions & 0 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace rnwhisper {

std::string bench(whisper_context * ctx, int n_threads);

struct vad_params {
bool use_vad = false;
float vad_thold = 0.6f;
Expand Down
43 changes: 33 additions & 10 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4190,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}

struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
if (ctx->state == nullptr) {
return nullptr;
}
return new whisper_timings {
.load_us = ctx->t_load_us,
.t_start_us = ctx->t_start_us,
.fail_p = ctx->state->n_fail_p,
.fail_h = ctx->state->n_fail_h,
.t_mel_us = ctx->state->t_mel_us,
.n_sample = ctx->state->n_sample,
.n_encode = ctx->state->n_encode,
.n_decode = ctx->state->n_decode,
.n_batchd = ctx->state->n_batchd,
.n_prompt = ctx->state->n_prompt,
.t_sample_us = ctx->state->t_sample_us,
.t_encode_us = ctx->state->t_encode_us,
.t_decode_us = ctx->state->t_decode_us,
.t_batchd_us = ctx->state->t_batchd_us,
.t_prompt_us = ctx->state->t_prompt_us,
};
}

void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = wsp_ggml_time_us();
const struct whisper_timings * timings = whisper_get_timings(ctx);

WHISPER_LOG_INFO("\n");
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
if (ctx->state != nullptr) {

const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);

WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
}

void whisper_reset_timings(struct whisper_context * ctx) {
Expand Down
18 changes: 18 additions & 0 deletions cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,24 @@ extern "C" {
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);

// Performance information from the default state.
struct whisper_timings {
int64_t load_us;
int64_t t_start_us;
int32_t fail_p;
int32_t fail_h;
int64_t t_mel_us;
int32_t n_sample;
int32_t n_encode;
int32_t n_decode;
int32_t n_batchd;
int32_t n_prompt;
int64_t t_sample_us;
int64_t t_encode_us;
int64_t t_decode_us;
int64_t t_batchd_us;
int64_t t_prompt_us;
};
WHISPER_API struct whisper_timings * whisper_get_timings(struct whisper_context * ctx);
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);

Expand Down
Loading
Loading