Skip to content

Commit

Permalink
Merge pull request #4 from eliteprox/add-speech-to-text-code-improvem…
Browse files Browse the repository at this point in the history
…ents

feat(ai): apply code improvements to AudioToText pipeline
  • Loading branch information
eliteprox authored Jul 15, 2024
2 parents 5b24400 + e307c70 commit d40d41b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 35 deletions.
1 change: 0 additions & 1 deletion cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
constraints[core.Capability_AudioToText].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit))

}

if len(aiCaps) > 0 {
Expand Down
10 changes: 5 additions & 5 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ var (
ErrProfEncoder = fmt.Errorf("unknown VideoProfile encoder for protobufs")
ErrProfName = fmt.Errorf("unknown VideoProfile profile name")

ErrUnsupportedAudioFormat = fmt.Errorf("audio format unsupported")
ErrAudioDurationCalculation = fmt.Errorf("audio duration calculation failed")

ext2mime = map[string]string{
".ts": "video/mp2t",
".mp4": "video/mp4",
Expand Down Expand Up @@ -532,10 +535,7 @@ func ParseEthAddr(strJsonKey string) (string, error) {
return "", errors.New("Error parsing address from keyfile")
}

// determines the duration of an mp3 audio file by reading the frames
var ErrUnsupportedFormat = errors.New("Unsupported audio file format")
var ErrorCalculatingDuration = errors.New("Error calculating duration")

// CalculateAudioDuration calculates audio file duration using the lpms/ffmpeg package.
func CalculateAudioDuration(audio types.File) (int64, error) {
read, err := audio.Reader()
if err != nil {
Expand All @@ -551,7 +551,7 @@ func CalculateAudioDuration(audio types.File) (int64, error) {

duration := int64(mediaFormat.DurSecs)
if duration <= 0 {
return 0, ErrorCalculatingDuration
return 0, ErrAudioDurationCalculation
}

return duration, nil
Expand Down
4 changes: 1 addition & 3 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,7 @@ func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartR
}

func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {

resp, err := n.AIWorker.AudioToText(ctx, req)
return resp, err
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
Expand Down
1 change: 0 additions & 1 deletion server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds

default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down
18 changes: 9 additions & 9 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,29 +340,29 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request model_id=%v", *req.ModelId)
clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request audioSize=%v model_id=%v", req.Audio.FileSize(), *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(string(core.RandomManifestID())),
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processAudioToText(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
var reqError *BadRequestError
if errors.As(err, &e) {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
} else if errors.As(err, &reqError) {
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
} else {
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
Expand Down
30 changes: 14 additions & 16 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"strings"
"time"

"github.com/golang/glog"
"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/clog"
"github.com/livepeer/go-livepeer/common"
Expand All @@ -35,6 +34,10 @@ type ServiceUnavailableError struct {
err error
}

func (e *ServiceUnavailableError) Error() string {
return e.err.Error()
}

type BadRequestError struct {
err error
}
Expand All @@ -43,10 +46,6 @@ func (e *BadRequestError) Error() string {
return e.err.Error()
}

func (e *ServiceUnavailableError) Error() string {
return e.err.Error()
}

type aiRequestParams struct {
node *core.LivepeerNode
os drivers.OSSession
Expand Down Expand Up @@ -419,13 +418,13 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess
return nil, err
}

outPixels, err := common.CalculateAudioDuration(req.Audio)
durationSeconds, err := common.CalculateAudioDuration(req.Audio)
if err != nil {
return nil, err
}
glog.Infof("Submitting audio-to-text media with duration: %d seconds", outPixels)
outPixels *= 1000 // Convert to milliseconds
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outPixels)

clog.V(common.VERBOSE).Infof(ctx, "Submitting audio-to-text media with duration: %d seconds", durationSeconds)
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, durationSeconds*1000)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -459,7 +458,7 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess
}

// TODO: Refine this rough estimate in future iterations
sess.LatencyScore = took.Seconds() / float64(outPixels)
sess.LatencyScore = took.Seconds() / float64(durationSeconds)

return &res, nil
}
Expand Down Expand Up @@ -505,7 +504,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
if v.ModelId != nil {
modelID = *v.ModelId
}
// Assuming submitImageToVideo returns a VideoResponse
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitImageToVideo(ctx, params, sess, v)
}
Expand All @@ -527,7 +525,6 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitAudioToText(ctx, params, sess, v)
}
// Add more cases as needed...
default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down Expand Up @@ -561,18 +558,19 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
params.sessManager.Complete(ctx, sess)
break
}
if errors.Is(err, common.ErrorCalculatingDuration) || errors.Is(err, common.ErrUnsupportedFormat) {
return nil, &BadRequestError{err}
}

clog.Infof(ctx, "Error submitting request cap=%v modelID=%v try=%v orch=%v err=%v", cap, modelID, tries, sess.Transcoder(), err)
params.sessManager.Remove(ctx, sess)

if errors.Is(err, common.ErrAudioDurationCalculation) || errors.Is(err, common.ErrUnsupportedAudioFormat) {
return nil, &BadRequestError{err}
}
}

if resp == nil {
return nil, &ServiceUnavailableError{err: errors.New("no orchestrators available")}
}
return resp.(interface{}), nil
return resp, nil
}

func prepareAIPayment(ctx context.Context, sess *AISession, outPixels int64) (worker.RequestEditorFn, *BalanceUpdate, error) {
Expand Down

0 comments on commit d40d41b

Please sign in to comment.