Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate registry registration for tasks #116

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pprint import pprint
from typing import List, Union
import inspect

import sacrebleu
import lm_eval.base
Expand Down Expand Up @@ -59,8 +60,8 @@
from .ja import marc_ja
from .ja import jcola
from .ja import jblimp
from .ja import wikilingua
from .ja import xwinograd
from .ja import wikilingua_ja
from .ja import xwinograd_ja
from .ja import xlsum_ja
from .ja import jaqket_v1
from .ja import jaqket_v2
Expand Down Expand Up @@ -90,11 +91,8 @@
for ts in sacrebleu.get_available_testsets()
}


########################################
# All tasks
########################################

# Ideally this would be removed and handled based entirely on module names,
# but the name process is irregular, so it can only be transitioned gradually.

TASK_REGISTRY = {
# GLUE
Expand Down Expand Up @@ -323,33 +321,42 @@
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
# JGLUE
"jsquad": jsquad.JSQuAD,
**jsquad.construct_tasks(),
"jaquad": jaquad.JaQuAD,
**jaquad.construct_tasks(),
"jcommonsenseqa": jcommonsenseqa.JCommonsenseQA,
**jcommonsenseqa.construct_tasks(),
"jnli": jnli.JNLIWithFintanPrompt,
**jnli.construct_tasks(),
"marc_ja": marc_ja.MARCJaWithFintanPrompt,
**marc_ja.construct_tasks(),
"jcola": jcola.JCoLA,
**jcola.construct_tasks(),
"jblimp": jblimp.JBlimp,
**wikilingua.construct_tasks(),
"xwinograd_ja": xwinograd.XWinogradJA,
"xlsum_ja": xlsum_ja.XLSumJa,
**xlsum_ja.construct_tasks(),
"jaqket_v1": jaqket_v1.JAQKETV1,
**jaqket_v1.construct_tasks(),
"jaqket_v2": jaqket_v2.JAQKETV2,
**jaqket_v2.construct_tasks(),
"mgsm": mgsm.MGSM,
**mgsm.construct_tasks(),
}


def register_tasks():
"""Automatically register subclasses of Task.

Currently this is only guaranteed to work for Japanese tasks. Ideally it
would be updated to handle legacy tasks and avoid manual registration.
"""
qq = []
qq.extend(lm_eval.base.Task.__subclasses__())
while qq:
cls = qq.pop()
# add subclasses to recur
qq.extend(cls.__subclasses__())

# get the shortname using the module
mod = inspect.getmodule(cls)
# XXX skip non-japanese modules
parts = mod.__name__.split(".")
if parts[-2] != "ja":
continue

name = parts[-1]
# only the first one gets added as a plain name
if name not in TASK_REGISTRY:
TASK_REGISTRY[name] = cls

if hasattr(cls, "PROMPT_VERSION"):
# Note that anything with a prompt version has a VERSION
key = f"{name}-{cls.VERSION}-{cls.PROMPT_VERSION}"
TASK_REGISTRY[key] = cls


register_tasks()

ALL_TASKS = sorted(list(TASK_REGISTRY))


Expand Down
File renamed without changes.
File renamed without changes.
Loading