Skip to content

Commit

Permalink
feat: update the list of supported entrypoint parameter types
Browse files Browse the repository at this point in the history
This update adds support for the "boolean", "integer", "list", and "mapping" entrypoint parameter
types that are already supported by the task execution engine. In addition, the
/workflows/jobFilesDownload endpoint service has been updated to handle all supported types when
creating the parameters.json and task engine YAML files.

The "path" and "uri" types have been removed since they were treated the same as strings.
  • Loading branch information
jkglasbrenner committed Sep 11, 2024
1 parent d85d677 commit ae3ab68
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Update entry point parameter types list
Revision ID: 6a75ede23821
Revises: 4b2d781f8bb4
Create Date: 2024-09-10 16:32:40.707231
"""

from typing import Annotated, Optional

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
sessionmaker,
)

# revision identifiers, used by Alembic.
revision = "6a75ede23821"
down_revision = "4b2d781f8bb4"
branch_labels = None
depends_on = None


# Upgrade and downgrade inserts and deletes
UPGRADE_INSERTS = ["boolean", "integer", "list", "mapping"]
UPGRADE_DELETES = ["path", "uri"]
DOWNGRADE_INSERTS = ["path", "uri"]


# Migration data models
intpk = Annotated[
int,
mapped_column(sa.BigInteger().with_variant(sa.Integer, "sqlite"), primary_key=True),
]
text_ = Annotated[str, mapped_column(sa.Text())]
bool_ = Annotated[bool, mapped_column(sa.Boolean())]
optionalstr = Annotated[Optional[str], mapped_column(sa.Text(), nullable=True)]


class UpgradeBase(DeclarativeBase, MappedAsDataclass):
pass


class DowngradeBase(DeclarativeBase, MappedAsDataclass):
pass


class EntryPointParameterTypeUpgrade(UpgradeBase):
__tablename__ = "entry_point_parameter_types"

parameter_type: Mapped[text_] = mapped_column(primary_key=True)


class EntryPointParameterUpgrade(UpgradeBase):
__tablename__ = "entry_point_parameters"

entry_point_resource_snapshot_id: Mapped[intpk] = mapped_column(init=False)
parameter_number: Mapped[intpk]
parameter_type: Mapped[text_] = mapped_column(nullable=False)
name: Mapped[text_] = mapped_column(nullable=False)
default_value: Mapped[optionalstr]


class EntryPointParameterTypeDowngrade(DowngradeBase):
__tablename__ = "entry_point_parameter_types"

parameter_type: Mapped[text_] = mapped_column(primary_key=True)


def upgrade():
bind = op.get_bind()
Session = sessionmaker(bind=bind)

with Session() as session:
for parameter_type in UPGRADE_INSERTS:
stmt = sa.select(EntryPointParameterTypeUpgrade).where(
EntryPointParameterTypeUpgrade.parameter_type == parameter_type
)

if session.scalar(stmt) is None:
session.add(
EntryPointParameterTypeUpgrade(parameter_type=parameter_type)
)

# Search for any parameters that are of type "path" or "uri" and convert them to
# "string"
to_string_params_stmt = sa.select(EntryPointParameterUpgrade).where(
EntryPointParameterUpgrade.parameter_type.in_(["path", "uri"])
)

for entry_point_parameter in session.execute(to_string_params_stmt):
entry_point_parameter.parameter_type = "string"

for parameter_type in UPGRADE_DELETES:
stmt = sa.select(EntryPointParameterTypeUpgrade).where(
EntryPointParameterTypeUpgrade.parameter_type == parameter_type
)
entry_point_parameter_type = session.scalar(stmt)

if entry_point_parameter_type is not None:
session.delete(entry_point_parameter_type)

session.commit()


def downgrade():
bind = op.get_bind()
Session = sessionmaker(bind=bind)

with Session() as session:
for parameter_type in DOWNGRADE_INSERTS:
stmt = sa.select(EntryPointParameterTypeDowngrade).where(
EntryPointParameterTypeDowngrade.parameter_type == parameter_type
)

if session.scalar(stmt) is None:
session.add(
EntryPointParameterTypeDowngrade(parameter_type=parameter_type)
)

session.commit()
4 changes: 3 additions & 1 deletion src/dioptra/restapi/v1/entrypoints/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class EntrypointParameterSchema(Schema):
attribute="parameter_type",
metadata=dict(description="Data type of the Entrypoint parameter."),
required=True,
validate=validate.OneOf(["string", "float", "path", "uri"]),
validate=validate.OneOf(
["string", "float", "integer", "boolean", "list", "mapping"]
),
)


Expand Down
27 changes: 8 additions & 19 deletions src/dioptra/restapi/v1/workflows/lib/export_job_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

from dioptra.restapi.db import models

from .type_coercions import GlobalParameterType, coerce_to_type

LOGGER: BoundLogger = structlog.stdlib.get_logger()

FLOAT_PARAM_TYPE: Final[str] = "float"
JSON_FILE_ENCODING: Final[str] = "utf-8"


Expand All @@ -47,27 +48,15 @@ def export_job_parameters(
log = logger or LOGGER.new() # noqa: F841
job_params_json_path = Path(base_dir, "parameters").with_suffix(".json")

job_parameters: dict[str, str | int | float | None] = {}
job_parameters: dict[str, GlobalParameterType] = {}
for param_value in job_param_values:
if param_value.parameter.parameter_type == FLOAT_PARAM_TYPE:
job_parameters[param_value.parameter.name] = _convert_to_number(
param_value.value
)

else:
job_parameters[param_value.parameter.name] = param_value.value
value = coerce_to_type(
x=param_value.value,
type_name=param_value.parameter.parameter_type,
)
job_parameters[param_value.parameter.name] = value

with job_params_json_path.open("wt", encoding=JSON_FILE_ENCODING) as f:
json.dump(job_parameters, f, indent=2)

return job_params_json_path


def _convert_to_number(number: str | None) -> int | float | None:
if number is None:
return None

if number.isnumeric():
return int(number)

return float(number)
40 changes: 39 additions & 1 deletion src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,22 @@
from dioptra.restapi.db import models
from dioptra.task_engine.type_registry import BUILTIN_TYPES

from .type_coercions import (
BOOLEAN_PARAM_TYPE,
FLOAT_PARAM_TYPE,
INTEGER_PARAM_TYPE,
STRING_PARAM_TYPE,
coerce_to_type,
)

LOGGER: BoundLogger = structlog.stdlib.get_logger()

EXPLICIT_GLOBAL_TYPES: Final[set[str]] = {
STRING_PARAM_TYPE,
BOOLEAN_PARAM_TYPE,
INTEGER_PARAM_TYPE,
FLOAT_PARAM_TYPE,
}
YAML_FILE_ENCODING: Final[str] = "utf-8"
YAML_EXPORT_SETTINGS: Final[dict[str, Any]] = {
"indent": 2,
Expand Down Expand Up @@ -107,7 +121,19 @@ def extract_parameters(
A dictionary of the entrypoint's parameters.
"""
log = logger or LOGGER.new() # noqa: F841
return {param.name: param.default_value for param in entrypoint.parameters}
parameters: dict[str, Any] = {}
for param in entrypoint.parameters:
default_value = param.default_value
parameters[param.name] = {
"default": coerce_to_type(x=default_value, type_name=param.parameter_type)
}

