Skip to content

Commit

Permalink
Rename AbstractNodeInput to AbstractSchemaNode (#72)
Browse files Browse the repository at this point in the history
Rename the base node.
  • Loading branch information
eyurtsev authored Mar 19, 2023
1 parent 4c0ca0e commit 54572da
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 47 deletions.
6 changes: 3 additions & 3 deletions kor/encoders/csv_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import pandas as pd

from kor.encoders.typedefs import Encoder
from kor.nodes import AbstractInput, Object
from kor.nodes import AbstractSchemaNode, Object

DELIMITER = "|"


def _extract_top_level_fieldnames(node: AbstractInput) -> List[str]:
def _extract_top_level_fieldnames(node: AbstractSchemaNode) -> List[str]:
"""Temporary schema description for CSV extraction."""
if isinstance(node, Object):
return [attributes.id for attributes in node.attributes]
Expand All @@ -42,7 +42,7 @@ def _get_table_content(string: str) -> Optional[str]:
class CSVEncoder(Encoder):
"""CSV encoder."""

def __init__(self, node: AbstractInput) -> None:
def __init__(self, node: AbstractSchemaNode) -> None:
"""Attach node to the encoder to allow the encoder to understand schema."""
super().__init__(node)

Expand Down
4 changes: 2 additions & 2 deletions kor/encoders/typedefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import abc
from typing import Any

from kor.nodes import AbstractInput
from kor.nodes import AbstractSchemaNode


class Encoder(abc.ABC):
def __init__(self, node: AbstractInput) -> None:
def __init__(self, node: AbstractSchemaNode) -> None:
"""Attach node to the encoder to allow the encoder to understand schema."""
self.node = node

Expand Down
14 changes: 7 additions & 7 deletions kor/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Any, List, Tuple

from kor.nodes import (
AbstractInput,
AbstractSchemaNode,
AbstractVisitor,
ExtractionInput,
ExtractionSchemaNode,
Object,
Option,
Selection,
Expand All @@ -29,7 +29,7 @@ def visit_option(self, node: "Option") -> List[Tuple[str, str]]:
raise AssertionError("Should never visit an Option node.")

@staticmethod
def _assemble_output(node: AbstractInput, data: Any) -> Any:
def _assemble_output(node: AbstractSchemaNode, data: Any) -> Any:
"""Assemble the output data according to the type of the node."""
if not data:
return {}
Expand Down Expand Up @@ -79,9 +79,9 @@ def visit_selection(self, node: "Selection") -> List[Tuple[str, str]]:
examples.append((null_example, ""))
return examples

def visit_default(self, node: "AbstractInput") -> List[Tuple[str, str]]:
def visit_default(self, node: "AbstractSchemaNode") -> List[Tuple[str, str]]:
"""Default visitor implementation."""
if not isinstance(node, ExtractionInput):
if not isinstance(node, ExtractionSchemaNode):
raise AssertionError()
examples = []

Expand All @@ -90,15 +90,15 @@ def visit_default(self, node: "AbstractInput") -> List[Tuple[str, str]]:
examples.append((text, value))
return examples

def visit(self, node: "AbstractInput") -> List[Tuple[str, str]]:
def visit(self, node: "AbstractSchemaNode") -> List[Tuple[str, str]]:
"""Entry-point."""
return node.accept(self)


# PUBLIC API


def generate_examples(node: AbstractInput) -> List[Tuple[str, str]]:
def generate_examples(node: AbstractSchemaNode) -> List[Tuple[str, str]]:
"""Generate examples for a given element.
A rudimentary implementation that simply concatenates all available examples
Expand Down
27 changes: 14 additions & 13 deletions kor/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,22 @@ def visit_option(self, node: "Option") -> T:
"""Visit option node."""
return self.visit_default(node)

def visit_default(self, node: "AbstractInput") -> T:
def visit_default(self, node: "AbstractSchemaNode") -> T:
"""Default node implementation."""
raise NotImplementedError()


class AbstractInput(abc.ABC):
"""Abstract input node.
class AbstractSchemaNode(abc.ABC):
"""Abstract schema node.
Each input is expected to have a unique ID, and should
Each node is expected to have a unique ID, and should
only use alphanumeric characters.
The ID should be unique across all inputs that belong
to a given form.
The description should describe what the input is about.
The description should describe what the node represents.
It is used during prompt generation.
"""

__slots__ = "id", "description", "many"
Expand All @@ -79,7 +80,7 @@ def replace(
self,
id: Optional[str] = None, # pylint: disable=redefined-builtin
description: Optional[str] = None,
) -> "AbstractInput":
) -> "AbstractSchemaNode":
"""Wrapper around data-classes replace."""
new_object = copy.copy(self)
if id:
Expand All @@ -89,7 +90,7 @@ def replace(
return new_object


class ExtractionInput(AbstractInput, abc.ABC):
class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC):
"""An abstract definition for inputs that involve extraction.
An extraction input can be associated with extraction examples.
Expand Down Expand Up @@ -119,23 +120,23 @@ def __init__(
self.examples = examples


class Number(ExtractionInput):
class Number(ExtractionSchemaNode):
"""Built-in number input."""

def accept(self, visitor: AbstractVisitor[T]) -> T:
"""Accept a visitor."""
return visitor.visit_number(self)


class Text(ExtractionInput):
class Text(ExtractionSchemaNode):
"""Built-in text input."""

def accept(self, visitor: AbstractVisitor[T]) -> T:
"""Accept a visitor."""
return visitor.visit_text(self)


class Option(AbstractInput):
class Option(AbstractSchemaNode):
"""Built-in option input must be part of a selection input."""

__slots__ = ("examples",)
Expand All @@ -157,7 +158,7 @@ def accept(self, visitor: AbstractVisitor[T]) -> T:
return visitor.visit_option(self)


class Selection(AbstractInput):
class Selection(AbstractSchemaNode):
"""Built-in selection input.
A selection input is composed of one or more options.
Expand Down Expand Up @@ -193,7 +194,7 @@ def accept(self, visitor: AbstractVisitor[T]) -> T:
return visitor.visit_selection(self)


class Object(AbstractInput):
class Object(AbstractSchemaNode):
"""A definition for an object extraction.
An extraction input can be associated with 2 different types of examples:
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(
id: str,
description: str = "",
many: bool = True,
attributes: Sequence[Union[ExtractionInput, Selection]],
attributes: Sequence[Union[ExtractionSchemaNode, Selection]],
examples: Sequence[
Tuple[str, Mapping[str, Union[str, Sequence[str]]]]
] = tuple(),
Expand Down
8 changes: 5 additions & 3 deletions kor/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from kor.encoders import Encoder
from kor.encoders.encode import encode_examples
from kor.examples import generate_examples
from kor.nodes import AbstractInput
from kor.nodes import AbstractSchemaNode
from kor.type_descriptors import TypeDescriptor

PromptFormat = Union[Literal["openai-chat"], Literal["string"]]
Expand Down Expand Up @@ -80,12 +80,14 @@ def to_messages(self) -> List[BaseMessage]:
messages.append(HumanMessage(content=self.text))
return messages

def generate_encoded_examples(self, node: AbstractInput) -> List[Tuple[str, str]]:
def generate_encoded_examples(
self, node: AbstractSchemaNode
) -> List[Tuple[str, str]]:
"""Generate encoded examples."""
examples = generate_examples(node)
return encode_examples(examples, self.encoder)

def generate_instruction_segment(self, node: AbstractInput) -> str:
def generate_instruction_segment(self, node: AbstractSchemaNode) -> str:
"""Generate the instruction segment of the extraction."""
type_description = self.type_descriptor.describe(node)
instruction_segment = self.encoder.get_instruction_segment()
Expand Down
19 changes: 13 additions & 6 deletions kor/type_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import abc
from typing import List, TypeVar

from kor.nodes import AbstractInput, AbstractVisitor, Number, Object, Selection, Text
from kor.nodes import (
AbstractSchemaNode,
AbstractVisitor,
Number,
Object,
Selection,
Text,
)

T = TypeVar("T")

Expand All @@ -20,7 +27,7 @@ class TypeDescriptor(AbstractVisitor[T], abc.ABC):
"""Interface for type descriptors."""

@abc.abstractmethod
def describe(self, node: AbstractInput) -> str:
def describe(self, node: AbstractSchemaNode) -> str:
"""Take in node and describe its type as a string."""
raise NotImplementedError()

Expand All @@ -32,7 +39,7 @@ def __init__(self) -> None:
self.depth = 0
self.code_lines: List[str] = []

def visit_default(self, node: "AbstractInput") -> None:
def visit_default(self, node: "AbstractSchemaNode") -> None:
"""Default action for a node."""
space = "* " + self.depth * " "
self.code_lines.append(
Expand All @@ -51,7 +58,7 @@ def get_type_description(self) -> str:
"""Get the type."""
return "\n".join(self.code_lines)

def describe(self, node: AbstractInput) -> str:
def describe(self, node: AbstractSchemaNode) -> str:
"""Describe the type of the given node."""
self.code_lines = []
node.accept(self)
Expand All @@ -65,7 +72,7 @@ def __init__(self) -> None:
self.depth = 0
self.code_lines: List[str] = []

def visit_default(self, node: "AbstractInput") -> None:
def visit_default(self, node: "AbstractSchemaNode") -> None:
"""Default action for a node."""
space = self.depth * " "

Expand Down Expand Up @@ -106,7 +113,7 @@ def get_type_description(self) -> str:
"""Get the type."""
return "\n".join(self.code_lines)

def describe(self, node: "AbstractInput") -> str:
def describe(self, node: "AbstractSchemaNode") -> str:
"""Describe the node type in TypeScript notation."""
self.depth = 0
self.code_lines = []
Expand Down
4 changes: 2 additions & 2 deletions tests/test_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pytest

from kor.encoders import JSONEncoder, XMLEncoder
from kor.nodes import AbstractInput, Number, Object, Option, Selection, Text
from kor.nodes import AbstractSchemaNode, Number, Object, Option, Selection, Text


def _get_schema() -> AbstractInput:
def _get_schema() -> AbstractSchemaNode:
"""Make an abstract input node."""
option = Option(id="option", description="Option", examples=["selection"])
number = Number(id="number", description="Number", examples=[("number", "2")])
Expand Down
8 changes: 4 additions & 4 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kor.nodes import AbstractVisitor


class ToyInput(nodes.AbstractInput):
class ToySchemaNode(nodes.AbstractSchemaNode):
"""Toy input for tests."""

def accept(self, visitor: AbstractVisitor) -> Any:
Expand All @@ -16,19 +16,19 @@ def accept(self, visitor: AbstractVisitor) -> Any:
@pytest.mark.parametrize("invalid_id", ["", "@@#", " ", "NAME", "1name", "name-name"])
def test_invalid_identifier_raises_error(invalid_id: str) -> None:
with pytest.raises(ValueError):
ToyInput(id=invalid_id, description="Toy")
ToySchemaNode(id=invalid_id, description="Toy")


@pytest.mark.parametrize("valid_id", ["name", "name_name", "_name", "n1ame"])
def test_can_instantiate_with_valid_id(valid_id: str) -> None:
"""Can instantiate an abstract input with a valid ID."""
ToyInput(id=valid_id, description="Toy")
ToySchemaNode(id=valid_id, description="Toy")


def test_extraction_input_cannot_be_instantiated() -> None:
"""ExtractionInput is abstract and should not be instantiated."""
with pytest.raises(TypeError):
nodes.ExtractionInput( # type: ignore[abstract]
nodes.ExtractionSchemaNode( # type: ignore[abstract]
id="help",
description="description",
examples=[],
Expand Down
6 changes: 3 additions & 3 deletions tests/test_type_descriptors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from kor import Number, Object, Text
from kor.nodes import AbstractInput, Option, Selection
from kor.nodes import AbstractSchemaNode, Option, Selection
from kor.type_descriptors import BulletPointTypeGenerator, TypeScriptTypeGenerator

OPTION = Option(id="option", description="Option Description", examples=["selection"])
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_no_obvious_crashes() -> None:
),
],
)
def test_bullet_point_descriptions(node: AbstractInput, description: str) -> None:
def test_bullet_point_descriptions(node: AbstractSchemaNode, description: str) -> None:
"""Verify bullet point descriptions."""
assert BulletPointTypeGenerator().describe(node) == description

Expand Down Expand Up @@ -102,6 +102,6 @@ def test_bullet_point_descriptions(node: AbstractInput, description: str) -> Non
),
],
)
def test_typescript_description(node: AbstractInput, description: str) -> None:
def test_typescript_description(node: AbstractSchemaNode, description: str) -> None:
"""Verify typescript descriptions."""
assert TypeScriptTypeGenerator().describe(node) == description
8 changes: 4 additions & 4 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from kor.nodes import (
AbstractInput,
AbstractSchemaNode,
AbstractVisitor,
Number,
Object,
Expand All @@ -14,11 +14,11 @@
class TestVisitor(AbstractVisitor[str]):
"""Toy input for tests."""

def visit_default(self, node: AbstractInput) -> str:
def visit_default(self, node: AbstractSchemaNode) -> str:
"""Verify default is invoked"""
return node.id

def visit(self, node: AbstractInput) -> str:
def visit(self, node: AbstractSchemaNode) -> str:
"""Convenience method."""
return node.accept(self)

Expand All @@ -33,6 +33,6 @@ def visit(self, node: AbstractInput) -> str:
Option(id="uid"),
],
)
def test_visit_default_is_invoked(node: AbstractInput) -> None:
def test_visit_default_is_invoked(node: AbstractSchemaNode) -> None:
visitor = TestVisitor()
assert visitor.visit(node) == "uid"

0 comments on commit 54572da

Please sign in to comment.