Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlite 8-bit bytestrings with unicode coercion #1099

Merged
merged 4 commits into from
Dec 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions chatterbot/conversation/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,21 @@ class Statement(StatementMixin):
"""

def __init__(self, text, **kwargs):
import sys

# Try not to allow non-string types to be passed to statements
try:
text = str(text)
except UnicodeEncodeError:
pass

# Prefer decoded utf8-strings in Python 2.7
if sys.version_info[0] < 3:
try:
text = text.decode('utf-8')
except UnicodeEncodeError:
pass

self.text = text
self.tags = kwargs.pop('tags', [])
self.in_response_to = kwargs.pop('in_response_to', [])
Expand Down
11 changes: 6 additions & 5 deletions chatterbot/ext/sqlalchemy_app/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey, PickleType
from sqlalchemy import Table, Column, Integer, DateTime, ForeignKey, PickleType
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declared_attr, declarative_base
from chatterbot.ext.sqlalchemy_app.types import UnicodeString
from chatterbot.conversation.statement import StatementMixin


Expand Down Expand Up @@ -40,15 +41,15 @@ class Tag(Base):
A tag that describes a statement.
"""

name = Column(String)
name = Column(UnicodeString)


class Statement(Base, StatementMixin):
"""
A Statement represents a sentence or phrase.
"""

text = Column(String, unique=True)
text = Column(UnicodeString, unique=True)

tags = relationship(
'Tag',
Expand Down Expand Up @@ -90,7 +91,7 @@ class Response(Base):
Response, contains responses related to a given statement.
"""

text = Column(String)
text = Column(UnicodeString)

created_at = Column(
DateTime(timezone=True),
Expand All @@ -99,7 +100,7 @@ class Response(Base):

occurrence = Column(Integer, default=1)

statement_text = Column(String, ForeignKey('statement.text'))
statement_text = Column(UnicodeString, ForeignKey('statement.text'))

statement_table = relationship(
'Statement',
Expand Down
21 changes: 21 additions & 0 deletions chatterbot/ext/sqlalchemy_app/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sqlalchemy.types import TypeDecorator, Unicode


class UnicodeString(TypeDecorator):
"""
Type for unicode strings.
"""

impl = Unicode

def process_bind_param(self, value, dialect):
"""
Coerce Python bytestrings to unicode before
saving them to the database.
"""
import sys

if sys.version_info[0] < 3:
if isinstance(value, str):
value = value.decode('utf-8')
return value
1 change: 1 addition & 0 deletions chatterbot/input/input_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def process_input_statement(self, *args, **kwargs):
Return an existing statement object (if one exists).
"""
input_statement = self.process_input(*args, **kwargs)

self.logger.info('Received input statement: {}'.format(input_statement.text))

existing_statement = self.chatbot.storage.find(input_statement.text)
Expand Down
32 changes: 32 additions & 0 deletions tests/training_tests/test_list_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,38 @@ def test_training_with_unicode_characters(self):

self.assertEqual(response, conversation[2])

def test_training_with_emoji_characters(self):
"""
Ensure that the training method adds statements containing emojis.
"""
conversation = [
u'Hi, how are you? 😃',
u'I am just dandy 👍',
u'Superb! 🎆'
]

self.chatbot.train(conversation)

response = self.chatbot.get_response(conversation[1])

self.assertEqual(response, conversation[2])

def test_training_with_unicode_bytestring(self):
"""
Test training with an 8-bit bytestring.
"""
conversation = [
'Hi, how are you?',
'\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x90\x97',
'Superb!'
]

self.chatbot.train(conversation)

response = self.chatbot.get_response(conversation[1])

self.assertEqual(response, conversation[2])

def test_similar_sentence_gets_same_response_multiple_times(self):
"""
Tests if the bot returns the same response for the same
Expand Down