Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates for timestamp accuracy. #42

Merged
merged 4 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: CI

on:
push:
branches: [ "master" ]
branches: [ "master", "dev" ]
pull_request:
branches: [ "master" ]

Expand Down
15 changes: 8 additions & 7 deletions openlrc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
# All rights reserved.
import abc
import re
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Type, Union

from openlrc.chatbot import route_chatbot
from openlrc.chatbot import route_chatbot, GPTBot, ClaudeBot
from openlrc.context import TranslationContext, TranslateInfo
from openlrc.logger import logger
from openlrc.prompter import BaseTranslatePrompter, ContextReviewPrompter, POTENTIAL_PREFIX_COMBOS, \
ProofreaderPrompter, PROOFREAD_PREFIX


class Agent(abc.ABC):
TEMPERATURE = 0.5
TEMPERATURE = 1
"""
Base class for all agents.
"""

def _initialize_chatbot(self, chatbot_model: str, fee_limit: float, proxy: str, base_url_config: Optional[dict]):
chatbot_cls: Union[Type[ClaudeBot], Type[GPTBot]]
chatbot_cls, model_name = route_chatbot(chatbot_model)
return chatbot_cls(model=model_name, fee_limit=fee_limit, proxy=proxy, retry=3,
temperature=self.TEMPERATURE, base_url_config=base_url_config)
Expand All @@ -28,10 +29,10 @@ class ChunkedTranslatorAgent(Agent):
Translate the well-defined chunked text to the target language and send it to the chatbot for further processing.
"""

TEMPERATURE = 0.9
TEMPERATURE = 1.0

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.chatbot_model = chatbot_model
Expand Down Expand Up @@ -108,7 +109,7 @@ class ContextReviewerAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
Expand Down Expand Up @@ -139,7 +140,7 @@ class ProofreaderAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
Expand Down
18 changes: 8 additions & 10 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def route_chatbot(model):
class ChatBot:
pricing = None

def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.2):
def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25):
self.pricing = pricing
self._model = None

Expand Down Expand Up @@ -172,6 +172,9 @@ class GPTBot(ChatBot):
def __init__(self, model='gpt-3.5-turbo-0125', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
fee_limit=0.05, proxy=None, base_url_config=None):

# clamp temperature to 0-2
temperature = max(0, min(2, temperature))

super().__init__(self.pricing, temperature, top_p, retry, max_async, fee_limit)

self.async_client = AsyncGPTClient(
Expand All @@ -181,12 +184,7 @@ def __init__(self, model='gpt-3.5-turbo-0125', temperature=1, top_p=1, retry=8,
)

self.model = model
self.temperature = temperature
self.top_p = top_p
self.retry = retry
self.max_async = max_async
self.json_mode = json_mode
self.fee_limit = fee_limit

def __exit__(self, exc_type, exc_val, exc_tb):
self.async_client.close()
Expand Down Expand Up @@ -252,9 +250,12 @@ class ClaudeBot(ChatBot):
'claude-3-haiku-20240307': (0.25, 1.25)
}

def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.2,
def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25,
proxy=None, base_url_config=None):

# clamp temperature to 0-1
temperature = max(0, min(1, temperature))

super().__init__(self.pricing, temperature, top_p, retry, max_async, fee_limit)

self.async_client = AsyncAnthropic(
Expand All @@ -266,9 +267,6 @@ def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, ret
)

self.model = model
self.retry = retry
self.max_async = max_async
self.fee_limit = fee_limit

def update_fee(self, response: Message):
prompt_price, completion_price = all_pricing[self.model]
Expand Down
2 changes: 1 addition & 1 deletion openlrc/openlrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class LRCer:
"""

def __init__(self, whisper_model='large-v3', compute_type='float16', chatbot_model: str = 'gpt-3.5-turbo',
fee_limit=0.2, consumer_thread=4, asr_options=None, vad_options=None, preprocess_options=None,
fee_limit=0.25, consumer_thread=4, asr_options=None, vad_options=None, preprocess_options=None,
proxy=None, base_url_config=None, glossary: Union[dict, str, Path] = None, retry_model=None):
self.chatbot_model = chatbot_model
self.fee_limit = fee_limit
Expand Down
42 changes: 37 additions & 5 deletions openlrc/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from openlrc.logger import logger
from openlrc.subtitle import Subtitle
from openlrc.utils import extend_filename
from openlrc.utils import extend_filename, format_timestamp

