Skip to content

Commit

Permalink
Automate registry registration for tasks (#116)
Browse files Browse the repository at this point in the history
* Automate registry creation

Rather than manually build the registry, this iterates overall subtasks
of Task to build it. This makes registry entries less manual and error
prone, and removes redundant code.

This is not complete. There are two main issues:

1. Registry manipulation functions (with names like `create_tasks`)
   should be removed.

2. For older English tasks with multiple versions, tasks aren't added to
   the registry properly.

Neither of these are complicated, they will be added in further commits.

* Leave legacy registry code alone

This restores the manual creation of the registry for non-Japanese
tasks. While it's possible to register them automatically, the mapping
of the module or class name to the task name varies considerably, so it
would require a lot of special casing to get it right.

---------

Co-authored-by: Paul O'Leary McCann <[email protected]>
  • Loading branch information
polm-stability and polm authored Nov 21, 2023
1 parent e68527f commit effdbea
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
69 changes: 38 additions & 31 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pprint import pprint
from typing import List, Union
import inspect

import sacrebleu
import lm_eval.base
Expand Down Expand Up @@ -59,8 +60,8 @@
from .ja import marc_ja
from .ja import jcola
from .ja import jblimp
from .ja import wikilingua
from .ja import xwinograd
from .ja import wikilingua_ja
from .ja import xwinograd_ja
from .ja import xlsum_ja
from .ja import jaqket_v1
from .ja import jaqket_v2
Expand Down Expand Up @@ -90,11 +91,8 @@
for ts in sacrebleu.get_available_testsets()
}


########################################
# All tasks
########################################

# Ideally this would be removed and handled based entirely on module names,
# but the name process is irregular, so it can only be transitioned gradually.

TASK_REGISTRY = {
# GLUE
Expand Down Expand Up @@ -323,33 +321,42 @@
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
# JGLUE
"jsquad": jsquad.JSQuAD,
**jsquad.construct_tasks(),
"jaquad": jaquad.JaQuAD,
**jaquad.construct_tasks(),
"jcommonsenseqa": jcommonsenseqa.JCommonsenseQA,
**jcommonsenseqa.construct_tasks(),
"jnli": jnli.JNLIWithFintanPrompt,
**jnli.construct_tasks(),
"marc_ja": marc_ja.MARCJaWithFintanPrompt,
**marc_ja.construct_tasks(),
"jcola": jcola.JCoLA,
**jcola.construct_tasks(),
"jblimp": jblimp.JBlimp,
**wikilingua.construct_tasks(),
"xwinograd_ja": xwinograd.XWinogradJA,
"xlsum_ja": xlsum_ja.XLSumJa,
**xlsum_ja.construct_tasks(),
"jaqket_v1": jaqket_v1.JAQKETV1,
**jaqket_v1.construct_tasks(),
"jaqket_v2": jaqket_v2.JAQKETV2,
**jaqket_v2.construct_tasks(),
"mgsm": mgsm.MGSM,
**mgsm.construct_tasks(),
}


def register_tasks():
"""Automatically register subclasses of Task.
Currently this is only guaranteed to work for Japanese tasks. Ideally it
would be updated to handle legacy tasks and avoid manual registration.
"""
qq = []
qq.extend(lm_eval.base.Task.__subclasses__())
while qq:
cls = qq.pop()
# add subclasses to recur
qq.extend(cls.__subclasses__())

# get the shortname using the module
mod = inspect.getmodule(cls)
# XXX skip non-japanese modules
parts = mod.__name__.split(".")
if parts[-2] != "ja":
continue

name = parts[-1]
# only the first one gets added as a plain name
if name not in TASK_REGISTRY:
TASK_REGISTRY[name] = cls

if hasattr(cls, "PROMPT_VERSION"):
# Note that anything with a prompt version has a VERSION
key = f"{name}-{cls.VERSION}-{cls.PROMPT_VERSION}"
TASK_REGISTRY[key] = cls


register_tasks()

ALL_TASKS = sorted(list(TASK_REGISTRY))


Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit effdbea

Please sign in to comment.