From cb4360baca34c4678fb79e3455aca641d4e9c35a Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Sun, 14 Jul 2024 16:36:20 +0200 Subject: [PATCH] feat(ai): apply code improvements to AudioToText pipeline This commit applies several code improvements to the AudioToText codebase. --- cmd/livepeer/starter/starter.go | 1 - common/util.go | 10 +++++----- core/orchestrator.go | 8 +------- server/ai_http.go | 1 - server/ai_mediaserver.go | 18 +++++++++--------- server/ai_process.go | 30 ++++++++++++++---------------- 6 files changed, 29 insertions(+), 39 deletions(-) diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 945d0dee58..1495101ce6 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -625,7 +625,6 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { constraints[core.Capability_AudioToText].Models[config.ModelID] = modelConstraint n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit)) - } if len(aiCaps) > 0 { diff --git a/common/util.go b/common/util.go index 1a6bf71510..46dae51aee 100644 --- a/common/util.go +++ b/common/util.go @@ -75,6 +75,9 @@ var ( ErrProfEncoder = fmt.Errorf("unknown VideoProfile encoder for protobufs") ErrProfName = fmt.Errorf("unknown VideoProfile profile name") + ErrUnsupportedAudioFormat = fmt.Errorf("audio format unsupported") + ErrAudioDurationCalculation = fmt.Errorf("audio duration calculation failed") + ext2mime = map[string]string{ ".ts": "video/mp2t", ".mp4": "video/mp4", @@ -532,10 +535,7 @@ func ParseEthAddr(strJsonKey string) (string, error) { return "", errors.New("Error parsing address from keyfile") } -// determines the duration of an mp3 audio file by reading the frames -var ErrUnsupportedFormat = errors.New("Unsupported audio file format") -var ErrorCalculatingDuration = errors.New("Error calculating duration") - +// CalculateAudioDuration calculates audio file duration using the lpms/ffmpeg package. func CalculateAudioDuration(audio types.File) (int64, error) { read, err := audio.Reader() if err != nil { @@ -551,7 +551,7 @@ func CalculateAudioDuration(audio types.File) (int64, error) { duration := int64(mediaFormat.DurSecs) if duration <= 0 { - return 0, ErrorCalculatingDuration + return 0, ErrAudioDurationCalculation } return duration, nil diff --git a/core/orchestrator.go b/core/orchestrator.go index 9505e97b30..4ceea94bba 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -950,13 +950,7 @@ func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartR } func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { - - resp, err := n.AIWorker.AudioToText(ctx, req) - if err != nil { - return nil, err - } - - return resp, nil + return n.AIWorker.AudioToText(ctx, req) } func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { diff --git a/server/ai_http.go b/server/ai_http.go index b23bc61794..4ef6eac545 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -269,7 +269,6 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels *= 1000 // Convert to milliseconds - default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 5ad5c6d865..29d33b9fed 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -340,29 +340,29 @@ func (ls *LivepeerServer) AudioToText() http.Handler { return } - clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request model_id=%v", *req.ModelId) + clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request audioSize=%v model_id=%v", req.Audio.FileSize(), *req.ModelId) params := aiRequestParams{ node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(string(core.RandomManifestID())), + os: drivers.NodeStorage.NewSession(requestID), sessManager: ls.AISessionManager, } start := time.Now() resp, err := processAudioToText(ctx, params, req) if err != nil { - var e *ServiceUnavailableError - var reqError *BadRequestError - if errors.As(err, &e) { + var serviceUnavailableErr *ServiceUnavailableError + var badRequestErr *BadRequestError + if errors.As(err, &serviceUnavailableErr) { respondJsonError(ctx, w, err, http.StatusServiceUnavailable) return - } else if errors.As(err, &reqError) { + } + if errors.As(err, &badRequestErr) { respondJsonError(ctx, w, err, http.StatusBadRequest) return - } else { - respondJsonError(ctx, w, err, http.StatusInternalServerError) - return } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return } took := time.Since(start) diff --git a/server/ai_process.go b/server/ai_process.go index fdb3f8888f..49e27ab4ce 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "github.com/golang/glog" "github.com/livepeer/ai-worker/worker" "github.com/livepeer/go-livepeer/clog" "github.com/livepeer/go-livepeer/common" @@ -35,6 +34,10 @@ type ServiceUnavailableError struct { err error } +func (e *ServiceUnavailableError) Error() string { + return e.err.Error() +} + type BadRequestError struct { err error } @@ -43,10 +46,6 @@ func (e *BadRequestError) Error() string { return e.err.Error() } -func (e *ServiceUnavailableError) Error() string { - return e.err.Error() -} - type aiRequestParams struct { node *core.LivepeerNode os drivers.OSSession @@ -419,13 +418,13 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return nil, err } - outPixels, err := common.CalculateAudioDuration(req.Audio) + durationSeconds, err := common.CalculateAudioDuration(req.Audio) if err != nil { return nil, err } - glog.Infof("Submitting audio-to-text media with duration: %d seconds", outPixels) - outPixels *= 1000 // Convert to milliseconds - setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels) + + clog.V(common.VERBOSE).Infof(ctx, "Submitting audio-to-text media with duration: %d seconds", durationSeconds) + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, durationSeconds*1000) if err != nil { return nil, err } @@ -459,7 +458,7 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess } // TODO: Refine this rough estimate in future iterations - sess.LatencyScore = took.Seconds() / float64(outPixels) + sess.LatencyScore = took.Seconds() / float64(durationSeconds) return &res, nil } @@ -505,7 +504,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface if v.ModelId != nil { modelID = *v.ModelId } - // Assuming submitImageToVideo returns a VideoResponse submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToVideo(ctx, params, sess, v) } @@ -527,7 +525,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitAudioToText(ctx, params, sess, v) } - // Add more cases as needed... default: return nil, fmt.Errorf("unsupported request type %T", req) } @@ -561,18 +558,19 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface params.sessManager.Complete(ctx, sess) break } - if errors.Is(err, common.ErrorCalculatingDuration) || errors.Is(err, common.ErrUnsupportedFormat) { - return nil, &BadRequestError{err} - } clog.Infof(ctx, "Error submitting request cap=%v modelID=%v try=%v orch=%v err=%v", cap, modelID, tries, sess.Transcoder(), err) params.sessManager.Remove(ctx, sess) + + if errors.Is(err, common.ErrAudioDurationCalculation) || errors.Is(err, common.ErrUnsupportedAudioFormat) { + return nil, &BadRequestError{err} + } } if resp == nil { return nil, &ServiceUnavailableError{err: errors.New("no orchestrators available")} } - return resp.(interface{}), nil + return resp, nil } func prepareAIPayment(ctx context.Context, sess *AISession, outPixels int64) (worker.RequestEditorFn, *BalanceUpdate, error) {