# Thresholds for different languages
CUT_LONG_THRESHOLD = {
Expand Down Expand Up @@ -81,18 +81,42 @@ def merge_short(self, threshold=1.2):
merged_element = None
new_elements.append(element)
else:
merged_element = self._merge_elements(merged_element, element)
if not merged_element:
merged_element = element
continue

# Merge to previous element if closer to pre-element and gap > 3s
previous_gap = merged_element.start - new_elements[-1].start
next_gap = element.start - merged_element.end
if previous_gap <= next_gap and previous_gap <= 3:
previous_element = new_elements.pop()
merged_element.text = previous_element.text + merged_element.text
merged_element.start = previous_element.start
new_elements.append(merged_element)
merged_element = element
elif next_gap <= previous_gap and next_gap <= 3:
merged_element.text += element.text
merged_element.end = element.end
new_elements.append(merged_element)
merged_element = None
else:
new_elements.append(merged_element)
merged_element = element

self.subtitle.segments = new_elements

def _finalize_merge(self, new_elements, merged_element, element):
if merged_element.duration < 1.5:
if element.start - merged_element.end < merged_element.start - new_elements[-1].end:
previous_gap = merged_element.start - new_elements[-1].end
next_gap = element.start - merged_element.end
if previous_gap <= next_gap and previous_gap <= 3:
new_elements[-1].text += merged_element.text
new_elements[-1].end = merged_element.end
elif next_gap <= previous_gap and next_gap <= 3:
element.text = merged_element.text + element.text
element.start = merged_element.start
else:
new_elements[-1].text += merged_element.text
new_elements[-1].end = merged_element.end
new_elements.append(merged_element)
else:
new_elements.append(merged_element)

Expand Down Expand Up @@ -204,10 +228,18 @@ def perform_all(self, steps: Optional[List[str]] = None, extend_time=False):
if extend_time:
self.extend_time()

# Finally check to notify users
self.check()

def save(self, output_name: Optional[str] = None, update_name=False):
"""
Save the optimized subtitle to a file.
"""
optimized_name = extend_filename(self.filename, '_optimized') if not output_name else output_name
self.subtitle.save(optimized_name, update_name=update_name)
logger.info(f'Optimized json file saved to {optimized_name}')

def check(self):
for element in self.subtitle.segments:
if element.duration >= 10:
logger.warning(f'Duration of text "{element.text}" at {format_timestamp(element.start)} exceeds 10')
45 changes: 33 additions & 12 deletions openlrc/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,19 @@ def is_punct(char):
latter_words = seg_entry.words[len(former_words):]

if not latter_words:
# Directly split using the hard-mid
former_words = seg_entry.words[:len(seg_entry.words) // 2]
latter_words = seg_entry.words[len(seg_entry.words) // 2:]
# Directly split using the largest word-word gap
gaps = [-1]
for k in range(len(seg_entry.words) - 1):
gaps.append(seg_entry.words[k + 1].start - seg_entry.words[k].end)
max_gap = max(gaps)
split_idx = gaps.index(max_gap) # TODO: Multiple largest or Multiple long gap

if max_gap >= 2: # Split using the max gap
former_words = seg_entry.words[:split_idx]
latter_words = seg_entry.words[split_idx:]
else: # Split using hard-mid
former_words = seg_entry.words[:len(seg_entry.words) // 2]
latter_words = seg_entry.words[len(seg_entry.words) // 2:]

former = seg_from_words(seg_entry, seg_entry.id, former_words, seg_entry.tokens[:len(former_words)])
latter = seg_from_words(seg_entry, seg_entry.id + 1, latter_words, seg_entry.tokens[len(former_words):])
Expand Down Expand Up @@ -162,16 +172,27 @@ def is_punct(char):
entry = seg_from_words(segment, id_cnt, split_words,
segment.tokens[word_start: word_start + len(split_words)])

# Check if the sentence is too long in words
if len(split) < (45 if lang in self.continuous_scripted else 90) or len(entry.words) == 1:
# split if duration > 10s
if entry.end - entry.start > 10:
segmented_entries = mid_split(entry)
def recursive_segment(entry):
if len(entry.text) < (45 if lang in self.continuous_scripted else 90) or len(entry.words) == 1:
if entry.end - entry.start > 10:
# split if duration > 10s
segmented_entries = mid_split(entry)
further_segmented = []
for segment in segmented_entries:
further_segmented.extend(recursive_segment(segment))
else:
return [entry]
else:
segmented_entries = [entry]
else:
# Split them in the middle
segmented_entries = mid_split(entry)
# Split them in the middle
segmented_entries = mid_split(entry)
further_segmented = []
for segment in segmented_entries:
further_segmented.extend(recursive_segment(segment))

return further_segmented

# Check if the sentence is too long in words
segmented_entries = recursive_segment(entry)

sentences.extend(segmented_entries)
id_cnt += len(segmented_entries)
Expand Down
2 changes: 1 addition & 1 deletion openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def translate(self, texts: Union[str, List[str]], src_lang: str, target_lang: st
class LLMTranslator(Translator):
CHUNK_SIZE = 30

def __init__(self, chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.2, chunk_size: int = CHUNK_SIZE,
def __init__(self, chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, chunk_size: int = CHUNK_SIZE,
intercept_line: Optional[int] = None, proxy: Optional[str] = None,
base_url_config: Optional[dict] = None,
retry_model: Optional[str] = None):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,16 @@ def test_route_chatbot_error(self):
chatbot_model = 'openai: invalid_model_name'
with self.assertRaises(ValueError):
route_chatbot(chatbot_model + 'error')

def test_temperature_clamp(self):
chatbot1 = GPTBot(temperature=10, top_p=1, retry=8, max_async=16)
chatbot2 = GPTBot(temperature=-1, top_p=1, retry=8, max_async=16)
chatbot3 = ClaudeBot(temperature=2, top_p=1, retry=8, max_async=16)
chatbot4 = ClaudeBot(temperature=-1, top_p=1, retry=8, max_async=16)

self.assertEqual(chatbot1.temperature, 2)
self.assertEqual(chatbot2.temperature, 0)
self.assertEqual(chatbot3.temperature, 1)
self.assertEqual(chatbot4.temperature, 0)

# TODO: Retry_bot testing
Loading