From 49b03186017bddb2dd5b767d7ef86005a3a61acf Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 13 Sep 2024 11:06:00 +0200 Subject: [PATCH] refactor: update worker classes This commit ensures that the go-livepeer code uses the new worker classes that were defined in https://github.com/livepeer/ai-worker/pull/191. --- ai/file_worker.go | 8 +++---- core/ai.go | 12 +++++----- core/orchestrator.go | 24 +++++++++---------- server/ai_http.go | 36 ++++++++++++++--------------- server/ai_mediaserver.go | 12 +++++----- server/ai_process.go | 50 ++++++++++++++++++++-------------------- server/rpc.go | 12 +++++----- server/rpc_test.go | 24 +++++++++---------- 8 files changed, 89 insertions(+), 89 deletions(-) diff --git a/ai/file_worker.go b/ai/file_worker.go index e9eb85c641..1ce5478204 100644 --- a/ai/file_worker.go +++ b/ai/file_worker.go @@ -17,7 +17,7 @@ func NewFileWorker(files map[string]string) *FileWorker { return &FileWorker{files: files} } -func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { fname, ok := w.files["text-to-image"] if !ok { return nil, errors.New("text-to-image response file not found") @@ -36,7 +36,7 @@ func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSON return &resp, nil } -func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { fname, ok := w.files["image-to-image"] if !ok { return nil, errors.New("image-to-image response file not found") @@ -55,7 +55,7 @@ func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMu return &resp, nil } -func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) { +func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) { fname, ok := w.files["image-to-video"] if !ok { return nil, errors.New("image-to-video response file not found") @@ -74,7 +74,7 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu return &resp, nil } -func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { fname, ok := w.files["upscale"] if !ok { return nil, errors.New("upscale response file not found") diff --git a/core/ai.go b/core/ai.go index 887e1f8c8c..31f331e49e 100644 --- a/core/ai.go +++ b/core/ai.go @@ -17,12 +17,12 @@ import ( var errPipelineNotAvailable = errors.New("pipeline not available") type AI interface { - TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) - ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) - ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) - Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) - AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) - SegmentAnything2(context.Context, worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + TextToImage(context.Context, worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) + ImageToImage(context.Context, worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) + ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) + Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) + AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) + SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/orchestrator.go b/core/orchestrator.go index ba9f470bd1..f8e343ae32 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -110,27 +110,27 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes orch.node.TranscoderManager.transcoderResults(tcID, res) } -func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func (orch *orchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { return orch.node.textToImage(ctx, req) } -func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { return orch.node.imageToImage(ctx, req) } -func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { return orch.node.imageToVideo(ctx, req) } -func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { return orch.node.upscale(ctx, req) } -func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return orch.node.AudioToText(ctx, req) } -func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { +func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return orch.node.SegmentAnything2(ctx, req) } @@ -951,27 +951,27 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS } } -func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { return n.AIWorker.TextToImage(ctx, req) } -func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { return n.AIWorker.ImageToImage(ctx, req) } -func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func (n *LivepeerNode) upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { return n.AIWorker.Upscale(ctx, req) } -func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return n.AIWorker.AudioToText(ctx, req) } -func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { +func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return n.AIWorker.SegmentAnything2(ctx, req) } -func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { // We might support generating more than one video in the future (i.e. multiple input images/prompts) numVideos := 1 diff --git a/server/ai_http.go b/server/ai_http.go index a11164bc9b..3f0bb97d9e 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -56,7 +56,7 @@ func (h *lphttp) TextToImage() http.Handler { remoteAddr := getRemoteAddr(r) ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) - var req worker.TextToImageJSONRequestBody + var req worker.GenTextToImageJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondWithError(w, err.Error(), http.StatusBadRequest) return @@ -79,7 +79,7 @@ func (h *lphttp) ImageToImage() http.Handler { return } - var req worker.ImageToImageMultipartRequestBody + var req worker.GenImageToImageMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -102,7 +102,7 @@ func (h *lphttp) ImageToVideo() http.Handler { return } - var req worker.ImageToVideoMultipartRequestBody + var req worker.GenImageToVideoMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -125,7 +125,7 @@ func (h *lphttp) Upscale() http.Handler { return } - var req worker.UpscaleMultipartRequestBody + var req worker.GenUpscaleMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -148,7 +148,7 @@ func (h *lphttp) AudioToText() http.Handler { return } - var req worker.AudioToTextMultipartRequestBody + var req worker.GenAudioToTextMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -171,7 +171,7 @@ func (h *lphttp) SegmentAnything2() http.Handler { return } - var req worker.SegmentAnything2MultipartRequestBody + var req worker.GenSegmentAnything2MultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -202,7 +202,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request var outPixels int64 switch v := req.(type) { - case worker.TextToImageJSONRequestBody: + case worker.GenTextToImageJSONRequestBody: pipeline = "text-to-image" cap = core.Capability_TextToImage modelID = *v.ModelId @@ -226,7 +226,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } outPixels = height * width * numImages - case worker.ImageToImageMultipartRequestBody: + case worker.GenImageToImageMultipartRequestBody: pipeline = "image-to-image" cap = core.Capability_ImageToImage modelID = *v.ModelId @@ -251,7 +251,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } outPixels = int64(config.Height) * int64(config.Width) * numImages - case worker.UpscaleMultipartRequestBody: + case worker.GenUpscaleMultipartRequestBody: pipeline = "upscale" cap = core.Capability_Upscale modelID = *v.ModelId @@ -270,7 +270,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) - case worker.ImageToVideoMultipartRequestBody: + case worker.GenImageToVideoMultipartRequestBody: pipeline = "image-to-video" cap = core.Capability_ImageToVideo modelID = *v.ModelId @@ -291,7 +291,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request frames := int64(25) outPixels = height * width * int64(frames) - case worker.AudioToTextMultipartRequestBody: + case worker.GenAudioToTextMultipartRequestBody: pipeline = "audio-to-text" cap = core.Capability_AudioToText modelID = *v.ModelId @@ -305,7 +305,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels *= 1000 // Convert to milliseconds - case worker.SegmentAnything2MultipartRequestBody: + case worker.GenSegmentAnything2MultipartRequestBody: pipeline = "segment-anything-2" cap = core.Capability_SegmentAnything2 modelID = *v.ModelId @@ -382,20 +382,20 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request if monitor.Enabled { var latencyScore float64 switch v := req.(type) { - case worker.TextToImageJSONRequestBody: + case worker.GenTextToImageJSONRequestBody: latencyScore = CalculateTextToImageLatencyScore(took, v, outPixels) - case worker.ImageToImageMultipartRequestBody: + case worker.GenImageToImageMultipartRequestBody: latencyScore = CalculateImageToImageLatencyScore(took, v, outPixels) - case worker.ImageToVideoMultipartRequestBody: + case worker.GenImageToVideoMultipartRequestBody: latencyScore = CalculateImageToVideoLatencyScore(took, v, outPixels) - case worker.UpscaleMultipartRequestBody: + case worker.GenUpscaleMultipartRequestBody: latencyScore = CalculateUpscaleLatencyScore(took, v, outPixels) - case worker.AudioToTextMultipartRequestBody: + case worker.GenAudioToTextMultipartRequestBody: durationSeconds, err := common.CalculateAudioDuration(v.Audio) if err == nil { latencyScore = CalculateAudioToTextLatencyScore(took, durationSeconds) } - case worker.SegmentAnything2MultipartRequestBody: + case worker.GenSegmentAnything2MultipartRequestBody: latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels) } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index a600f3b176..078fa05ee9 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -81,7 +81,7 @@ func (ls *LivepeerServer) TextToImage() http.Handler { requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.TextToImageJSONRequestBody + var req worker.GenTextToImageJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return @@ -129,7 +129,7 @@ func (ls *LivepeerServer) ImageToImage() http.Handler { return } - var req worker.ImageToImageMultipartRequestBody + var req worker.GenImageToImageMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return @@ -177,7 +177,7 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler { return } - var req worker.ImageToVideoMultipartRequestBody + var req worker.GenImageToVideoMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return @@ -287,7 +287,7 @@ func (ls *LivepeerServer) Upscale() http.Handler { return } - var req worker.UpscaleMultipartRequestBody + var req worker.GenUpscaleMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return @@ -335,7 +335,7 @@ func (ls *LivepeerServer) AudioToText() http.Handler { return } - var req worker.AudioToTextMultipartRequestBody + var req worker.GenAudioToTextMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return @@ -388,7 +388,7 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler { return } - var req worker.SegmentAnything2MultipartRequestBody + var req worker.GenSegmentAnything2MultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return diff --git a/server/ai_process.go b/server/ai_process.go index d69beb0610..664f683094 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -56,7 +56,7 @@ type aiRequestParams struct { } // CalculateTextToImageLatencyScore computes the time taken per pixel for an text-to-image request. -func CalculateTextToImageLatencyScore(took time.Duration, req worker.TextToImageJSONRequestBody, outPixels int64) float64 { +func CalculateTextToImageLatencyScore(took time.Duration, req worker.GenTextToImageJSONRequestBody, outPixels int64) float64 { if outPixels <= 0 { return 0 } @@ -75,7 +75,7 @@ func CalculateTextToImageLatencyScore(took time.Duration, req worker.TextToImage return took.Seconds() / float64(outPixels) / numInferenceSteps } -func processTextToImage(ctx context.Context, params aiRequestParams, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func processTextToImage(ctx context.Context, params aiRequestParams, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -106,7 +106,7 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. return imgResp, nil } -func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) if err != nil { @@ -146,7 +146,7 @@ func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISess defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.TextToImageWithResponse(ctx, req, setHeaders) + resp, err := client.GenTextToImageWithResponse(ctx, req, setHeaders) took := time.Since(start) // TODO: Refine this rough estimate in future iterations. @@ -182,7 +182,7 @@ func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISess } // CalculateImageToImageLatencyScore computes the time taken per pixel for an image-to-image request. -func CalculateImageToImageLatencyScore(took time.Duration, req worker.ImageToImageMultipartRequestBody, outPixels int64) float64 { +func CalculateImageToImageLatencyScore(took time.Duration, req worker.GenImageToImageMultipartRequestBody, outPixels int64) float64 { if outPixels <= 0 { return 0 } @@ -201,7 +201,7 @@ func CalculateImageToImageLatencyScore(took time.Duration, req worker.ImageToIma return took.Seconds() / float64(outPixels) / numInferenceSteps } -func processImageToImage(ctx context.Context, params aiRequestParams, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func processImageToImage(ctx context.Context, params aiRequestParams, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -232,7 +232,7 @@ func processImageToImage(ctx context.Context, params aiRequestParams, req worker return imgResp, nil } -func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { // TODO: Default values for the number of images is currently hardcoded. // These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details. defaultNumImages := 1 @@ -286,7 +286,7 @@ func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISes defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.ImageToImageWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.GenImageToImageWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) took := time.Since(start) // TODO: Refine this rough estimate in future iterations. @@ -322,7 +322,7 @@ func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISes } // CalculateImageToVideoLatencyScore computes the time taken per pixel for an image-to-video request. -func CalculateImageToVideoLatencyScore(took time.Duration, req worker.ImageToVideoMultipartRequestBody, outPixels int64) float64 { +func CalculateImageToVideoLatencyScore(took time.Duration, req worker.GenImageToVideoMultipartRequestBody, outPixels int64) float64 { if outPixels <= 0 { return 0 } @@ -337,7 +337,7 @@ func CalculateImageToVideoLatencyScore(took time.Duration, req worker.ImageToVid return took.Seconds() / float64(outPixels) / numInferenceSteps } -func processImageToVideo(ctx context.Context, params aiRequestParams, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func processImageToVideo(ctx context.Context, params aiRequestParams, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -373,7 +373,7 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker return imgResp, nil } -func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISession, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { var buf bytes.Buffer mw, err := worker.NewImageToVideoMultipartWriter(&buf, req) if err != nil { @@ -463,7 +463,7 @@ func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISes } // CalculateUpscaleLatencyScore computes the time taken per pixel for an upscale request. -func CalculateUpscaleLatencyScore(took time.Duration, req worker.UpscaleMultipartRequestBody, outPixels int64) float64 { +func CalculateUpscaleLatencyScore(took time.Duration, req worker.GenUpscaleMultipartRequestBody, outPixels int64) float64 { if outPixels <= 0 { return 0 } @@ -478,7 +478,7 @@ func CalculateUpscaleLatencyScore(took time.Duration, req worker.UpscaleMultipar return took.Seconds() / float64(outPixels) / numInferenceSteps } -func processUpscale(ctx context.Context, params aiRequestParams, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func processUpscale(ctx context.Context, params aiRequestParams, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -509,7 +509,7 @@ func processUpscale(ctx context.Context, params aiRequestParams, req worker.Upsc return imgResp, nil } -func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { var buf bytes.Buffer mw, err := worker.NewUpscaleMultipartWriter(&buf, req) if err != nil { @@ -553,7 +553,7 @@ func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession, defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.UpscaleWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.GenUpscaleWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) took := time.Since(start) if err != nil { if monitor.Enabled { @@ -596,7 +596,7 @@ func CalculateSegmentAnything2LatencyScore(took time.Duration, outPixels int64) return took.Seconds() / float64(outPixels) } -func processSegmentAnything2(ctx context.Context, params aiRequestParams, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { +func processSegmentAnything2(ctx context.Context, params aiRequestParams, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -651,7 +651,7 @@ func submitSegmentAnything2(ctx context.Context, params aiRequestParams, sess *A defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.SegmentAnything2WithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.GenSegmentAnything2WithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders) took := time.Since(start) if err != nil { if monitor.Enabled { @@ -694,7 +694,7 @@ func CalculateAudioToTextLatencyScore(took time.Duration, durationSeconds int64) return took.Seconds() / float64(durationSeconds) } -func processAudioToText(ctx context.Context, params aiRequestParams, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func processAudioToText(ctx context.Context, params aiRequestParams, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -705,7 +705,7 @@ func processAudioToText(ctx context.Context, params aiRequestParams, req worker. return txtResp, nil } -func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISession, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { var buf bytes.Buffer mw, err := worker.NewAudioToTextMultipartWriter(&buf, req) if err != nil { @@ -798,7 +798,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface var submitFn func(context.Context, aiRequestParams, *AISession) (interface{}, error) switch v := req.(type) { - case worker.TextToImageJSONRequestBody: + case worker.GenTextToImageJSONRequestBody: cap = core.Capability_TextToImage modelID = defaultTextToImageModelID if v.ModelId != nil { @@ -807,7 +807,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitTextToImage(ctx, params, sess, v) } - case worker.ImageToImageMultipartRequestBody: + case worker.GenImageToImageMultipartRequestBody: cap = core.Capability_ImageToImage modelID = defaultImageToImageModelID if v.ModelId != nil { @@ -816,7 +816,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToImage(ctx, params, sess, v) } - case worker.ImageToVideoMultipartRequestBody: + case worker.GenImageToVideoMultipartRequestBody: cap = core.Capability_ImageToVideo modelID = defaultImageToVideoModelID if v.ModelId != nil { @@ -825,7 +825,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToVideo(ctx, params, sess, v) } - case worker.UpscaleMultipartRequestBody: + case worker.GenUpscaleMultipartRequestBody: cap = core.Capability_Upscale modelID = defaultUpscaleModelID if v.ModelId != nil { @@ -834,7 +834,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitUpscale(ctx, params, sess, v) } - case worker.AudioToTextMultipartRequestBody: + case worker.GenAudioToTextMultipartRequestBody: cap = core.Capability_AudioToText modelID = defaultAudioToTextModelID if v.ModelId != nil { @@ -843,7 +843,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitAudioToText(ctx, params, sess, v) } - case worker.SegmentAnything2MultipartRequestBody: + case worker.GenSegmentAnything2MultipartRequestBody: cap = core.Capability_SegmentAnything2 modelID = defaultSegmentAnything2ModelID if v.ModelId != nil { diff --git a/server/rpc.go b/server/rpc.go index db797b9eba..6c1365ccd6 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -63,12 +63,12 @@ type Orchestrator interface { DebitFees(addr ethcommon.Address, manifestID core.ManifestID, price *net.PriceInfo, pixels int64) Capabilities() *net.Capabilities AuthToken(sessionID string, expiration int64) *net.AuthToken - TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) - ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) - ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) - Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) - AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) - SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) + ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) + ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) + Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) + AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) + SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) } // Balance describes methods for a session's balance maintenance diff --git a/server/rpc_test.go b/server/rpc_test.go index 197e2baa6b..6f710a319d 100644 --- a/server/rpc_test.go +++ b/server/rpc_test.go @@ -187,22 +187,22 @@ func (r *stubOrchestrator) TranscoderSecret() string { func (r *stubOrchestrator) PriceInfoForCaps(sender ethcommon.Address, manifestID core.ManifestID, caps *net.Capabilities) (*net.PriceInfo, error) { return &net.PriceInfo{PricePerUnit: 4, PixelsPerUnit: 1}, nil } -func (r *stubOrchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func (r *stubOrchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *stubOrchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *stubOrchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *stubOrchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *stubOrchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *stubOrchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *stubOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } -func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { +func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return nil, nil } func (r *stubOrchestrator) CheckAICapacity(pipeline, modelID string) bool { @@ -1373,22 +1373,22 @@ func (o *mockOrchestrator) AuthToken(sessionID string, expiration int64) *net.Au func (r *mockOrchestrator) PriceInfoForCaps(sender ethcommon.Address, manifestID core.ManifestID, caps *net.Capabilities) (*net.PriceInfo, error) { return &net.PriceInfo{PricePerUnit: 4, PixelsPerUnit: 1}, nil } -func (r *mockOrchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { +func (r *mockOrchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *mockOrchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *mockOrchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *mockOrchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *mockOrchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *mockOrchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) { +func (r *mockOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) { return nil, nil } -func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) { +func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) { return nil, nil } -func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { +func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) { return nil, nil } func (r *mockOrchestrator) CheckAICapacity(pipeline, modelID string) bool {