diff --git a/chatterbot/conversation/statement.py b/chatterbot/conversation/statement.py index a1938dd9f..dd3e5ca7a 100644 --- a/chatterbot/conversation/statement.py +++ b/chatterbot/conversation/statement.py @@ -29,6 +29,7 @@ class Statement(StatementMixin): """ def __init__(self, text, **kwargs): + import sys # Try not to allow non-string types to be passed to statements try: @@ -36,6 +37,13 @@ def __init__(self, text, **kwargs): except UnicodeEncodeError: pass + # Prefer decoded utf8-strings in Python 2.7 + if sys.version_info[0] < 3: + try: + text = text.decode('utf-8') + except UnicodeEncodeError: + pass + self.text = text self.tags = kwargs.pop('tags', []) self.in_response_to = kwargs.pop('in_response_to', []) diff --git a/chatterbot/ext/sqlalchemy_app/models.py b/chatterbot/ext/sqlalchemy_app/models.py index 121a6587a..8be09f86e 100644 --- a/chatterbot/ext/sqlalchemy_app/models.py +++ b/chatterbot/ext/sqlalchemy_app/models.py @@ -1,7 +1,8 @@ -from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey, PickleType +from sqlalchemy import Table, Column, Integer, DateTime, ForeignKey, PickleType from sqlalchemy.orm import relationship from sqlalchemy.sql import func from sqlalchemy.ext.declarative import declared_attr, declarative_base +from chatterbot.ext.sqlalchemy_app.types import UnicodeString from chatterbot.conversation.statement import StatementMixin @@ -40,7 +41,7 @@ class Tag(Base): A tag that describes a statement. """ - name = Column(String) + name = Column(UnicodeString) class Statement(Base, StatementMixin): @@ -48,7 +49,7 @@ class Statement(Base, StatementMixin): A Statement represents a sentence or phrase. """ - text = Column(String, unique=True) + text = Column(UnicodeString, unique=True) tags = relationship( 'Tag', @@ -90,7 +91,7 @@ class Response(Base): Response, contains responses related to a given statement. """ - text = Column(String) + text = Column(UnicodeString) created_at = Column( DateTime(timezone=True), @@ -99,7 +100,7 @@ class Response(Base): occurrence = Column(Integer, default=1) - statement_text = Column(String, ForeignKey('statement.text')) + statement_text = Column(UnicodeString, ForeignKey('statement.text')) statement_table = relationship( 'Statement', diff --git a/chatterbot/ext/sqlalchemy_app/types.py b/chatterbot/ext/sqlalchemy_app/types.py new file mode 100644 index 000000000..b48f4f6e4 --- /dev/null +++ b/chatterbot/ext/sqlalchemy_app/types.py @@ -0,0 +1,21 @@ +from sqlalchemy.types import TypeDecorator, Unicode + + +class UnicodeString(TypeDecorator): + """ + Type for unicode strings. + """ + + impl = Unicode + + def process_bind_param(self, value, dialect): + """ + Coerce Python bytestrings to unicode before + saving them to the database. + """ + import sys + + if sys.version_info[0] < 3: + if isinstance(value, str): + value = value.decode('utf-8') + return value diff --git a/chatterbot/input/input_adapter.py b/chatterbot/input/input_adapter.py index 2e764d1a5..17b1dbe14 100644 --- a/chatterbot/input/input_adapter.py +++ b/chatterbot/input/input_adapter.py @@ -19,6 +19,7 @@ def process_input_statement(self, *args, **kwargs): Return an existing statement object (if one exists). """ input_statement = self.process_input(*args, **kwargs) + self.logger.info('Received input statement: {}'.format(input_statement.text)) existing_statement = self.chatbot.storage.find(input_statement.text) diff --git a/tests/training_tests/test_list_training.py b/tests/training_tests/test_list_training.py index 74bd46056..dd63f2c07 100644 --- a/tests/training_tests/test_list_training.py +++ b/tests/training_tests/test_list_training.py @@ -115,6 +115,38 @@ def test_training_with_unicode_characters(self): self.assertEqual(response, conversation[2]) + def test_training_with_emoji_characters(self): + """ + Ensure that the training method adds statements containing emojis. + """ + conversation = [ + u'Hi, how are you? 😃', + u'I am just dandy 👍', + u'Superb! 🎆' + ] + + self.chatbot.train(conversation) + + response = self.chatbot.get_response(conversation[1]) + + self.assertEqual(response, conversation[2]) + + def test_training_with_unicode_bytestring(self): + """ + Test training with an 8-bit bytestring. + """ + conversation = [ + 'Hi, how are you?', + '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x90\x97', + 'Superb!' + ] + + self.chatbot.train(conversation) + + response = self.chatbot.get_response(conversation[1]) + + self.assertEqual(response, conversation[2]) + def test_similar_sentence_gets_same_response_multiple_times(self): """ Tests if the bot returns the same response for the same