From 2cef1ce9c1853fa7c5466b4a150b5c752b3f2abc Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Thu, 28 Sep 2017 22:37:21 -0400 Subject: [PATCH] Add get model methods to each adapter --- chatterbot/storage/jsonfile.py | 12 ++++++++ chatterbot/storage/mongodb.py | 26 +++++++++++++++- chatterbot/storage/sql_storage.py | 43 +++++++++++++++++++++------ chatterbot/storage/storage_adapter.py | 12 -------- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/chatterbot/storage/jsonfile.py b/chatterbot/storage/jsonfile.py index 567db099c..838cc3431 100644 --- a/chatterbot/storage/jsonfile.py +++ b/chatterbot/storage/jsonfile.py @@ -37,6 +37,18 @@ def __init__(self, **kwargs): self.adapter_supports_queries = False + def get_statement_model(self): + """ + Return the class for the statement model. + """ + from chatterbot.conversation.statement import Statement + + # Create a storage-aware statement + statement = Statement + statement.storage = self + + return statement + def _keys(self): # The value has to be cast as a list for Python 3 compatibility return list(self.database[0].keys()) diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index aaa9e4937..9c6e332d8 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -1,5 +1,4 @@ from chatterbot.storage import StorageAdapter -from chatterbot.conversation import Response class Query(object): @@ -113,6 +112,30 @@ def __init__(self, **kwargs): self.base_query = Query() + def get_statement_model(self): + """ + Return the class for the statement model. + """ + from chatterbot.conversation.statement import Statement + + # Create a storage-aware statement + statement = Statement + statement.storage = self + + return statement + + def get_response_model(self): + """ + Return the class for the response model. + """ + from chatterbot.conversation.response import Response + + # Create a storage-aware response + response = Response + response.storage = self + + return response + def count(self): return self.statements.count() @@ -140,6 +163,7 @@ def deserialize_responses(self, response_list): the list converted to Response objects. """ Statement = self.get_model('statement') + Response = self.get_model('response') proxy_statement = Statement('') for response in response_list: diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py index 9bdfa6dd5..35c0ff53e 100644 --- a/chatterbot/storage/sql_storage.py +++ b/chatterbot/storage/sql_storage.py @@ -1,4 +1,3 @@ -import random from chatterbot.storage import StorageAdapter @@ -79,11 +78,32 @@ def set_sqlite_pragma(dbapi_connection, connection_record): # ChatterBot's internal query builder is not yet supported for this adapter self.adapter_supports_queries = False + def get_statement_model(self): + """ + Return the statement model. + """ + from chatterbot.ext.sqlalchemy_app.models import Statement + return Statement + + def get_response_model(self): + """ + Return the response model. + """ + from chatterbot.ext.sqlalchemy_app.models import Response + return Response + + def get_conversation_model(self): + """ + Return the conversation model. + """ + from chatterbot.ext.sqlalchemy_app.models import Conversation + return Conversation + def count(self): """ Return the number of entries in the database. """ - from chatterbot.ext.sqlalchemy_app.models import Statement + Statement = self.get_model('statement') session = self.Session() statement_count = session.query(Statement).count() @@ -96,7 +116,7 @@ def __statement_filter(self, session, **kwargs): rtype: query """ - from chatterbot.ext.sqlalchemy_app.models import Statement + Statement = self.get_model('statement') _query = session.query(Statement) return _query.filter_by(**kwargs) @@ -138,7 +158,8 @@ def filter(self, **kwargs): all listed attributes and in which all values match for all listed attributes will be returned. """ - from chatterbot.ext.sqlalchemy_app.models import Statement, Response + Statement = self.get_model('statement') + Response = self.get_model('response') session = self.Session() @@ -199,7 +220,8 @@ def update(self, statement): Modifies an entry in the database. Creates an entry if one does not exist. """ - from chatterbot.ext.sqlalchemy_app.models import Statement, Response + Statement = self.get_model('statement') + Response = self.get_model('response') if statement: session = self.Session() @@ -240,7 +262,7 @@ def create_conversation(self): """ Create a new conversation. """ - from chatterbot.ext.sqlalchemy_app.models import Conversation + Conversation = self.get_model('conversation') session = self.Session() conversation = Conversation() @@ -260,7 +282,8 @@ def add_to_conversation(self, conversation_id, statement, response): """ Add the statement and response to the conversation. """ - from chatterbot.ext.sqlalchemy_app.models import Conversation, Statement + Statement = self.get_model('statement') + Conversation = self.get_model('conversation') session = self.Session() conversation = session.query(Conversation).get(conversation_id) @@ -296,7 +319,7 @@ def get_latest_response(self, conversation_id): Returns the latest response in a conversation if it exists. Returns None if a matching conversation cannot be found. """ - from chatterbot.ext.sqlalchemy_app.models import Statement + Statement = self.get_model('statement') session = self.Session() statement = None @@ -318,7 +341,9 @@ def get_random(self): """ Returns a random statement from the database """ - from chatterbot.ext.sqlalchemy_app.models import Statement + import random + + Statement = self.get_model('statement') session = self.Session() count = self.count() diff --git a/chatterbot/storage/storage_adapter.py b/chatterbot/storage/storage_adapter.py index b1ef0601f..50beac78b 100644 --- a/chatterbot/storage/storage_adapter.py +++ b/chatterbot/storage/storage_adapter.py @@ -33,18 +33,6 @@ def get_model(self, model_name): return get_model_method() - def get_statement_model(self): - """ - Return the class for the statement model. - """ - from chatterbot.conversation.statement import Statement - - # Create a storage-aware statement - statement = Statement - statement.storage = self - - return statement - def generate_base_query(self, chatterbot, session_id): """ Create a base query for the storage adapter.