Skip to content

Commit

Permalink
Add tagging when training with SQL storage
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Oct 11, 2017
1 parent 45603b9 commit aa209f5
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 27 deletions.
24 changes: 22 additions & 2 deletions chatterbot/conversation/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,26 @@
from .response import Response


class Statement(object):
class StatementMixin(object):
"""
This class has shared methods used to
normalize different statement models.
"""

def get_tags(self):
"""
Return the list of tags for this statement.
"""
return self.tags

def add_tags(self, tags):
"""
Add a list of strings to the statement as tags.
"""
for tag in tags:
self.tags.append(tag)

class Statement(StatementMixin):
"""
A statement represents a single spoken entity, sentence or
phrase that someone can say.
Expand All @@ -17,6 +36,7 @@ def __init__(self, text, **kwargs):
pass

self.text = text
self.tags = kwargs.pop('tags', [])
self.in_response_to = kwargs.pop('in_response_to', [])

self.extra_data = kwargs.pop('extra_data', {})
Expand Down Expand Up @@ -80,7 +100,7 @@ def add_response(self, response):
"""
if not isinstance(response, Response):
raise Statement.InvalidTypeException(
'A {} was recieved when a {} instance was expected'.format(
'A {} was received when a {} instance was expected'.format(
type(response),
type(Response(''))
)
Expand Down
25 changes: 17 additions & 8 deletions chatterbot/ext/sqlalchemy_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declared_attr, declarative_base
from chatterbot.conversation.statement import StatementMixin


class ModelBase(object):
Expand Down Expand Up @@ -42,7 +43,7 @@ class Tag(Base):
name = Column(String)


class Statement(Base):
class Statement(Base, StatementMixin):
"""
A Statement represents a sentence or phrase.
"""
Expand All @@ -62,12 +63,25 @@ class Statement(Base):
back_populates='statement_table'
)

def get_tags(self):
"""
Return a list of tags for this statement.
"""
return [tag.name for tag in self.tags]

def get_statement(self):
from chatterbot.conversation import Statement as StatementObject
from chatterbot.conversation import Response as ResponseObject

statement = StatementObject(self.text, extra_data=self.extra_data)
statement = StatementObject(
self.text,
tags=[tag.name for tag in self.tags],
extra_data=self.extra_data
)
for response in self.in_response_to:
statement.add_response(response.get_response())
statement.add_response(
ResponseObject(text=response.text, occurrence=response.occurrence)
)
return statement


Expand All @@ -94,11 +108,6 @@ class Response(Base):
uselist=False
)

def get_response(self):
from chatterbot.conversation import Response as ResponseObject
occ = {'occurrence': self.occurrence}
return ResponseObject(text=self.text, **occ)


conversation_association_table = Table(
'conversation_association',
Expand Down
50 changes: 33 additions & 17 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def get_conversation_model(self):
from chatterbot.ext.sqlalchemy_app.models import Conversation
return Conversation

def get_tag_model(self):
"""
Return the conversation model.
"""
from chatterbot.ext.sqlalchemy_app.models import Tag
return Tag

def count(self):
"""
Return the number of entries in the database.
Expand Down Expand Up @@ -226,6 +233,7 @@ def update(self, statement):
"""
Statement = self.get_model('statement')
Response = self.get_model('response')
Tag = self.get_model('tag')

if statement:
session = self.Session()
Expand All @@ -238,25 +246,33 @@ def update(self, statement):

record.extra_data = dict(statement.extra_data)

if statement.in_response_to:
# Get or create the response records as needed
for response in statement.in_response_to:
_response = session.query(Response).filter_by(
for _tag in statement.tags:
tag = session.query(Tag).filter_by(name=_tag).first()

if not tag:
# Create the record
tag = Tag(name=_tag)

record.tags.append(tag)

# Get or create the response records as needed
for response in statement.in_response_to:
_response = session.query(Response).filter_by(
text=response.text,
statement_text=statement.text
).first()

if _response:
_response.occurrence += 1
else:
# Create the record
_response = Response(
text=response.text,
statement_text=statement.text
).first()
statement_text=statement.text,
occurrence=response.occurrence
)

if _response:
_response.occurrence += 1
else:
# Create the record
_response = Response(
text=response.text,
statement_text=statement.text,
occurrence=response.occurrence
)

record.in_response_to.append(_response)
record.in_response_to.append(_response)

session.add(record)

Expand Down
1 change: 1 addition & 0 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def train(self, *corpus_paths):

for text in conversation:
statement = self.get_or_create(text)
statement.add_tags(corpus.categories)

if previous_statement_text:
statement.add_response(
Expand Down
11 changes: 11 additions & 0 deletions tests/training_tests/test_chatterbot_corpus_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@


class ChatterBotCorpusTrainingTestCase(ChatBotTestCase):
"""
Test case for training with data from the ChatterBot Corpus.
"""

def setUp(self):
super(ChatterBotCorpusTrainingTestCase, self).setUp()
Expand All @@ -12,8 +15,16 @@ def test_train_with_english_greeting_corpus(self):
self.chatbot.train('chatterbot.corpus.english.greetings')

statement = self.chatbot.storage.find('Hello')

self.assertIsNotNone(statement)

def test_train_with_english_greeting_corpus_tags(self):
self.chatbot.train('chatterbot.corpus.english.greetings')

statement = self.chatbot.storage.find('Hello')

self.assertIn('greetings', statement.get_tags())

def test_train_with_multiple_corpora(self):
self.chatbot.train(
'chatterbot.corpus.english.greetings',
Expand Down

0 comments on commit aa209f5

Please sign in to comment.