Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1128 improve color matching #1166

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
215 changes: 212 additions & 3 deletions adam/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from abc import abstractmethod
import logging
from math import sqrt
from typing import Tuple
from typing import Tuple, Iterable

from typing_extensions import Protocol

from adam.perception.perception_utils import dist
from attr import attrs, attrib
from attr.validators import instance_of

Expand Down Expand Up @@ -106,8 +107,131 @@ def match_score(self, value: float) -> float:
consistently across different matcher types.
"""
standard_deviation = sqrt(self.sample_variance)
return 2.0 * norm.cdf(
self._mean - abs(value - self._mean), loc=self._mean, scale=standard_deviation
return (
1.0
if standard_deviation == 0 and value == self._mean
else 2.0
* norm.cdf(
self._mean - abs(value - self._mean),
loc=self._mean,
scale=standard_deviation,
)
)

@staticmethod
def _calculate_new_values(
mean: float,
sum_of_squared_differences: float,
n_observations: int,
observation: float,
) -> Tuple[float, float]:
new_mean = mean + (observation - mean) / (n_observations + 1)
new_sum_squared = sum_of_squared_differences + (observation - mean) * (
observation - new_mean
)
return new_mean, new_sum_squared

def update_on_observation(self, value: float) -> None:
"""
Update the matcher's distribution to account for the given value.
"""
# With some help from Wikipedia. :)
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
new_mean, new_sum_squared = self._calculate_new_values(
self._mean, self._sum_of_squared_differences, self._n_observations, value
)

self._mean = new_mean
self._sum_of_squared_differences = new_sum_squared
self._n_observations += 1

def merge(self, other: "ContinuousValueMatcher") -> None:
# pylint: disable=protected-access
# Pylint doesn't realize the "client class" whose private members we're accessing is this
# same class
if isinstance(other, GaussianContinuousValueMatcher):
if self._n_observations == 1:
# Treat our own mean as a single observation (because it is) and calculate the new
# mean and sum of squares from the other matcher's perspective.
new_mean, new_sum_squared = self._calculate_new_values(
other._mean,
other._sum_of_squared_differences,
other._n_observations,
self._mean,
)
self._mean = new_mean
self._sum_of_squared_differences = new_sum_squared
self._n_observations += other._n_observations
elif other._n_observations == 1:
self.update_on_observation(other.mean)
else:
raise ValueError(
f"Cannot merge two matchers that both have multiple observations (self with "
f"{self._n_observations} and other with {other._n_observations})."
)
else:
raise ValueError(
f"Cannot merge {type(self)} with matcher of foreign type {type(other)}"
)


@attrs
class GaussianContinuousValueMatcher(ContinuousValueMatcher):
"""
Implements soft value matching where we pretend values come from a Gaussian distribution.
"""

_mean: float = attrib(validator=instance_of(float))
_sum_of_squared_differences: float = attrib(
validator=instance_of(float), init=False, default=0.0
)
"""
Also called M_{2,n} e.g. on the Wikipedia page.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
"""
_n_observations: int = attrib(validator=instance_of(int), init=False, default=1)

@property
def mean(self) -> float:
return self._mean

@property
def sample_variance(self) -> float:
if self._n_observations < 2:
return float("nan")
else:
return self._sum_of_squared_differences / (self._n_observations - 1)

@property
def n_observations(self) -> int:
return self._n_observations

@staticmethod
def from_observation(value) -> "GaussianContinuousValueMatcher":
"""
Return a new Gaussian continuous matcher created from the given single observation.

This exists for clarity more than anything.
"""
return GaussianContinuousValueMatcher(value)

def match_score(self, value: float) -> float:
"""
Return a score representing how closely the given value matches this distribution.

This score should fall into the interval [0, 1] so that learners can threshold scores
consistently across different matcher types.
"""
standard_deviation = sqrt(self.sample_variance)
return (
1.0
if standard_deviation == 0 and value == self._mean
else 2.0
* norm.cdf(
self._mean - abs(value - self._mean),
loc=self._mean,
scale=standard_deviation,
)
)

@staticmethod
Expand Down Expand Up @@ -165,3 +289,88 @@ def merge(self, other: "ContinuousValueMatcher") -> None:
raise ValueError(
f"Cannot merge {type(self)} with matcher of foreign type {type(other)}"
)


@attrs
class MultidimensionalGaussianContinuousValueMatcher(GaussianContinuousValueMatcher):
"""
Extend Gaussian continuous value matcher to >1 dimensional data
"""
_root_coordinates: Iterable[float] = attrib(validator = instance_of(tuple), init=True)
_mean: float = attrib(validator=instance_of(float), init=False, default=0.0)

@staticmethod
def from_observation(point: Iterable[float]) -> "MultidimensionalGaussianContinuousValueMatcher":
"""
Return a new Multidimensional Gaussian continuous matcher created from the given single point.

