Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enjim dataset(s) #9

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ dependencies = [
"langdetect>=1.0.9",
"tokenizers>=0.13.2",
"ansicolors>=1.1.8",
"PyYAML>=6.0",
"transformers==4.26.1",
"bbcode>=1.1.0",
"torch>=1.13.1",
"pillow>=9.4.0"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -50,3 +55,10 @@ based_on_style = "google"

[tool.mypy]
ignore_missing_imports = true

[tool.poetry]
readme = "README.md"
name = 'toolbox'
version = '0.1.0'
description = 'Code for ingesting data from several sources, formatting it and creating a training dataset.'
authors = ["0x000011b"]
74 changes: 74 additions & 0 deletions resources/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
enjim:
# The maximum number of characters in a user's persona before it is summarized.
max_persona_chars: 1000

# The max characters in a single statement from a human or bot before it's broken up.
max_utterance_chars: 1500

# If there are more than this many chars in a single post, the entire thread is discarded.
cutoff_utterance_chars: 10000

# Percentage of non-OP posts necessary for thread to be considered a valid multi-person roleplay.
minimum_external_participation: 0.25

# The minimum number of posts in a thread.
min_posts_cutoff: 10

# The approximate number of characters the summarization model can handle at once.
summary_char_limit: 3000

# Char limit for scenarios before summarizing.
max_scenario_chars: 1000

# The Named Entity Recognition model to use. Used for figuring out which character a user is playing as.
ner_model: "Jean-Baptiste/roberta-large-ner-english"

# The image recognition model to use.
img_recognition_model: "nlpconnect/vit-gpt2-image-captioning"

# The location of the initial cache file.
cache_db: 'https://files.catbox.moe/kow8f0.db'

sources:
secretworld:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id: '5359865'
character_forums: ['1244856']
roleplay_forums: ['1244866']
other_forums: ['1244872']
ESO:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id: '9324623'
character_forums: [ '7284244', '7265336', '7265338' ]
roleplay_forums: [ '1956980' ]
other_forums: [ '1957130', '2273977', '6443584']
GW2:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id: '2737230'
character_forums: [ '2927425', '2927462', '2927458', '2927434', '2927444', '2927463', '2927468', '2927460', '2927457', '2927469', '2927459']
roleplay_forums: [ '673041' ]
other_forums: [ '2927474', '673031', '2927475' ]
SWTOR-Starforge:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id:
character_forums: [ '3282101' ]
roleplay_forums: [ '3282365', '3286053' ]
other_forums: [ '3282394', '3282102' ]
SWTOR-Malgus:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id:
character_forums: [ '8659181', '8775611', '9268554' ]
roleplay_forums: [ '9468751' ]
other_forums: [ '8656831', '8659204' ]
Star-Trek-Online:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id:
character_forums: [ '2115071', '2115080' ]
roleplay_forums: [ '2552224', '2115149' ]
other_forums: [ '1481092', '2002200', '7533699']
Aion:
path: 'https://files.catbox.moe/XXXXXXX.db'
preset_id: 16093153
character_forums: ['3246488', '3275437']
roleplay_forums: ['3246509']
other_forums: ['3246508', '3246501', '3246503']
181 changes: 181 additions & 0 deletions toolbox/datasets/enjim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import logging
import sqlite3
import typing as t
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from os.path import isfile

import mashumaro
import requests
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification

from toolbox.datasets import BaseDataset
from toolbox.parsers.bb_code import BBCtoMD
from toolbox.utils.chunking import right_size
from toolbox.utils.dataset import get_data_path, get_config


@dataclass(frozen=True)
class EnjimAgent(mashumaro.DataClassDictMixin):
name: str
user_name: str
user_id: str
persona: str


@dataclass(frozen=True)
class EnjimEpisode(mashumaro.DataClassDictMixin):
forum_shortname: str
thread_subject: str
thread_id: str
agents: t.Dict[str, EnjimAgent]
posts: t.List[t.Tuple[str, str]]


class EnjimDataset(BaseDataset[EnjimEpisode]):
CACHE_DB = 'cache'
THREAD_SELECT = "SELECT t.thread_id, t.thread_subject, p.post_content, p.post_user_id, p.post_username " \
"FROM forum_threads t INNER JOIN forum_posts p ON p.thread_id = t.thread_id WHERE " \
"forum_id IN (?) AND post_user_id = ? ORDER BY p.thread_id, post_time"
POSTS_QUERY = "SELECT post_content, post_user_id, post_username from forum_posts where thread_id = ?"

def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.settings = get_config('enjim')
self.root_data_path = get_data_path("enjim")
self.conns = None
self.load_sqlite()
# noinspection PyUnresolvedReferences
self.parser = BBCtoMD(self.settings['img_recognition_model'], self.conns[self.CACHE_DB])
tokenizer = AutoTokenizer.from_pretrained(self.settings['ner_model'])
model = AutoModelForTokenClassification.from_pretrained(self.settings['ner_model'])
self.nlp = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")

def generator(self) -> t.Generator[EnjimEpisode, None, None]:
self.load_sqlite()
for shortname, configuration in self.settings['sources'].items():
roleplay_forums = configuration['roleplay_forums']
thread_query = f'SELECT thread_id, thread_subject FROM forum_threads WHERE forum_id ' \
f'IN ({", ".join(roleplay_forums)}) ORDER BY thread_views DESC'
for idx, (rp_thread_id, rp_thread_subject) in enumerate(self.conns[shortname].execute(thread_query)):
self.logger.info('Processing thread number %s: %s.', idx, rp_thread_subject)
thread = [post for post in self.conns[shortname].execute(self.POSTS_QUERY, (rp_thread_id, ))]
# Not doing at the filter stage for performance reasons
if len(thread) < self.settings['min_posts_cutoff']:
self.logger.warning('Too short of a thread; only %s posts. Skipping.', len(thread))
continue
non_op_pct = sum([1 for post in thread if post[1] != thread[0][1]]) / len(thread)
if non_op_pct < self.settings['minimum_external_participation']:
self.logger.warning('Not enough non-OP posters in thread. Skipping.')
continue
episode: t.Optional[EnjimEpisode] = self.parse_thread(thread, shortname, rp_thread_subject,
rp_thread_id)
if episode is not None:
yield episode

def parse_thread(self, thread, shortname, rp_thread_subject, thread_id) -> t.Optional[EnjimEpisode]:
user_to_agent = {}
agents: t.Dict[str, EnjimAgent] = {}
posts: t.List[t.Tuple[str, str]] = []
for post_content, post_user_id, post_username in thread:
self.logger.debug("Post: %s.", post_username)
formatted_post = self.parser.to_markdown(post_content)
if len(formatted_post) > self.settings['cutoff_utterance_chars']:
self.logger.warning('Too long of an individual post. Discarding thread %s.', thread_id)
return None
if post_username not in user_to_agent:
speaker = self.determine_speaker(formatted_post, post_user_id, post_username, shortname)
agents[speaker.name] = speaker
if len(speaker.name) > 30:
self.logger.error('Invalid speaker. Skipping episode thread %s.', thread_id)
return None
user_to_agent[post_username] = speaker.name
speaker_name = user_to_agent[post_username]
if len(formatted_post) > self.settings['max_utterance_chars']:
sub_utterances = right_size(scenes=[formatted_post],
max_length=self.settings['max_utterance_chars'])
for sub_utterance in sub_utterances:
posts.append((speaker_name, sub_utterance))
else:
posts.append((speaker_name, formatted_post))
return EnjimEpisode(forum_shortname=shortname, thread_subject=rp_thread_subject, posts=posts, agents=agents,
thread_id=thread_id)

@lru_cache(maxsize=10_000)
def get_chardef(self, user_id, char_name, shortname) -> t.Optional[EnjimAgent]:
self.load_sqlite()
self.logger.info('Making/fetching character named %s with user id %s for forum %s.', user_id, char_name,
shortname)
char_forums = self.settings['sources'][shortname]['character_forums']
for thread_id, thread_subject, post_content, post_user_id, post_username in self.conns[shortname].execute(
self.THREAD_SELECT, (", ".join(char_forums), user_id)):
if ' List' not in thread_subject and (char_name in post_content or char_name in thread_subject):
self.logger.info('Found character thread for %s.', char_name)
persona = self.parser.to_markdown(post_content)
character: EnjimAgent = EnjimAgent(name=self.parser.to_markdown(thread_subject),
user_name=post_username,
user_id=user_id,
persona=persona)
return character
self.logger.warning('Thread %s not a match for character %s.', thread_subject, char_name)
self.logger.debug('No character found for user_id %s, character name %s, forum %s. Will not have a persona.',
user_id, char_name, shortname)

def determine_speaker(self, formatted_post, post_user_id, post_username, shortname) -> EnjimAgent:
entities = self.nlp(formatted_post)
counts = defaultdict(int)
for entity in entities:
if entity['entity_group'] == 'PER':
counts[entity['word']] += 1
speaker = None
for person, count in reversed(sorted(counts.items(), key=lambda item: item[1])):
speaker: t.Optional[EnjimAgent] = self.get_chardef(post_user_id, person, shortname)
if speaker is not None:
break
if speaker is None:
self.logger.warning(
'No character found for user_id %s, post_username %s, forum %s. Will not have a persona.',
post_user_id, post_username, shortname)
speaker = EnjimAgent(name=self.parser.to_markdown(post_username),
user_name=post_username, user_id=post_user_id,
persona='')
return speaker

def load_sqlite(self, force_recreate=False):
"""
Downloads the file(s) and sets up connection(s) to the database(s).
"""
if not force_recreate and self.conns is not None and len(self.conns) > 0:
return
self.conns = {}
for shortname, configuration in self.settings['sources'].items():
path = configuration['path']
self.conns[shortname] = setup_sqlite(self.root_data_path, shortname, path, logger=self.logger)
# Duct tape tier way of dealing with how slow running image recognition and summarization pipelines is.
self.conns[self.CACHE_DB] = setup_sqlite(self.root_data_path, 'cache', self.settings['cache_db'],
logger=self.logger)
self.conns[self.CACHE_DB].execute(
'CREATE TABLE IF NOT EXISTS img_cache '
'(img_url TEXT NOT NULL, model TEXT NOT NULL, description TEXT NOT NULL, '
'CONSTRAINT img_pkey PRIMARY KEY (img_url, model))')
self.conns[self.CACHE_DB].execute(
'CREATE TABLE IF NOT EXISTS summary_cache '
'(forum_shortname TEXT, text_id TEXT, max_length integer, summary TEXT, '
'CONSTRAINT summ_pkey PRIMARY KEY (forum_shortname, text_id, max_length))')


def setup_sqlite(root_data_path, shortname, url_path, logger=logging.getLogger()):
data_path = f"{root_data_path}/{shortname}.db"
if not isfile(data_path) and len(url_path) > 0:
logger.info('Downloading dataset %s from %s.', shortname, url_path)
dataset = requests.get(url_path)
with open(data_path, 'wb') as f:
f.write(dataset.content)
try:
return sqlite3.connect(data_path)
except Exception as error:
logger.error('Error while connecting with db %s, error: %s.', shortname, error)
raise error
finally:
logger.info('Added db %s.', shortname)
76 changes: 76 additions & 0 deletions toolbox/modules/enjim_pdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
from functools import lru_cache
from re import search
from typing import Optional, Generator, Tuple, List, Dict

from transformers import pipeline

from toolbox.core.models import Episode, Turn
from toolbox.datasets.enjim import EnjimDataset, EnjimAgent, setup_sqlite
from toolbox.modules import BaseModule
from toolbox.modules.registry import ModuleRegistry
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you were going to implement the registry pattern for the modules but backed away from doing that in this PR?

This import causes a crash since the file doesn't exist, but removing it and the reference below fixes it since it's not used elsewhere.

from toolbox.parsers.bb_code import BBCtoMD
from toolbox.utils.dataset import get_data_path, get_config


class EnjimPDM(BaseModule, metaclass=ModuleRegistry):
"""
Persona Dialogue Module based on the Enjim dataset.
NOTE: All the summarizing stuff is just there until whatever is going to be done with vector dbs is figured out.
"""
CACHE_QUERY = "SELECT summary FROM summary_cache WHERE forum_shortname = ? AND text_id = ? AND max_length = ?"
CACHE_INSERT = "INSERT INTO summary_cache (forum_shortname, text_id, max_length, summary) VALUES (?, ?, ?, ?)"

def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.settings = get_config('enjim')
self.summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")
self.cache_db = setup_sqlite(get_data_path("enjim"), EnjimDataset.CACHE_DB, self.settings['cache_db'],
logger=self.logger)

def generator(self) -> Generator[Episode, None, None]:
for episode in EnjimDataset():
thread_subject: str = episode.thread_subject
agents: Dict[str, EnjimAgent] = episode.agents
posts: List[Tuple[str, str]] = episode.posts
bot_name = posts[0][0]
turns = [Turn(utterance=spoken if not search(BBCtoMD.INVALID_RESULT, spoken)
else self.summarize(spoken, self.settings['max_scenario_chars'], None, None),
speaker=speaker, human_speaker=speaker != bot_name) for speaker, spoken in
posts]
participant_personas = {ag.name: self.summarize_char(ag, self.settings['max_persona_chars'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple minor issues with the personas:

  • Empty personas are generated sometimes, resulting in processed text that looks like:
    A's Persona: 
    B's Persona: (actual persona text here)
    Scenario: (proper scenario here)
    ... 
    
    Which is a little wasteful token-wise and might be confusing to the model when training.
  • They seem to be unparsed query results instead of cleaned text, so instead of generating
     A's Persona: A is like this and that
    
    They look like:
    A's Persona: ("A is like this and that", )
    
    This is also happening on world_scenario.

episode.forum_shortname) for ag in agents.values()}
# Scenario is just the summary of the first post.
world_scenario = thread_subject + ": " + posts[0][1]
world_scenario = self.summarize(world_scenario, self.settings['max_scenario_chars'],
episode.thread_id, episode.forum_shortname)
yield Episode(turns=turns, participant_personas=participant_personas, world_scenario=world_scenario)

@lru_cache(10_000)
def summarize_char(self, character: EnjimAgent, max_length: int, forum_shortname):
if len(character.persona) <= max_length and not search(BBCtoMD.INVALID_RESULT, character.persona):
return character.persona
return self.summarize(character.persona, max_length, character.name+character.user_id, forum_shortname)

def summarize(self, text: str, max_length: int, thread_or_user_id: Optional[str], forum_shortname: Optional[str]):
if forum_shortname is not None:
for summary in self.cache_db.execute(self.CACHE_QUERY, (forum_shortname, thread_or_user_id, max_length)):
return summary
combined_summary = None
if len(text) > self.settings['summary_char_limit']:
for separator in ['. ', '\n']:
if separator in text:
sents = text.split(separator)
halfway = int(len(sents) / 2)
first = self.summarize(separator.join(sents[:halfway]), max_length, None, None)
second = self.summarize(separator.join(sents[halfway:]), max_length, None, None)
combined_summary = first + separator + second
while len(combined_summary) > max_length:
combined_summary = self.summarize(combined_summary, max_length, None, None)
break
else:
combined_summary = self.summarizer(text)[0]['summary_text']
if forum_shortname is not None:
self.cache_db.execute(self.CACHE_INSERT, (forum_shortname, thread_or_user_id, max_length, combined_summary))
self.cache_db.commit()
return combined_summary
Loading