Skip to content

Commit

Permalink
(feat) add return_timestamps as configurable in request (#228)
Browse files Browse the repository at this point in the history
This commit adds the `return_timestamps` input of the whisper model to the `audio-to-text` pipelines. This new input allows users to also get word based timestamps by setting the value to `word`.

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
eliteprox and rickstaa authored Oct 25, 2024
1 parent cd75c8a commit 49b460f
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 62 deletions.
2 changes: 1 addition & 1 deletion runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(self, model_id: str):
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
**kwargs,
)

Expand All @@ -79,6 +78,7 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]:

try:
outputs = self.tm(audio.file.read(), **kwargs)
outputs.setdefault("chunks", [])
except Exception as e:
raise InferenceError(original_exception=e)

Expand Down
31 changes: 30 additions & 1 deletion runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@
}


def parse_return_timestamps(value: str) -> Union[bool, str]:
"""Convert a string to a boolean or return the string as is. Sentence is considered
True as it is the model default value.
Args:
value: The value to parse.
Returns:
The parsed value.
"""
value_lower = value.lower()
if value_lower in ("true", "1", "sentence"):
return True
if value_lower in ("false", "0"):
return False
return value_lower


@router.post(
"/audio-to-text",
response_model=TextResponse,
Expand All @@ -77,8 +95,19 @@ async def audio_to_text(
Form(description="Hugging Face model ID used for transcription."),
] = "",
pipeline: Pipeline = Depends(get_pipeline),
return_timestamps: Annotated[
Union[str, bool],
Form(
description=(
"Return timestamps for the transcribed text. Supported values: "
"'sentence', 'word', or a boolean. Default is True ('sentence'). "
"False means no timestamps. 'word' means word-based timestamps."
)
),
] = "true",
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
return_timestamps = parse_return_timestamps(return_timestamps)
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
Expand All @@ -104,7 +133,7 @@ async def audio_to_text(
)

try:
return pipeline(audio=audio)
return pipeline(audio=audio, return_timestamps=return_timestamps)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
Expand Down
9 changes: 9 additions & 0 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,15 @@ components:
title: Model Id
description: Hugging Face model ID used for transcription.
default: ''
return_timestamps:
anyOf:
- type: string
- type: boolean
title: Return Timestamps
description: 'Return timestamps for the transcribed text. Supported values:
''sentence'', ''word'', or a boolean. Default is True (''sentence'').
False means no timestamps. ''word'' means word-based timestamps.'
default: 'true'
type: object
required:
- audio
Expand Down
9 changes: 9 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,15 @@ components:
title: Model Id
description: Hugging Face model ID used for transcription.
default: ''
return_timestamps:
anyOf:
- type: string
- type: boolean
title: Return Timestamps
description: 'Return timestamps for the transcribed text. Supported values:
''sentence'', ''word'', or a boolean. Default is True (''sentence'').
False means no timestamps. ''word'' means word-based timestamps.'
default: 'true'
type: object
required:
- audio
Expand Down
6 changes: 6 additions & 0 deletions worker/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ func NewAudioToTextMultipartWriter(w io.Writer, req GenAudioToTextMultipartReque
}
}

if req.ReturnTimestamps != nil {
if err := mw.WriteField("return_timestamps", *req.ReturnTimestamps); err != nil {
return nil, err
}
}

if err := mw.Close(); err != nil {
return nil, err
}
Expand Down
197 changes: 137 additions & 60 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 49b460f

Please sign in to comment.