Skip to content

Commit

Permalink
Fix mongodb conversation response ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Mar 4, 2018
1 parent 69e8834 commit 01c8178
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
21 changes: 16 additions & 5 deletions chatterbot/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,23 @@ class RepetitiveResponseFilter(Filter):

def filter_selection(self, chatterbot, conversation_id):

text_of_recent_responses = []
responses = chatterbot.storage.get_statements_for_conversation(
conversation_id
)

text_of_recent_responses = set()
text_of_all_responses = []

for response in responses:

# Use the latest 1 repetitive responses
if len(text_of_recent_responses) >= 1:
break

if response.text in text_of_all_responses:
text_of_recent_responses.add(response.text)

# TODO: Add a larger quantity of response history
latest_response = chatterbot.storage.get_latest_response(conversation_id)
if latest_response:
text_of_recent_responses.append(latest_response.text)
text_of_all_responses.append(response.text)

# Return the query with no changes if there are no statements to exclude
if not text_of_recent_responses:
Expand Down
32 changes: 27 additions & 5 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ def create_conversation(self):
conversation_id = self.conversations.insert_one({}).inserted_id
return conversation_id

def get_statements_for_conversation(self, conversation_id):
"""
Return all statements in the specified conversation.
"""
from pymongo import DESCENDING

statements = list(self.statements.find({
'conversations.id': conversation_id
}).sort('conversations.created_at', DESCENDING))

return [
self.mongo_to_object(statement) for statement in statements
]

def get_latest_response(self, conversation_id):
"""
Returns the latest response in a conversation if it exists.
Expand All @@ -291,10 +305,16 @@ def get_latest_response(self, conversation_id):
'conversations.id': conversation_id
}).sort('conversations.created_at', DESCENDING))

if not statements:
return None
statement = None

return self.mongo_to_object(statements[-2])
if len(statements) >= 2:
statement = self.mongo_to_object(statements[1])

# Handle the case of the first statement in the list
elif len(statements) == 1:
statement = self.mongo_to_object(statements[0])

return statement

def add_to_conversation(self, conversation_id, statement, response):
"""
Expand All @@ -312,7 +332,8 @@ def add_to_conversation(self, conversation_id, statement, response):
'created_at': datetime.utcnow()
}
}
}
},
upsert=True
)
self.statements.update_one(
{
Expand All @@ -326,7 +347,8 @@ def add_to_conversation(self, conversation_id, statement, response):
'created_at': datetime.utcnow() + timedelta(milliseconds=1)
}
}
}
},
upsert=True
)

def get_random(self):
Expand Down
19 changes: 15 additions & 4 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ def __init__(self, storage, **kwargs):
self.chatbot = kwargs.get('chatbot')
self.storage = storage
self.logger = logging.getLogger(__name__)

self.show_training_progress = kwargs.get('show_training_progress', True)

self.training_conversation_id = self.storage.create_conversation()

def get_preprocessed_statement(self, input_statement):
"""
Preprocess the input statement.
Expand Down Expand Up @@ -113,9 +116,13 @@ def train(self, conversation):
statement.add_response(
Response(previous_statement_text)
)
self.storage.add_to_conversation(
self.training_conversation_id,
statement,
Statement(text=previous_statement_text)
)

previous_statement_text = statement.text
self.storage.update(statement)


class ChatterBotCorpusTrainer(Trainer):
Expand Down Expand Up @@ -163,9 +170,15 @@ def train(self, *corpus_paths):
statement.add_response(
Response(previous_statement_text)
)
self.storage.add_to_conversation(
self.training_conversation_id,
statement,
Response(text=previous_statement_text)
)
else:
self.storage.update(statement)

previous_statement_text = statement.text
self.storage.update(statement)


class TwitterTrainer(Trainer):
Expand Down Expand Up @@ -421,6 +434,4 @@ def train(self):
statement.add_response(
Response(previous_statement_text)
)

previous_statement_text = statement.text
self.storage.update(statement)

0 comments on commit 01c8178

Please sign in to comment.