Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training class for Ubuntu corpus #405

Merged
merged 5 commits into from
Nov 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ docs/_build/
*.iml

examples/settings.py
examples/ubuntu_dialogs*
197 changes: 182 additions & 15 deletions chatterbot/trainers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
from .conversation import Statement, Response
import logging
from .conversation import Statement, Response


class Trainer(object):
"""
Base class for all other trainer classes.
"""

def __init__(self, storage, **kwargs):
self.storage = storage
self.logger = logging.getLogger(__name__)

def train(self, *args, **kwargs):
"""
This class must be overridden by a class the inherits from 'Trainer'.
"""
raise self.TrainerInitializationException()

def get_or_create(self, statement_text):
"""
Return a statement if it exists.
Create and return the statement if it does not exist.
"""
statement = self.storage.find(statement_text)

if not statement:
statement = Statement(statement_text)

return statement

class TrainerInitializationException(Exception):
"""
Exception raised when a base class has not overridden
the required methods on the Trainer base class.
"""

def __init__(self, value=None):
default = (
Expand Down Expand Up @@ -44,20 +66,16 @@ def export_for_training(self, file_path='./export.json'):


class ListTrainer(Trainer):
"""
Allaows a chat bot to be trained using a list of strings
where the list represents a conversation.
"""

def get_or_create(self, statement_text):
def train(self, conversation):
"""
Return a statement if it exists.
Create and return the statement if it does not exist.
Train the chat bot based on the provided list of
statements that represents a single conversation.
"""
statement = self.storage.find(statement_text)

if not statement:
statement = Statement(statement_text)

return statement

def train(self, conversation):
statement_history = []

for text in conversation:
Expand All @@ -73,6 +91,10 @@ def train(self, conversation):


class ChatterBotCorpusTrainer(Trainer):
"""
Allows the chat bot to be trained using data from the
ChatterBot dialog corpus.
"""

def __init__(self, storage, **kwargs):
super(ChatterBotCorpusTrainer, self).__init__(storage, **kwargs)
Expand All @@ -96,6 +118,10 @@ def train(self, *corpora):


class TwitterTrainer(Trainer):
"""
Allows the chat bot to be trained using data
gathered from Twitter.
"""

def __init__(self, storage, **kwargs):
super(TwitterTrainer, self).__init__(storage, **kwargs)
Expand Down Expand Up @@ -167,15 +193,156 @@ def get_statements(self):
status = self.api.GetStatus(tweet.in_reply_to_status_id)
statement.add_response(Response(status.text))
statements.append(statement)
except TwitterError as e:
self.logger.warning(str(e))
except TwitterError as error:
self.logger.warning(str(error))

self.logger.info('Adding {} tweets with responses'.format(len(statements)))

return statements

def train(self):
for i in range(0, 10):
for _ in range(0, 10):
statements = self.get_statements()
for statement in statements:
self.storage.update(statement, force=True)


class UbuntuCorpusTrainer(Trainer):
"""
Allow chatbots to be trained with the data from
the Ubuntu Dialog Corpus.
"""

def __init__(self, storage, **kwargs):
super(UbuntuCorpusTrainer, self).__init__(storage, **kwargs)
import os

self.data_download_url = kwargs.get(
'ubuntu_corpus_data_download_url',
'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
)

self.data_directory = kwargs.get(
'ubuntu_corpus_data_directory',
'./data/'
)

# Create the data directory if it does not already exist
if not os.path.exists(self.data_directory):
os.makedirs(self.data_directory)

def download(self, url, show_status=True):
"""
Download a file from the given url.
Show a progress indicator for the download status.
Based on: http://stackoverflow.com/a/15645088/1547223
"""
import os
import sys
import requests

file_name = url.split('/')[-1]
file_path = os.path.join(self.data_directory, file_name)

# Do not download the data if it already exists
if os.path.exists(file_path):
self.logger.info('File is already downloaded')
return file_path

with open(file_path, 'wb') as open_file:
print('Downloading %s' % file_name)
response = requests.get(url, stream=True)
total_length = response.headers.get('content-length')

if total_length is None:
# No content length header
open_file.write(response.content)
else:
download = 0
total_length = int(total_length)
for data in response.iter_content(chunk_size=4096):
download += len(data)
open_file.write(data)
if show_status:
done = int(50 * download / total_length)
sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))
sys.stdout.flush()

return file_path

def extract(self, file_path):
"""
Extract a tar file at the specified file path.
"""
import os
import tarfile

dir_name = os.path.split(file_path)[-1].split('.')[0]

extracted_file_directory = os.path.join(
self.data_directory,
dir_name
)

# Do not extract if the extracted directory already exists
if os.path.isdir(extracted_file_directory):
return False

self.logger.info('Starting file extraction')

def track_progress(members):
for member in members:
# this will be the current file being extracted
yield member
print('Extracting {}'.format(member.path))

with tarfile.open(file_path) as tar:
tar.extractall(path=self.data_directory, members=track_progress(tar))

self.logger.info('File extraction complete')

return True

def train(self):
import glob
import csv
import os

# Download and extract the Ubuntu dialog corpus
corpus_download_path = self.download(self.data_download_url)

self.extract(corpus_download_path)

extracted_corpus_path = os.path.join(
self.data_directory,
os.path.split(corpus_download_path)[-1].split('.')[0],
'**', '*.tsv'
)

for file in glob.iglob(extracted_corpus_path):
self.logger.info('Training from: {}'.format(file))

with open(file, 'r') as tsv:
reader = csv.reader(tsv, delimiter='\t')

statement_history = []

for row in reader:
if len(row) > 0:
text = row[3]
statement = self.get_or_create(text)
print(text, len(row))

statement.add_extra_data('datetime', row[0])
statement.add_extra_data('speaker', row[1])

if row[2].strip():
statement.add_extra_data('addressing_speaker', row[2])

if statement_history:
statement.add_response(
Response(statement_history[-1].text)
)

statement_history.append(statement)
self.storage.update(statement, force=True)
14 changes: 14 additions & 0 deletions docs/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ Twitter training example
:language: python


Training with the Ubuntu dialog corpus
======================================

.. autofunction:: chatterbot.trainers.UbuntuCorpusTrainer

This training class makes it possible to train your chat bot using the Ubuntu
dialog corpus. Becaue of the file size of the Ubuntu dialog corpus, the download
and training process may take a considerable amount of time.

This training class will handle the process of downloading the compressed corpus
file and extracting it. If the file has already been downloaded, it will not be
downloaded again. If the file is already extracted, it will not be extracted again.


Creating a new training class
=============================

Expand Down
23 changes: 23 additions & 0 deletions examples/ubuntu_corpus_training_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from chatterbot import ChatBot
import logging


'''
This is an example showing how to train a chat bot using the
Ubuntu Corpus of conversation dialog.
'''

# Enable info level logging
logging.basicConfig(level=logging.INFO)

chatbot = ChatBot(
'Example Bot',
trainer='chatterbot.trainers.UbuntuCorpusTrainer'
)

# Start by training our bot with the Ubuntu corpus data
chatbot.train()

# Now let's get a response to a greeting
response = chatbot.get_response('How are you doing today?')
print(response)
Loading