From 24c5665b2d1ae90c4a3d5c70bc51dcfb0e8480c6 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Tue, 10 Oct 2017 20:41:17 -0400 Subject: [PATCH] Add tagging when training with SQL storage --- chatterbot/conversation/statement.py | 25 +++++++++- chatterbot/ext/sqlalchemy_app/models.py | 25 +++++++--- chatterbot/storage/sql_storage.py | 50 ++++++++++++------- chatterbot/trainers.py | 1 + .../test_chatterbot_corpus_training.py | 11 ++++ 5 files changed, 85 insertions(+), 27 deletions(-) diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index 82c79b73b..a1938dd9f 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -2,7 +2,27 @@ from .response import Response -class Statement(object): +class StatementMixin(object): + """ + This class has shared methods used to + normalize different statement models. + """ + + def get_tags(self): + """ + Return the list of tags for this statement. + """ + return self.tags + + def add_tags(self, tags): + """ + Add a list of strings to the statement as tags. + """ + for tag in tags: + self.tags.append(tag) + + +class Statement(StatementMixin): """ A statement represents a single spoken entity, sentence or phrase that someone can say. @@ -17,6 +37,7 @@ def __init__(self, text, **kwargs): pass self.text = text + self.tags = kwargs.pop('tags', []) self.in_response_to = kwargs.pop('in_response_to', []) self.extra_data = kwargs.pop('extra_data', {}) @@ -80,7 +101,7 @@ def add_response(self, response): """ if not isinstance(response, Response): raise Statement.InvalidTypeException( - 'A {} was recieved when a {} instance was expected'.format( + 'A {} was received when a {} instance was expected'.format( type(response), type(Response('')) ) diff --git a/chatterbot/ext/sqlalchemy_app/models.py b/chatterbot/ext/sqlalchemy_app/models.py index fe8cb547d..121a6587a 100644 --- a/chatterbot/ext/sqlalchemy_app/models.py +++ b/chatterbot/ext/sqlalchemy_app/models.py @@ -2,6 +2,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.sql import func from sqlalchemy.ext.declarative import declared_attr, declarative_base +from chatterbot.conversation.statement import StatementMixin class ModelBase(object): @@ -42,7 +43,7 @@ class Tag(Base): name = Column(String) -class Statement(Base): +class Statement(Base, StatementMixin): """ A Statement represents a sentence or phrase. """ @@ -62,12 +63,25 @@ class Statement(Base): back_populates='statement_table' ) + def get_tags(self): + """ + Return a list of tags for this statement. + """ + return [tag.name for tag in self.tags] + def get_statement(self): from chatterbot.conversation import Statement as StatementObject + from chatterbot.conversation import Response as ResponseObject - statement = StatementObject(self.text, extra_data=self.extra_data) + statement = StatementObject( + self.text, + tags=[tag.name for tag in self.tags], + extra_data=self.extra_data + ) for response in self.in_response_to: - statement.add_response(response.get_response()) + statement.add_response( + ResponseObject(text=response.text, occurrence=response.occurrence) + ) return statement @@ -94,11 +108,6 @@ class Response(Base): uselist=False ) - def get_response(self): - from chatterbot.conversation import Response as ResponseObject - occ = {'occurrence': self.occurrence} - return ResponseObject(text=self.text, **occ) - conversation_association_table = Table( 'conversation_association', diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py index 2c3ad603b..8fbf2d2b2 100644 --- a/chatterbot/storage/sql_storage.py +++ b/chatterbot/storage/sql_storage.py @@ -103,6 +103,13 @@ def get_conversation_model(self): from chatterbot.ext.sqlalchemy_app.models import Conversation return Conversation + def get_tag_model(self): + """ + Return the conversation model. + """ + from chatterbot.ext.sqlalchemy_app.models import Tag + return Tag + def count(self): """ Return the number of entries in the database. @@ -226,6 +233,7 @@ def update(self, statement): """ Statement = self.get_model('statement') Response = self.get_model('response') + Tag = self.get_model('tag') if statement: session = self.Session() @@ -238,25 +246,33 @@ def update(self, statement): record.extra_data = dict(statement.extra_data) - if statement.in_response_to: - # Get or create the response records as needed - for response in statement.in_response_to: - _response = session.query(Response).filter_by( + for _tag in statement.tags: + tag = session.query(Tag).filter_by(name=_tag).first() + + if not tag: + # Create the record + tag = Tag(name=_tag) + + record.tags.append(tag) + + # Get or create the response records as needed + for response in statement.in_response_to: + _response = session.query(Response).filter_by( + text=response.text, + statement_text=statement.text + ).first() + + if _response: + _response.occurrence += 1 + else: + # Create the record + _response = Response( text=response.text, - statement_text=statement.text - ).first() + statement_text=statement.text, + occurrence=response.occurrence + ) - if _response: - _response.occurrence += 1 - else: - # Create the record - _response = Response( - text=response.text, - statement_text=statement.text, - occurrence=response.occurrence - ) - - record.in_response_to.append(_response) + record.in_response_to.append(_response) session.add(record) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 0b73eec13..0a4b59a1f 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -131,6 +131,7 @@ def train(self, *corpus_paths): for text in conversation: statement = self.get_or_create(text) + statement.add_tags(corpus.categories) if previous_statement_text: statement.add_response( diff --git a/tests/training_tests/test_chatterbot_corpus_training.py b/tests/training_tests/test_chatterbot_corpus_training.py index d1a82658e..be7ffca1d 100644 --- a/tests/training_tests/test_chatterbot_corpus_training.py +++ b/tests/training_tests/test_chatterbot_corpus_training.py @@ -3,6 +3,9 @@ class ChatterBotCorpusTrainingTestCase(ChatBotTestCase): + """ + Test case for training with data from the ChatterBot Corpus. + """ def setUp(self): super(ChatterBotCorpusTrainingTestCase, self).setUp() @@ -12,8 +15,16 @@ def test_train_with_english_greeting_corpus(self): self.chatbot.train('chatterbot.corpus.english.greetings') statement = self.chatbot.storage.find('Hello') + self.assertIsNotNone(statement) + def test_train_with_english_greeting_corpus_tags(self): + self.chatbot.train('chatterbot.corpus.english.greetings') + + statement = self.chatbot.storage.find('Hello') + + self.assertIn('greetings', statement.get_tags()) + def test_train_with_multiple_corpora(self): self.chatbot.train( 'chatterbot.corpus.english.greetings',