Skip to content

Commit

Permalink
Add input video support.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed Jun 25, 2023
1 parent 77edaba commit 727b375
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
poetry run pip install git+https://github.com/m-bain/whisperx.git
- name: Test with pytest
run: |
poetry run pytest tests --disable-warnings
poetry run pytest tests --disable-warnings --rootdir=tests
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ lrcer.run('./data/test.mp3', target_lang='zh-cn') # Generate translated ./data/
# Multiple files
lrcer.run(['./data/test1.mp3', './data/test2.mp3'], target_lang='zh-cn')
# Note we run the transcription sequentially, but run the translation concurrently for each file.
# Path can contain video
lrcer.run(['./data/test_audio.mp3', './data/test_video.mp4'], target_lang='zh-cn')
```

## Todo
Expand Down
8 changes: 8 additions & 0 deletions openlrc/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (C) 2023. Hao Zheng
# All rights reserved.

class SameLanguageException(Exception):
"""
Raised when the source language and target language are the same.
Expand All @@ -17,3 +20,8 @@ class ChatBotException(Exception):

def __init__(self, message):
super().__init__(message)


class FfmpegException(Exception):
def __init__(self, message):
super().__init__(message)
27 changes: 23 additions & 4 deletions openlrc/openlrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from openlrc.subtitle import Subtitle
from openlrc.transcribe import Transcriber
from openlrc.translate import GPTTranslator
from openlrc.utils import Timer, change_ext, extend_filename, get_audio_duration, format_timestamp
from openlrc.utils import Timer, change_ext, extend_filename, get_audio_duration, format_timestamp, extract_audio, \
get_file_type, get_filename


class LRCer:
Expand All @@ -34,6 +35,7 @@ def __init__(self, model_name='large-v2', compute_type='float16', fee_limit=0.1,
self.transcriber = Transcriber(model_name=model_name, compute_type=compute_type)
self.fee_limit = fee_limit
self.api_fee = 0 # Can be updated in different thread, operation should be thread-safe
self.from_video = set()

self._lock = Lock()
self.exception = None
Expand Down Expand Up @@ -130,17 +132,21 @@ def translation_worker(self, transcription_queue, target_lang, prompter, audio_t
update_name=True) # xxx.json

final_subtitle.to_lrc()
if get_filename(output_filename) in self.from_video:
final_subtitle.to_srt()
logger.info(f'Translation fee til now: {self.api_fee:.4f} USD')

def run(self, audio_paths, target_lang='zh-cn', prompter='base_trans', audio_type='Anime'):
def run(self, paths, target_lang='zh-cn', prompter='base_trans', audio_type='Anime'):
"""
Split the translation into 2 phases: transcription and translation. They're running in parallel.
Firstly, transcribe the audios one-by-one. At the same time, translation threads are created and waiting for
the transcription results. After all the transcriptions are done, the translation threads will start to
translate the transcribed texts.
"""
if isinstance(audio_paths, str):
audio_paths = [audio_paths]
if isinstance(paths, str):
paths = [paths]

audio_paths = self.pre_process(paths)

logger.info(f'Working on {len(audio_paths)} audio files: {pformat(audio_paths)}')

Expand Down Expand Up @@ -185,6 +191,19 @@ def to_json(segments, name, lang):

return result

def pre_process(self, paths):
# Check if path is audio or video
for i, path in enumerate(paths):
if not os.path.exists(path) or not os.path.isfile(path):
raise FileNotFoundError(f'File not found: {path}')

paths[i] = extract_audio(path)

if get_file_type(path) == 'video':
self.from_video.add(get_filename(path))

return paths

@staticmethod
def post_process(transcribed_sub, output_name=None, remove_files=None, t2m=False, update_name=False):
optimizer = SubtitleOptimizer(transcribed_sub)
Expand Down
45 changes: 45 additions & 0 deletions openlrc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,53 @@
from typing import List, Dict, Any

import audioread
import ffmpeg
import tiktoken
import torch

from openlrc.exceptions import FfmpegException
from openlrc.logger import logger


def extract_audio(path: str) -> str:
"""
Extract audio from video.
:return: Audio path
"""
file_type = get_file_type(path)
if file_type == 'audio':
return path

probe = ffmpeg.probe(path)
audio_streams = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
sample_rate = audio_streams['sample_rate']
logger.info(f'File {path}: Audio sample rate: {sample_rate}')

audio, err = (
ffmpeg.input(path).
output("pipe:", format='wav', acodec='pcm_s16le', ar=sample_rate, loglevel='quiet').
run(capture_stdout=True)
)

if err:
raise RuntimeError(f'ffmpeg error: {err}')

audio_path = change_ext(path, 'wav')
with open(audio_path, 'wb') as f:
f.write(audio)

return audio_path


def get_file_type(path: str) -> str:
try:
video_stream = ffmpeg.probe(path, select_streams='v')['streams']
except Exception as e:
raise FfmpegException(f'ffmpeg error: {e}')

return ['audio', 'video'][len(video_stream) > 0]


def get_audio_duration(path: str) -> float:
return audioread.audio_open(path).duration

Expand All @@ -36,6 +77,10 @@ def get_messages_token_number(messages: List[Dict[str, Any]], model: str = "gpt-
return total


def get_filename(path: str) -> str:
return splitext(path)[0]


def change_ext(filename: str, ext: str) -> str:
"""Change the extension of a filename."""
return f'{splitext(filename)[0]}.{ext}'
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ zhconv = "^1.4.3"
punctuators = "^0.0.5"
colorlog = "^6.7.0"
pytest = "^7.4.0"
ffmpeg-python = "^0.2.0"


[[tool.poetry.source]]
Expand Down
Binary file added tests/data/test_video.mp4
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2023. Hao Zheng
# All rights reserved.
41 changes: 40 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,47 @@
import pytest
import torch

from openlrc.exceptions import FfmpegException
from openlrc.utils import format_timestamp, parse_timestamp, get_text_token_number, get_messages_token_number, \
change_ext, extend_filename, release_memory
change_ext, extend_filename, release_memory, extract_audio, get_file_type


@pytest.fixture
def video_file():
return 'data/test_video.mp4'


@pytest.fixture
def audio_file():
return 'data/test_video.wav'


def test_extract_audio(video_file, audio_file):
# Test extracting audio from a video file
extracted_audio_file = extract_audio(video_file)
assert extracted_audio_file == audio_file

# Test extracting audio from an audio file
extracted_audio_file = extract_audio(audio_file)
assert extracted_audio_file == audio_file

# Test extracting audio from an unsupported file type
with pytest.raises(FfmpegException):
extract_audio('unsupported_file.xyz')


def test_get_file_type(video_file, audio_file):
# Test getting the file type of video file
file_type = get_file_type(video_file)
assert file_type == 'video'

# Test getting the file type of audio file
file_type = get_file_type(audio_file)
assert file_type == 'audio'

# Test getting the file type of unsupported file type
with pytest.raises(FfmpegException):
get_file_type('unsupported_file.xyz')


def test_lrc_format():
Expand Down

0 comments on commit 727b375

Please sign in to comment.