Skip to content

Commit

Permalink
Make the initial_utterance customizable.
Browse files Browse the repository at this point in the history
  • Loading branch information
radi-cho committed Mar 22, 2023
1 parent 1b244a6 commit 7b7b8a8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/datasetGPT/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DatasetGenerator:
options_configs: List[Dict[str, Any]]
"""Possible combinations of the provided options."""
generator_index: int = 0
"""Index of the next item be returned by the generator."""
"""Index of the next item to be returned by the generator."""

def __init__(self, config: DatasetGeneratorConfig) -> None:
self.config = config
Expand Down
8 changes: 8 additions & 0 deletions src/datasetGPT/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def datasetGPT() -> None:
type=str,
required=True,
help="Agent role description.")
@click.option("--initial-utterance",
"-u",
"initial_utterance",
type=str,
default="Hello!",
help="An utterance to be provisioned to the first agent. For many use cases a \"Hello\" is enough.")
@click.option("--interruption",
"-i",
"interruption",
Expand Down Expand Up @@ -101,6 +107,7 @@ def conversations(
openai_api_key: str,
agent1: str,
agent2: str,
initial_utterance: str,
num_samples: int,
interruption: str,
end_phrase: str,
Expand All @@ -117,6 +124,7 @@ def conversations(
generator_config = ConversationsGeneratorConfig(openai_api_key=openai_api_key,
agent1=agent1,
agent2=agent2,
initial_utterance=initial_utterance,
num_samples=num_samples,
interruption=interruption,
end_phrase=end_phrase,
Expand Down
9 changes: 4 additions & 5 deletions src/datasetGPT/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ConversationsGeneratorConfig:
"""Description of the first agent used to construct its system message."""
agent2: str
"""Description of the second agent used to construct its system message."""
initial_utterance: str
"""Utterance to be provisioned to the first agent."""
num_samples: int = 1
"""Number of conversations to generate for each options combination."""
interruption: str = "length"
Expand Down Expand Up @@ -95,10 +97,7 @@ def end_phrase_interruption(self, agent: str, message: str) -> None:
if self.config.end_phrase in message:
raise StopIteration()

def generate_item(
self,
initial_utterance: str = "Hello!"
) -> Dict[str, Union[List[List[Any]], float, int]]:
def generate_item(self) -> Dict[str, Union[List[List[Any]], float, int]]:
"""Run two chains to talk with one another and record the chat history."""
if self.generator_index >= len(self.options_configs):
raise StopIteration()
Expand All @@ -116,7 +115,7 @@ def generate_item(

utterances = []

chain1_inp = initial_utterance
chain1_inp = self.config.initial_utterance
for _ in range(conversation_config["length"]):
chain1_out = chain1.predict(input=chain1_inp)
utterances.append(["agent1", chain1_out])
Expand Down

0 comments on commit 7b7b8a8

Please sign in to comment.