diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 15cca7f8c3..1495101ce6 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -613,6 +613,18 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { constraints[core.Capability_Upscale].Models[config.ModelID] = modelConstraint n.SetBasePriceForCap("default", core.Capability_Upscale, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit)) + case "audio-to-text": + _, ok := constraints[core.Capability_AudioToText] + if !ok { + aiCaps = append(aiCaps, core.Capability_AudioToText) + constraints[core.Capability_AudioToText] = &core.Constraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + constraints[core.Capability_AudioToText].Models[config.ModelID] = modelConstraint + + n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, big.NewRat(config.PricePerUnit, config.PixelsPerUnit)) } if len(aiCaps) > 0 { diff --git a/common/util.go b/common/util.go index 639cf07ccd..8a9eb15219 100644 --- a/common/util.go +++ b/common/util.go @@ -25,6 +25,7 @@ import ( "github.com/jaypipes/ghw/pkg/pci" "github.com/livepeer/go-livepeer/net" ffmpeg "github.com/livepeer/lpms/ffmpeg" + "github.com/oapi-codegen/runtime/types" "github.com/pkg/errors" "google.golang.org/grpc/peer" ) @@ -74,6 +75,8 @@ var ( ErrProfEncoder = fmt.Errorf("unknown VideoProfile encoder for protobufs") ErrProfName = fmt.Errorf("unknown VideoProfile profile name") + ErrAudioDurationCalculation = fmt.Errorf("audio duration calculation failed") + ext2mime = map[string]string{ ".ts": "video/mp2t", ".mp4": "video/mp4", @@ -530,3 +533,25 @@ func ParseEthAddr(strJsonKey string) (string, error) { } return "", errors.New("Error parsing address from keyfile") } + +// CalculateAudioDuration calculates audio file duration using the lpms/ffmpeg package. +func CalculateAudioDuration(audio types.File) (int64, error) { + read, err := audio.Reader() + if err != nil { + return 0, err + } + defer read.Close() + + bytearr, _ := audio.Bytes() + _, mediaFormat, err := ffmpeg.GetCodecInfoBytes(bytearr) + if err != nil { + return 0, errors.New("Error getting codec info") + } + + duration := int64(mediaFormat.DurSecs) + if duration <= 0 { + return 0, ErrAudioDurationCalculation + } + + return duration, nil +} diff --git a/core/ai.go b/core/ai.go index c4a7146ec6..fde3060a96 100644 --- a/core/ai.go +++ b/core/ai.go @@ -16,6 +16,7 @@ type AI interface { ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) + AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/capabilities.go b/core/capabilities.go index 240222ab29..12eb7bbc6d 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -71,6 +71,7 @@ const ( Capability_ImageToImage Capability_ImageToVideo Capability_Upscale + Capability_AudioToText ) var CapabilityNameLookup = map[Capability]string{ @@ -106,6 +107,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToImage: "Image to image", Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", + Capability_AudioToText: "Audio to text", } var CapabilityTestLookup = map[Capability]CapabilityTest{ @@ -195,6 +197,7 @@ func OptionalCapabilities() []Capability { Capability_ImageToImage, Capability_ImageToVideo, Capability_Upscale, + Capability_AudioToText, } } diff --git a/core/orchestrator.go b/core/orchestrator.go index f9d42b3faf..4ceea94bba 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -126,6 +126,10 @@ func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipa return orch.node.upscale(ctx, req) } +func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { + return orch.node.AudioToText(ctx, req) +} + func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error { if orch.node == nil || orch.node.Recipient == nil { return nil @@ -945,6 +949,10 @@ func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartR return n.AIWorker.Upscale(ctx, req) } +func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { + return n.AIWorker.AudioToText(ctx, req) +} + func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { // We might support generating more than one video in the future (i.e. multiple input images/prompts) numVideos := 1 diff --git a/go.mod b/go.mod index e77da1df23..c80fe64f37 100644 --- a/go.mod +++ b/go.mod @@ -9,13 +9,13 @@ require ( github.com/getkin/kin-openapi v0.124.0 github.com/golang/glog v1.1.1 github.com/golang/mock v1.6.0 - github.com/golang/protobuf v1.5.3 + github.com/golang/protobuf v1.5.4 github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/pcidb v1.0.0 - github.com/livepeer/ai-worker v0.0.8 + github.com/livepeer/ai-worker v0.1.0 github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 - github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69 + github.com/livepeer/lpms v0.0.0-20240711175220-227325841434 github.com/livepeer/m3u8 v0.11.1 github.com/mattn/go-sqlite3 v1.14.18 github.com/oapi-codegen/nethttp-middleware v1.0.1 @@ -32,7 +32,7 @@ require ( go.uber.org/goleak v1.3.0 golang.org/x/net v0.25.0 google.golang.org/grpc v1.57.1 - google.golang.org/protobuf v1.31.0 + google.golang.org/protobuf v1.33.0 pgregory.net/rapid v1.1.0 ) @@ -85,6 +85,7 @@ require ( github.com/go-openapi/swag v0.22.8 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-stack/stack v1.8.1 // indirect + github.com/go-test/deep v1.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect diff --git a/go.sum b/go.sum index 3073624a39..74df1b7267 100644 --- a/go.sum +++ b/go.sum @@ -251,8 +251,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.6/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -297,8 +297,8 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -531,16 +531,16 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.0.8 h1:FAjYJgSOaZslA06Wb6MolYohI30IMIujDTB26nfw8YE= -github.com/livepeer/ai-worker v0.0.8/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= +github.com/livepeer/ai-worker v0.1.0 h1:SJBZuxeK0vEzJPBzf5osdgVCxHYZt7ZKR2CvZ7Q7iog= +github.com/livepeer/ai-worker v0.1.0/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded/go.mod h1:xkDdm+akniYxVT9KW1Y2Y7Hso6aW+rZObz3nrA9yTHw= github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18 h1:4oH3NqV0NvcdS44Ld3zK2tO8IUiNozIggm74yobQeZg= github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18/go.mod h1:Jpf4jHK+fbWioBHRDRM1WadNT1qmY27g2YicTdO0Rtc= -github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69 h1:4A6geMb+HfxBBfaS24t8R3ddpEDfWbpx7NTQZMt5Fp4= -github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69/go.mod h1:Hr/JhxxPDipOVd4ZrGYWrdJfpVF8/SEI0nNr2ctAlkM= +github.com/livepeer/lpms v0.0.0-20240711175220-227325841434 h1:E7PKN6q/jMLapEV+eEwlwv87Xe5zacaVhvZ8T6AJR3c= +github.com/livepeer/lpms v0.0.0-20240711175220-227325841434/go.mod h1:Hr/JhxxPDipOVd4ZrGYWrdJfpVF8/SEI0nNr2ctAlkM= github.com/livepeer/m3u8 v0.11.1 h1:VkUJzfNTyjy9mqsgp5JPvouwna8wGZMvd/gAfT5FinU= github.com/livepeer/m3u8 v0.11.1/go.mod h1:IUqAtwWPAG2CblfQa4SVzTQoDcEMPyfNOaBSxqHMS04= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -1217,8 +1217,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/install_ffmpeg.sh b/install_ffmpeg.sh index bb99292ef0..df864b4968 100755 --- a/install_ffmpeg.sh +++ b/install_ffmpeg.sh @@ -208,13 +208,13 @@ if [[ ! -e "$ROOT/ffmpeg/libavcodec/libavcodec.a" ]]; then ./configure ${TARGET_OS:-} $DISABLE_FFMPEG_COMPONENTS --fatal-warnings \ --enable-libx264 --enable-gpl \ --enable-protocol=rtmp,file,pipe \ - --enable-muxer=mpegts,hls,segment,mp4,hevc,matroska,webm,null --enable-demuxer=flv,mpegts,mp4,mov,webm,matroska,image2 \ + --enable-muxer=mp3,wav,flac,mpegts,hls,segment,mp4,hevc,matroska,webm,null --enable-demuxer=mp3,wav,flac,flv,mpegts,mp4,mov,webm,matroska,image2 \ --enable-bsf=h264_mp4toannexb,aac_adtstoasc,h264_metadata,h264_redundant_pps,hevc_mp4toannexb,extract_extradata \ - --enable-parser=aac,aac_latm,h264,hevc,vp8,vp9,png \ + --enable-parser=mpegaudio,vorbis,opus,flac,aac,aac_latm,h264,hevc,vp8,vp9,png \ --enable-filter=abuffer,buffer,abuffersink,buffersink,afifo,fifo,aformat,format \ --enable-filter=aresample,asetnsamples,fps,scale,hwdownload,select,livepeer_dnn,signature \ - --enable-encoder=aac,opus,libx264 \ - --enable-decoder=aac,opus,h264,png \ + --enable-encoder=mp3,vorbis,flac,aac,opus,libx264 \ + --enable-decoder=mp3,vorbis,flac,aac,opus,h264,png \ --extra-cflags="${EXTRA_CFLAGS} -I${ROOT}/compiled/include -I/usr/local/cuda/include" \ --extra-ldflags="${EXTRA_FFMPEG_LDFLAGS} -L${ROOT}/compiled/lib -L/usr/local/cuda/lib64" \ --prefix="$ROOT/compiled" \ diff --git a/server/ai_http.go b/server/ai_http.go index fea249ca51..4ef6eac545 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -42,6 +42,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/image-to-image", oapiReqValidator(lp.ImageToImage())) lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo())) lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) + lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) return nil } @@ -132,6 +133,29 @@ func (h *lphttp) Upscale() http.Handler { }) } +func (h *lphttp) AudioToText() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.AudioToTextMultipartRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -149,7 +173,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request var cap core.Capability var pipeline string var modelID string - var submitFn func(context.Context) (*worker.ImageResponse, error) + var submitFn func(context.Context) (interface{}, error) var outPixels int64 switch v := req.(type) { @@ -157,7 +181,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request pipeline = "text-to-image" cap = core.Capability_TextToImage modelID = *v.ModelId - submitFn = func(ctx context.Context) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context) (interface{}, error) { return orch.TextToImage(ctx, v) } @@ -176,7 +200,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request pipeline = "image-to-image" cap = core.Capability_ImageToImage modelID = *v.ModelId - submitFn = func(ctx context.Context) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context) (interface{}, error) { return orch.ImageToImage(ctx, v) } @@ -195,7 +219,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request pipeline = "upscale" cap = core.Capability_Upscale modelID = *v.ModelId - submitFn = func(ctx context.Context) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context) (interface{}, error) { return orch.Upscale(ctx, v) } @@ -214,7 +238,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request pipeline = "image-to-video" cap = core.Capability_ImageToVideo modelID = *v.ModelId - submitFn = func(ctx context.Context) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context) (interface{}, error) { return orch.ImageToVideo(ctx, v) } @@ -231,6 +255,20 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request frames := int64(25) outPixels = height * width * int64(frames) + case worker.AudioToTextMultipartRequestBody: + pipeline = "audio-to-text" + cap = core.Capability_AudioToText + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.AudioToText(ctx, v) + } + + outPixels, err = common.CalculateAudioDuration(v.Audio) + if err != nil { + respondWithError(w, "Unable to calculate duration", http.StatusBadRequest) + return + } + outPixels *= 1000 // Convert to milliseconds default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 62e5b7ad39..29d33b9fed 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -68,6 +68,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/upscale", oapiReqValidator(ls.Upscale())) ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) + ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) return nil } @@ -320,6 +321,59 @@ func (ls *LivepeerServer) Upscale() http.Handler { }) } +func (ls *LivepeerServer) AudioToText() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + var req worker.AudioToTextMultipartRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request audioSize=%v model_id=%v", req.Audio.FileSize(), *req.ModelId) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processAudioToText(ctx, params, req) + if err != nil { + var serviceUnavailableErr *ServiceUnavailableError + var badRequestErr *BadRequestError + if errors.As(err, &serviceUnavailableErr) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + if errors.As(err, &badRequestErr) { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed AudioToText request model_id=%v took=%v", *req.ModelId, took) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + }) +} + func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index ed58e3106c..be8571dc26 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -28,6 +28,7 @@ const defaultTextToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt" const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" +const defaultAudioToTextModelID = "openai/whisper-large-v3" type ServiceUnavailableError struct { err error @@ -37,6 +38,14 @@ func (e *ServiceUnavailableError) Error() string { return e.err.Error() } +type BadRequestError struct { + err error +} + +func (e *BadRequestError) Error() string { + return e.err.Error() +} + type aiRequestParams struct { node *core.LivepeerNode os drivers.OSSession @@ -49,8 +58,10 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. return nil, err } - newMedia := make([]worker.Media, len(resp.Images)) - for i, media := range resp.Images { + imgResp := resp.(*worker.ImageResponse) + + newMedia := make([]worker.Media, len(imgResp.Images)) + for i, media := range imgResp.Images { var data bytes.Buffer writer := bufio.NewWriter(&data) if err := worker.ReadImageB64DataUrl(media.Url, writer); err != nil { @@ -67,9 +78,9 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl} } - resp.Images = newMedia + imgResp.Images = newMedia - return resp, nil + return imgResp, nil } func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { @@ -133,8 +144,10 @@ func processImageToImage(ctx context.Context, params aiRequestParams, req worker return nil, err } - newMedia := make([]worker.Media, len(resp.Images)) - for i, media := range resp.Images { + imgResp := resp.(*worker.ImageResponse) + + newMedia := make([]worker.Media, len(imgResp.Images)) + for i, media := range imgResp.Images { var data bytes.Buffer writer := bufio.NewWriter(&data) if err := worker.ReadImageB64DataUrl(media.Url, writer); err != nil { @@ -151,9 +164,9 @@ func processImageToImage(ctx context.Context, params aiRequestParams, req worker newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl} } - resp.Images = newMedia + imgResp.Images = newMedia - return resp, nil + return imgResp, nil } func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { @@ -220,8 +233,11 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker } // HACK: Re-use worker.ImageResponse to return results - videos := make([]worker.Media, len(resp.Images)) - for i, media := range resp.Images { + // TODO: Refactor to return worker.VideoResponse + imgResp := resp.(*worker.ImageResponse) + + videos := make([]worker.Media, len(imgResp.Images)) + for i, media := range imgResp.Images { data, err := downloadSeg(ctx, media.Url) if err != nil { return nil, err @@ -241,9 +257,9 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker } - resp.Images = videos + imgResp.Images = videos - return resp, nil + return imgResp, nil } func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISession, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { @@ -314,8 +330,10 @@ func processUpscale(ctx context.Context, params aiRequestParams, req worker.Upsc return nil, err } - newMedia := make([]worker.Media, len(resp.Images)) - for i, media := range resp.Images { + imgResp := resp.(*worker.ImageResponse) + + newMedia := make([]worker.Media, len(imgResp.Images)) + for i, media := range imgResp.Images { var data bytes.Buffer writer := bufio.NewWriter(&data) if err := worker.ReadImageB64DataUrl(media.Url, writer); err != nil { @@ -332,9 +350,9 @@ func processUpscale(ctx context.Context, params aiRequestParams, req worker.Upsc newMedia[i] = worker.Media{Nsfw: media.Nsfw, Seed: media.Seed, Url: newUrl} } - resp.Images = newMedia + imgResp.Images = newMedia - return resp, nil + return imgResp, nil } func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { @@ -388,10 +406,78 @@ func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, return resp.JSON200, nil } -func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (*worker.ImageResponse, error) { +func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISession, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { + var buf bytes.Buffer + mw, err := worker.NewAudioToTextMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + return nil, err + } + + durationSeconds, err := common.CalculateAudioDuration(req.Audio) + if err != nil { + return nil, err + } + + clog.V(common.VERBOSE).Infof(ctx, "Submitting audio-to-text media with duration: %d seconds", durationSeconds) + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, durationSeconds*1000) + if err != nil { + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.AudioToTextWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + took := time.Since(start) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, errors.New(string(data)) + } + + // We treat a response as "receiving change" where the change is the difference between the credit and debit for the update + if balUpdate != nil { + balUpdate.Status = ReceivedChange + } + + var res worker.TextResponse + if err := json.Unmarshal(data, &res); err != nil { + return nil, err + } + + // TODO: Refine this rough estimate in future iterations + sess.LatencyScore = took.Seconds() / float64(durationSeconds) + + return &res, nil +} + +func processAudioToText(ctx context.Context, params aiRequestParams, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + txtResp := resp.(*worker.TextResponse) + + return txtResp, nil +} + +func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string - var submitFn func(context.Context, aiRequestParams, *AISession) (*worker.ImageResponse, error) + var submitFn func(context.Context, aiRequestParams, *AISession) (interface{}, error) switch v := req.(type) { case worker.TextToImageJSONRequestBody: @@ -400,7 +486,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface if v.ModelId != nil { modelID = *v.ModelId } - submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitTextToImage(ctx, params, sess, v) } case worker.ImageToImageMultipartRequestBody: @@ -409,7 +495,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface if v.ModelId != nil { modelID = *v.ModelId } - submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToImage(ctx, params, sess, v) } case worker.ImageToVideoMultipartRequestBody: @@ -418,7 +504,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface if v.ModelId != nil { modelID = *v.ModelId } - submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToVideo(ctx, params, sess, v) } case worker.UpscaleMultipartRequestBody: @@ -427,14 +513,23 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface if v.ModelId != nil { modelID = *v.ModelId } - submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (*worker.ImageResponse, error) { + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitUpscale(ctx, params, sess, v) } + case worker.AudioToTextMultipartRequestBody: + cap = core.Capability_AudioToText + modelID = defaultAudioToTextModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitAudioToText(ctx, params, sess, v) + } default: - return nil, errors.New("unknown AI request type") + return nil, fmt.Errorf("unsupported request type %T", req) } - var resp *worker.ImageResponse + var resp interface{} cctx, cancel := context.WithTimeout(ctx, processingRetryTimeout) defer cancel() @@ -465,14 +560,16 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface } clog.Infof(ctx, "Error submitting request cap=%v modelID=%v try=%v orch=%v err=%v", cap, modelID, tries, sess.Transcoder(), err) - params.sessManager.Remove(ctx, sess) + + if errors.Is(err, common.ErrAudioDurationCalculation) { + return nil, &BadRequestError{err} + } } if resp == nil { return nil, &ServiceUnavailableError{err: errors.New("no orchestrators available")} } - return resp, nil } diff --git a/server/rpc.go b/server/rpc.go index 9c24d3336a..0fc46494f8 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -67,6 +67,7 @@ type Orchestrator interface { ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) + AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) } // Balance describes methods for a session's balance maintenance