This exists for clarity more than anything.
"""
return MultidimensionalGaussianContinuousValueMatcher(point)


def match_score(self, point: Iterable[float]) -> float:
"""
Return a score representing how closely the given value matches this distribution.

This score should fall into the interval [0, 1] so that learners can threshold scores
consistently across different matcher types.
"""
standard_deviation = sqrt(self.sample_variance)
value = dist(self._root_coordinates, point)
return (
1.0
if standard_deviation == 0 and value == self._mean
else 0.0 if standard_deviation == 0 else
2.0 * norm.cdf(
self._mean - abs(value - self._mean),
loc=self._mean,
scale=standard_deviation,
)
)

def update_on_observation(self, point: Iterable[float]) -> None:
"""
Update the matcher's distribution to account for the given value.
"""
value = dist(self._root_coordinates, point)
# With some help from Wikipedia. :)
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
new_mean, new_sum_squared = self._calculate_new_values(
self._mean, self._sum_of_squared_differences, self._n_observations, value
)

self._mean = new_mean
self._sum_of_squared_differences = new_sum_squared
self._n_observations += 1

def merge(self, other: "ContinuousValueMatcher") -> None:
# pylint: disable=protected-access
# Pylint doesn't realize the "client class" whose private members we're accessing is this
# same class
if isinstance(other, MultidimensionalGaussianContinuousValueMatcher):
if self._n_observations == 1:
# Treat our own mean as a single observation (because it is) and calculate the new
# mean and sum of squares from the other matcher's perspective.
new_mean, new_sum_squared = self._calculate_new_values(
other._mean,
other._sum_of_squared_differences,
other._n_observations,
dist(self._root_coordinates, other._root_coordinates),
)
self._mean = new_mean
self._sum_of_squared_differences = new_sum_squared
self._n_observations += other._n_observations
self._root_coordinates = other._root_coordinates
elif other._n_observations == 1:
self.update_on_observation(other._root_coordinates)
else:
return
raise ValueError(
f"Cannot merge two matchers that both have multiple observations (self with "
f"{self._n_observations} and other with {other._n_observations})."
)
else:
raise ValueError(
f"Cannot merge {type(self)} with matcher of foreign type {type(other)}"
)
11 changes: 9 additions & 2 deletions adam/curriculum/curriculum_from_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@
PHASE_3_TRAINING_CURRICULUM_OPTIONS = [
"m4_core",
"m4_stretch",
"m5_objects_v0_with_mugs",
"m5_objects_v0_apples_oranges_bananas"
]

PHASE_3_TESTING_CURRICULUM_OPTIONS = ["m4_core_eval", "m4_stretch_eval"]
PHASE_3_TESTING_CURRICULUM_OPTIONS = [
"m4_core_eval",
"m4_stretch_eval",
"m5_objects_v0_with_mugs_eval",
"m5_objects_v0_apples_oranges_bananas_eval",
]

TRAINING_CUR = "training"
TESTING_CUR = "testing"
Expand Down Expand Up @@ -92,7 +99,7 @@ def phase3_load_from_disk( # pylint: disable=unused-argument
) as situation_description_file:
situation_description = yaml.safe_load(situation_description_file)
language_tuple = tuple(situation_description["language"].split(" "))
feature_yamls = sorted(situation_dir.glob("feature_*"))
feature_yamls = sorted(situation_dir.glob("feature*"))
situation = SimulationSituation(
language=language_tuple,
scene_images_png=sorted(situation_dir.glob("rgb_*")),
Expand Down
19 changes: 18 additions & 1 deletion adam/perception/perception_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
CategoricalNode,
ContinuousNode,
RgbColorNode,
CielabColorNode,
GraphNode,
ObjectStroke,
StrokeGNNRecognitionNode,
Expand All @@ -110,6 +111,7 @@
CategoricalPredicate,
ContinuousPredicate,
RgbColorPredicate,
CielabColorPredicate,
ObjectStrokePredicate,
StrokeGNNRecognitionPredicate,
DistributionalContinuousPredicate,
Expand Down Expand Up @@ -585,6 +587,8 @@ def _to_dot_node(
label = f"axis:{unwrapped_perception_node.debug_name}"
elif isinstance(unwrapped_perception_node, RgbColorPerception):
label = unwrapped_perception_node.hex
elif isinstance(unwrapped_perception_node, RgbColorPerception):
label = unwrapped_perception_node.hex
elif isinstance(unwrapped_perception_node, OntologyNode):
label = unwrapped_perception_node.handle
elif isinstance(unwrapped_perception_node, Geon):
Expand All @@ -605,7 +609,13 @@ def _to_dot_node(
label = f"Stroke: [{', '.join(str(point) for point in unwrapped_perception_node.normalized_coordinates)}]"
elif isinstance(
unwrapped_perception_node,
(ContinuousNode, CategoricalNode, RgbColorNode, StrokeGNNRecognitionNode),
(
ContinuousNode,
CategoricalNode,
RgbColorNode,
CielabColorNode,
StrokeGNNRecognitionNode,
),
):
label = str(unwrapped_perception_node)
else:
Expand Down Expand Up @@ -1058,6 +1068,10 @@ def map_node(node: Any) -> "NodePredicate":
perception_node_to_pattern_node[key] = RgbColorPredicate.from_node(
node
)
elif isinstance(node, CielabColorNode):
perception_node_to_pattern_node[key] = CielabColorPredicate.from_node(
node, min_match_score=min_continuous_feature_match_score
)
elif isinstance(node, ObjectStroke):
perception_node_to_pattern_node[
key
Expand Down Expand Up @@ -2279,10 +2293,12 @@ def _translate_region(
DistributionalContinuousPredicate,
ContinuousPredicate,
RgbColorPredicate,
CielabColorPredicate,
ObjectStroke,
CategoricalNode,
ContinuousNode,
RgbColorNode,
CielabColorNode,
# Paths are rare, match them next
IsPathPredicate,
PathOperatorPredicate,
Expand Down Expand Up @@ -2315,6 +2331,7 @@ def _pattern_matching_node_order(node_node_data_tuple) -> int:
CategoricalNode,
ContinuousNode,
RgbColorNode,
CielabColorNode,
SpatialPath,
PathOperator,
OntologyNode,
Expand Down
50 changes: 50 additions & 0 deletions adam/perception/perception_graph_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from attr import attrs, attrib
from attr.validators import instance_of, in_, optional, deep_iterable
from colormath.color_conversions import convert_color
from colormath.color_objects import sRGBColor, LabColor
from immutablecollections import ImmutableSet
from immutablecollections.converter_utils import _to_immutableset

Expand Down Expand Up @@ -126,6 +128,54 @@ def __str__(self) -> str:
return f"#{hex(self.red)[2:]}{hex(self.green)[2:]}{hex(self.blue)[2:]}"


@attrs(frozen=True, slots=True, eq=False)
class CielabColorNode(GraphNode):
"""A node representing a CIELAB perception value."""

lab_l: float = attrib(validator=instance_of(float))
lab_a: float = attrib(validator=instance_of(float))
lab_b: float = attrib(validator=instance_of(float))

def dot_label(self):
return f"CielabColorNode({self})"

def __str__(self) -> str:
return f"Lab=({self.lab_l:.2f}, {self.lab_a:.2f}, {self.lab_b:.2f})"

def to_tuple(self) -> tuple:
return self.lab_l, self.lab_a, self.lab_b

@staticmethod
def from_rgb(node: RgbColorNode) -> "CielabColorNode":
rgb_color: sRGBColor = sRGBColor(
node.red, node.green, node.blue, is_upscaled=True
)
lab_color: LabColor = convert_color(
color=rgb_color, target_cs=LabColor, target_illuminant="d65"
)
return CielabColorNode(
lab_l=lab_color.lab_l,
lab_a=lab_color.lab_a,
lab_b=lab_color.lab_b,
weight=node.weight,
)

@staticmethod
def from_colors(
red: float, green: float, blue: float, *, weight: float = 1.0
) -> "CielabColorNode":
rgb_color: sRGBColor = sRGBColor(red, green, blue, is_upscaled=True)
lab_color: LabColor = convert_color(
color=rgb_color, target_cs=LabColor, target_illuminant="d65"
)
return CielabColorNode(
lab_l=lab_color.lab_l,
lab_a=lab_color.lab_a,
lab_b=lab_color.lab_b,
weight=weight,
)


@attrs(frozen=True, slots=True, eq=False)
class StrokeGNNRecognitionNode(GraphNode):
"""A property node indicating Stroke GNN object recognition."""
Expand Down
Loading