diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index f482e03bfa..d4b708bcac 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -50,6 +50,11 @@ "default": ["start", "output", "logs", "completed"], "items": { "$ref": "#/components/schemas/WebhookEvent" }, "type": "array" + }, + "quiet": { + "title": "Quiet", + "type": "boolean", + "default": false } }, "title": "PredictionRequest", diff --git a/python/cog/schema.py b/python/cog/schema.py index efdad2d0d7..f7bcc5a74e 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -41,8 +41,6 @@ def default_events(cls) -> List["WebhookEvent"]: class PredictionBaseModel(pydantic.BaseModel): - input: Dict[str, Any] - if PYDANTIC_V2: model_config = pydantic.ConfigDict(use_enum_values=True) # type: ignore else: @@ -66,6 +64,8 @@ class Config: class PredictionRequest(PredictionBaseModel): + input: Dict[str, Any] + id: Optional[str] = None created_at: Optional[datetime] = None @@ -77,6 +77,8 @@ class PredictionRequest(PredictionBaseModel): default=WebhookEvent.default_events(), ) + quiet: bool = False + @classmethod def with_types(cls, input_type: Type[Any]) -> Any: # [compat] Input is implicitly optional -- previous versions of the @@ -88,6 +90,8 @@ def with_types(cls, input_type: Type[Any]) -> Any: class PredictionResponse(PredictionBaseModel): + input: Optional[Dict[str, Any]] = None + output: Any = None id: Optional[str] = None diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 6b5360f49c..b7a61ab8e0 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -259,6 +259,9 @@ def __init__( else: request_dict = prediction_request.dict() + if prediction_request.quiet: + request_dict.pop("input", None) + self._p = schema.PredictionResponse(**request_dict) self._p.status = schema.Status.PROCESSING self._output_type_multi = None diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 619496496d..f3fab709ca 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -620,3 +620,43 @@ def test_predict_task_file_uploads_multi(): "http://example.com/hello.jpg", "http://example.com/world.jpg", ] + + +def test_predict_quiet(): + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + quiet=False, + ) + t = PredictTask(p) + + assert t.result.status == Status.PROCESSING + assert t.result.output is None + assert t.result.logs == "" + assert isinstance(t.result.started_at, datetime) + t.set_output_type(multi=False) + t.append_output("giraffes") + assert t.result.output == "giraffes" + assert t.result.input == {"hello": "there"} + + p = PredictionRequest( + input={"hello": "there"}, + id=None, + created_at=None, + output_file_prefix=None, + webhook=None, + quiet=True, + ) + t = PredictTask(p) + + assert t.result.status == Status.PROCESSING + assert t.result.output is None + assert t.result.logs == "" + assert isinstance(t.result.started_at, datetime) + t.set_output_type(multi=False) + t.append_output("giraffes") + assert t.result.output == "giraffes" + assert t.result.input == None