Skip to content

Commit

Permalink
Use chatterbot-corpus to train PunktSentenceTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Feb 24, 2019
1 parent 92dd93d commit 021d39e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 36 deletions.
44 changes: 32 additions & 12 deletions chatterbot/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
from chatterbot import utils
from chatterbot import languages
from chatterbot import tokenizers
from nltk.corpus import wordnet, stopwords

# Use python-Levenshtein if available
Expand Down Expand Up @@ -77,18 +78,14 @@ def __init__(self):

self.stopwords = None

self.word_tokenizer = None

def initialize_nltk_wordnet(self):
"""
Download required NLTK corpora if they have not already been downloaded.
"""
utils.nltk_download_corpus('corpora/wordnet')

def initialize_nltk_punkt(self):
"""
Download required NLTK corpora if they have not already been downloaded.
"""
utils.nltk_download_corpus('tokenizers/punkt')

def initialize_nltk_stopwords(self):
"""
Download required NLTK corpora if they have not already been downloaded.
Expand All @@ -104,6 +101,15 @@ def get_stopwords(self):

return self.stopwords

def get_word_tokenizer(self):
"""
Get the word tokenizer for this comparison algorithm.
"""
if self.word_tokenizer is None:
self.word_tokenizer = tokenizers.get_word_tokenizer(self.language)

return self.word_tokenizer

def compare(self, statement, other_statement):
"""
Compare the two input statements.
Expand All @@ -114,11 +120,12 @@ def compare(self, statement, other_statement):
.. _wordnet: http://www.nltk.org/howto/wordnet.html
.. _NLTK: http://www.nltk.org/
"""
from nltk import word_tokenize
import itertools

tokens1 = word_tokenize(statement.text.lower())
tokens2 = word_tokenize(other_statement.text.lower())
word_tokenizer = self.get_word_tokenizer()

tokens1 = word_tokenizer.tokenize(statement.text.lower())
tokens2 = word_tokenizer.tokenize(other_statement.text.lower())

# Get the stopwords for the current language
stop_word_set = set(self.get_stopwords())
Expand Down Expand Up @@ -266,6 +273,8 @@ def __init__(self):

self.lemmatizer = None

self.word_tokenizer = None

def initialize_nltk_wordnet(self):
"""
Download the NLTK wordnet corpora that is required for this algorithm
Expand Down Expand Up @@ -306,12 +315,23 @@ def get_lemmatizer(self):

return self.lemmatizer

def get_word_tokenizer(self):
"""
Get the word tokenizer for this comparison algorithm.
"""
if self.word_tokenizer is None:
self.word_tokenizer = tokenizers.get_word_tokenizer(self.language)

return self.word_tokenizer

def compare(self, statement, other_statement):
"""
Return the calculated similarity of two
statements based on the Jaccard index.
"""
from nltk import pos_tag, tokenize
from nltk import pos_tag

word_tokenizer = self.get_word_tokenizer()

# Get the stopwords for the current language
stopwords = self.get_stopwords()
Expand All @@ -326,8 +346,8 @@ def compare(self, statement, other_statement):
a = a.translate(self.punctuation_table)
b = b.translate(self.punctuation_table)

pos_a = pos_tag(tokenize.word_tokenize(a))
pos_b = pos_tag(tokenize.word_tokenize(b))
pos_a = pos_tag(word_tokenizer.tokenize(a))
pos_b = pos_tag(word_tokenizer.tokenize(b))

