Skip to content

Commit

Permalink
Format with Black (facebookresearch#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mubaraq Sani authored and Mubaraq Sani committed Aug 28, 2023
1 parent 1c07b35 commit 0627451
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 21 deletions.
10 changes: 2 additions & 8 deletions examples/speech_to_text/counter_in_tgt_lang_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,11 @@ class CounterInTargetLanguage(SpeechToTextAgent):
def __init__(self, args):
super().__init__(args)
self.wait_seconds = args.wait_seconds
# if args is not None:
# with open(args.tgt_lang, "r") as file:
# tgt_lang = file.read()
# self.tgt_lang = tgt_lang

@staticmethod
def add_args(parser):
parser.add_argument("--wait-seconds", default=1, type=int)
parser.add_argument(
"--tgt-lang"
)
parser.add_argument("--tgt-lang")

def policy(self, states: Optional[AgentStates] = None):
if states is None:
Expand All @@ -51,4 +45,4 @@ def policy(self, states: Optional[AgentStates] = None):
return WriteAction(
content=prediction,
finished=states.source_finished,
)
)
12 changes: 7 additions & 5 deletions simuleval/data/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class GenericDataloader:
"""

def __init__(
self, source_list: List[str], target_list: Union[List[str], List[None]],
self,
source_list: List[str],
target_list: Union[List[str], List[None]],
) -> None:
self.source_list = source_list
self.target_list = target_list
Expand All @@ -53,9 +55,10 @@ def get_target(self, index: int) -> Any:
return self.preprocess_target(self.target_list[index])

def __getitem__(self, index: int) -> Dict[str, Any]:
return {"source": self.get_source(index),
"target": self.get_target(index),
}
return {
"source": self.get_source(index),
"target": self.get_target(index),
}

def preprocess_source(self, source: Any) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -97,4 +100,3 @@ def add_args(parser: ArgumentParser):
default=1,
help="Source segment size, For text the unit is # token, for speech is ms",
)

6 changes: 2 additions & 4 deletions simuleval/data/dataloader/s2t_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def get_source_audio_path(self, index: int):

@classmethod
def from_files(
cls, source: Union[Path, str], target: Union[Path, str],
tgt_lang: str
cls, source: Union[Path, str], target: Union[Path, str], tgt_lang: str
) -> SpeechToTextDataloader:
with open(source) as f:
source_list = [line.strip() for line in f]
Expand All @@ -104,8 +103,7 @@ def from_args(cls, args: Namespace):
class SpeechToSpeechDataloader(SpeechToTextDataloader):
@classmethod
def from_files(
cls, source: Union[Path, str], target: Union[Path, str],
tgt_lang: str
cls, source: Union[Path, str], target: Union[Path, str], tgt_lang: str
) -> SpeechToSpeechDataloader:
with open(source) as f:
source_list = [line.strip() for line in f]
Expand Down
1 change: 1 addition & 0 deletions simuleval/data/dataloader/t2t_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

tgt_lang = "en"


@register_dataloader("text-to-text")
class TextToTextDataloader(GenericDataloader):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion simuleval/data/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def segment_from_json_string(string: str):
elif info_dict["data_type"] == "speech":
return SpeechSegment.from_json(string)
else:
return EmptySegment.from_json(string)
return EmptySegment.from_json(string)
5 changes: 2 additions & 3 deletions simuleval/evaluator/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(
self,
index: int,
dataloader: Optional[SpeechToTextDataloader],
args: Optional[Namespace],
args: Optional[Namespace],
):
super().__init__(index, dataloader, args)
self.args = args
Expand Down Expand Up @@ -284,7 +284,7 @@ def send_source(self, segment_size=10):
content=samples,
sample_rate=self.audio_info.samplerate,
finished=is_finished,
tgt_lang=self.tgt_lang
tgt_lang=self.tgt_lang,
)

else:
Expand Down Expand Up @@ -457,4 +457,3 @@ def __init__(self, info: str) -> None:
self.source_length = self.info.get("source_length") # just for testing!
self.finish_prediction = True
self.metrics = {}

0 comments on commit 0627451

Please sign in to comment.