if param.parameter_type in EXPLICIT_GLOBAL_TYPES:
parameters[param.name]["type"] = (
_convert_parameter_type_to_task_engine_type(param.parameter_type)
)

return parameters


def extract_tasks(
Expand Down Expand Up @@ -212,3 +238,15 @@ def _build_task_outputs(
{output_param.name: output_param.parameter_type.name}
for output_param in output_parameters
]


def _convert_parameter_type_to_task_engine_type(parameter_type: str) -> Any:
conversion_map = {
"boolean": "boolean",
"string": "string",
"float": "number",
"integer": "integer",
"list": {"list": "any"},
"mapping": {"mapping": ["string", "any"]},
}
return conversion_map[parameter_type]
90 changes: 90 additions & 0 deletions src/dioptra/restapi/v1/workflows/lib/type_coercions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
import json
from typing import Any, Final, cast

JsonType = dict[str, Any] | list[Any]
GlobalParameterType = str | float | int | bool | dict[str, Any] | list[Any] | None


BOOLEAN_PARAM_TYPE: Final[str] = "boolean"
FLOAT_PARAM_TYPE: Final[str] = "float"
INTEGER_PARAM_TYPE: Final[str] = "integer"
LIST_PARAM_TYPE: Final[str] = "list"
MAPPING_PARAM_TYPE: Final[str] = "mapping"
STRING_PARAM_TYPE: Final[str] = "string"


def coerce_to_type(x: str | None, type_name: str) -> GlobalParameterType:
coerce_fn_registry = {
BOOLEAN_PARAM_TYPE: to_boolean_type,
FLOAT_PARAM_TYPE: to_float_type,
INTEGER_PARAM_TYPE: to_integer_type,
LIST_PARAM_TYPE: to_list_type,
MAPPING_PARAM_TYPE: to_mapping_type,
STRING_PARAM_TYPE: to_string_type,
}

if type_name not in coerce_fn_registry:
raise ValueError(f"Invalid parameter type: {type_name}.")

if x is None:
return None

coerce_fn = coerce_fn_registry[type_name]
return cast(GlobalParameterType, coerce_fn(x))


def to_string_type(x: str) -> str:
return x


def to_boolean_type(x: str) -> bool:
if x.lower() not in {"true", "false"}:
raise ValueError(f"Not a boolean: {x}")

return x.lower() == "true"


def to_float_type(x: str) -> float:
# TODO: Handle coercion failures
return float(x)


def to_integer_type(x: str) -> int:
# TODO: Handle coercion failures
return int(x)


def to_list_type(x: str) -> list[Any]:
# TODO: Handle coercion failures
x_coerced = cast(JsonType, json.loads(x))

if not isinstance(x_coerced, list):
raise ValueError(f"Not a list: {x}")

return x_coerced


def to_mapping_type(x: str) -> dict[str, Any]:
# TODO: Handle coercion failures
x_coerced = cast(JsonType, json.loads(x))

if not isinstance(x_coerced, dict):
raise ValueError(f"Not a mapping: {x}")

return x_coerced
8 changes: 5 additions & 3 deletions src/frontend/src/dialogs/EditParamDialog.vue
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@
const typeOptions = reactive([
'string',
'float',
'path',
'url',
'integer',
'boolean',
'list',
'mapping',
])
watch(showDialog, (newVal) => {
Expand All @@ -89,4 +91,4 @@
</script>
</script>
6 changes: 4 additions & 2 deletions src/frontend/src/views/CreateEntryPoint.vue
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@
const typeOptions = ref([
'string',
'float',
'path',
'uri',
'integer',
'boolean',
'list',
'mapping',
])
const basicInfoForm = ref(null)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/restapi/lib/db/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
ENTRY_POINT_PARAMETER_TYPES: Final[list[dict[str, str]]] = [
{"parameter_type": "string"},
{"parameter_type": "float"},
{"parameter_type": "path"},
{"parameter_type": "uri"},
{"parameter_type": "integer"},
{"parameter_type": "boolean"},
{"parameter_type": "list"},
{"parameter_type": "mapping"},
]
JOB_STATUS_TYPES: Final[list[dict[str, str]]] = [
{"status": "queued"},
Expand Down
12 changes: 11 additions & 1 deletion tests/unit/restapi/v1/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,17 @@ def registered_entrypoints(
{
"name": "entrypoint_param_3",
"defaultValue": "/path",
"parameterType": "path",
"parameterType": "string",
},
{
"name": "entrypoint_param_4",
"defaultValue": "1",
"parameterType": "integer",
},
{
"name": "entrypoint_param_5",
"defaultValue": "['a', 'b', {'c': 1}]",
"parameterType": "list",
},
]
plugin_ids = [registered_plugin_with_files["plugin"]["id"]]
Expand Down

0 comments on commit ae3ab68

Please sign in to comment.