lemma_a = [
lemmatizer.lemmatize(
Expand Down
18 changes: 2 additions & 16 deletions chatterbot/tagging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import string
from chatterbot import languages
from chatterbot import utils
from chatterbot.tokenizers import get_sentence_tokenizer
from nltk import pos_tag
from nltk.data import load as load_data
from nltk.corpus import wordnet, stopwords
from nltk.corpus.reader.wordnet import WordNetError

Expand Down Expand Up @@ -32,12 +32,6 @@ def initialize_nltk_wordnet(self):
"""
utils.nltk_download_corpus('corpora/wordnet')

def initialize_nltk_punkt(self):
"""
Download required NLTK punkt corpus if it has not already been downloaded.
"""
utils.nltk_download_corpus('punkt')

def initialize_nltk_averaged_perceptron_tagger(self):
"""
Download the NLTK averaged perceptron tagger that is required for this algorithm
Expand All @@ -59,15 +53,7 @@ def tokenize_sentence(self, sentence):
Tokenize the provided sentence.
"""
if self.sentence_tokenizer is None:
try:
self.sentence_tokenizer = load_data('tokenizers/punkt/{language}.pickle'.format(
language=self.language.ENGLISH_NAME.lower()
))
except LookupError:
# Fall back to English sentence splitting rules if a language is not supported
self.sentence_tokenizer = load_data('tokenizers/punkt/{language}.pickle'.format(
language=languages.ENG.ENGLISH_NAME.lower()
))
self.sentence_tokenizer = get_sentence_tokenizer(self.language)

return self.sentence_tokenizer.tokenize(sentence)

Expand Down
62 changes: 62 additions & 0 deletions chatterbot/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pickle import dump, load
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktTrainer
from nltk.tokenize import _treebank_word_tokenizer
from chatterbot.corpus import load_corpus, list_corpus_files
from chatterbot import languages


def get_sentence_tokenizer(language):
"""
Return the sentence tokenizer callable.
"""

pickle_path = 'sentence_tokenizer.pickle'

try:
input_file = open(pickle_path, 'rb')
sentence_tokenizer = load(input_file)
input_file.close()
except FileNotFoundError:

data_file_paths = []

sentences = []

try:
# Get the paths to each file the bot will be trained with
corpus_files = list_corpus_files('chatterbot.corpus.{language}'.format(
language=language.ENGLISH_NAME.lower()
))
except LookupError:
# Fall back to English sentence splitting rules if a language is not supported
corpus_files = list_corpus_files('chatterbot.corpus.{language}'.format(
language=languages.ENG.ENGLISH_NAME.lower()
))

data_file_paths.extend(corpus_files)

for corpus, _categories, _file_path in load_corpus(*data_file_paths):
for conversation in corpus:
for text in conversation:
sentences.append(text.upper())
sentences.append(text.lower())

trainer = PunktTrainer()
trainer.INCLUDE_ALL_COLLOCS = True
trainer.train('\n'.join(sentences))

sentence_tokenizer = PunktSentenceTokenizer(trainer.get_params())

# Pickle the sentence tokenizer for future use
output_file = open(pickle_path, 'wb')
dump(sentence_tokenizer, output_file, -1)
output_file.close()

return sentence_tokenizer


def get_word_tokenizer(language):
"""
Return the word tokenizer callable.
"""
return _treebank_word_tokenizer
12 changes: 4 additions & 8 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def test_get_initialization_functions(self):

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_punkt', functions)
self.assertIn('initialize_nltk_averaged_perceptron_tagger', functions)
self.assertIsLength(functions, 4)
self.assertIsLength(functions, 3)

def test_get_initialization_functions_synset_distance(self):
"""
Expand All @@ -60,9 +59,8 @@ def test_get_initialization_functions_synset_distance(self):

self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_punkt', functions)
self.assertIn('initialize_nltk_averaged_perceptron_tagger', functions)
self.assertIsLength(functions, 4)
self.assertIsLength(functions, 3)

def test_get_initialization_functions_sentiment_comparison(self):
"""
Expand All @@ -76,9 +74,8 @@ def test_get_initialization_functions_sentiment_comparison(self):
self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_vader_lexicon', functions)
self.assertIn('initialize_nltk_punkt', functions)
self.assertIn('initialize_nltk_averaged_perceptron_tagger', functions)
self.assertIsLength(functions, 5)
self.assertIsLength(functions, 4)

def test_get_initialization_functions_jaccard_similarity(self):
"""
Expand All @@ -92,8 +89,7 @@ def test_get_initialization_functions_jaccard_similarity(self):
self.assertIn('initialize_nltk_wordnet', functions)
self.assertIn('initialize_nltk_stopwords', functions)
self.assertIn('initialize_nltk_averaged_perceptron_tagger', functions)
self.assertIn('initialize_nltk_punkt', functions)
self.assertIsLength(functions, 4)
self.assertIsLength(functions, 3)

def test_no_statements_known(self):
"""
Expand Down

0 comments on commit 021d39e

Please sign in to comment.