Skip to content

Commit

Permalink
Merge pull request #37 from marvinpfoertner/shape-validation
Browse files Browse the repository at this point in the history
Shape validation for `LinearFunction{al,Operator}.__call__(CovarianceFunction)`
  • Loading branch information
marvinpfoertner authored Jun 27, 2023
2 parents 9e49f25 + 2b18168 commit bcb2d12
Show file tree
Hide file tree
Showing 11 changed files with 790 additions and 632 deletions.
18 changes: 0 additions & 18 deletions src/linpde_gp/linfuncops/_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,3 @@ def _(self, f: pn.randprocs.RandomProcess, /) -> pn.randprocs.RandomProcess:
raise ValueError()

return f

@__call__.register(pn.randprocs.covfuncs.CovarianceFunction)
def _(
self, k: pn.randprocs.covfuncs.CovarianceFunction, /, argnum: int = 0
) -> pn.randprocs.covfuncs.CovarianceFunction:
if argnum not in (0, 1):
raise ValueError()

input_shape = k.input_shape_0 if argnum == 0 else k.input_shape_1
output_shape = k.output_shape_0 if argnum == 0 else k.output_shape_1

if input_shape != self.input_domain_shape:
raise ValueError()

if output_shape != self.input_codomain_shape:
raise ValueError()

return k
3 changes: 2 additions & 1 deletion src/linpde_gp/randprocs/covfuncs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ._parametric import ParametricCovarianceFunction
from ._stack import StackCovarianceFunction
from ._tensor_product import TensorProduct, TensorProductGrid
from ._utils import validate_covfunc_transformation
from ._wendland import WendlandCovarianceFunction, WendlandFunction, WendlandPolynomial
from ._zero import Zero

from . import _linfunctls, linfuncops # isort: skip
from . import linfuncops, linfunctls # isort: skip
24 changes: 0 additions & 24 deletions src/linpde_gp/randprocs/covfuncs/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,13 @@

import abc
from collections.abc import Callable
import functools
import operator
from typing import Optional

from jax import numpy as jnp
import numpy as np
from probnum.randprocs.covfuncs import CovarianceFunction
from probnum.typing import ArrayLike

from ... import linfuncops

CovarianceFunction.input_size = property(
lambda self: functools.reduce(operator.mul, self.input_shape, 1)
)

CovarianceFunction._batched_sum = ( # pylint: disable=protected-access
lambda self, a, **sum_kwargs: np.sum(
a, axis=tuple(range(-self.input_ndim, 0)), **sum_kwargs
Expand Down Expand Up @@ -155,22 +147,6 @@ def _euclidean_distances_jax(
)


@linfuncops.LinearDifferentialOperator.__call__.register # pylint: disable=no-member
def _(self, k: JaxCovarianceFunctionMixin, /, *, argnum=0):
try:
return super(linfuncops.LinearDifferentialOperator, self).__call__(
k, argnum=argnum
)
except NotImplementedError:
return JaxLambdaCovarianceFunction(
self._jax_fallback( # pylint: disable=protected-access
k.jax, argnum=argnum
),
input_shape=self.output_domain_shape,
vectorize=True,
)


class JaxLambdaCovarianceFunction(JaxCovarianceFunction):
def __init__(
self,
Expand Down
17 changes: 0 additions & 17 deletions src/linpde_gp/randprocs/covfuncs/_jax_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from probnum.typing import ArrayLike, ScalarLike, ScalarType

from ... import linfuncops
from ._jax import JaxCovarianceFunction, JaxCovarianceFunctionMixin


Expand Down Expand Up @@ -42,13 +41,6 @@ def __rmul__(self, other: ArrayLike) -> JaxCovarianceFunction:
return super().__rmul__(other)


@linfuncops.LinearFunctionOperator.__call__.register # pylint: disable=no-member
def _(
self, k: JaxScaledCovarianceFunction, /, *, argnum: int = 0
) -> JaxScaledCovarianceFunction:
return k.scalar * self(k.covfunc, argnum=argnum)


class JaxSumCovarianceFunction(JaxCovarianceFunctionMixin, SumCovarianceFunction):
def __init__(self, *summands: JaxCovarianceFunction):
if not all(
Expand All @@ -67,12 +59,3 @@ def _evaluate_jax(self, x0: jnp.ndarray, x1: Optional[jnp.ndarray]) -> jnp.ndarr
operator.add,
(summand.jax(x0, x1) for summand in self.summands),
)


@linfuncops.LinearFunctionOperator.__call__.register # pylint: disable=no-member
def _(
self, k: JaxSumCovarianceFunction, /, *, argnum: int = 0
) -> JaxSumCovarianceFunction:
return JaxSumCovarianceFunction(
*(self(summand, argnum=argnum) for summand in k.summands)
)
226 changes: 0 additions & 226 deletions src/linpde_gp/randprocs/covfuncs/_linfunctls.py

This file was deleted.

38 changes: 38 additions & 0 deletions src/linpde_gp/randprocs/covfuncs/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from probnum.randprocs import covfuncs as pn_covfuncs

from linpde_gp import linfuncops, linfunctls


def validate_covfunc_transformation(
L: linfuncops.LinearFunctionOperator | linfunctls.LinearFunctional,
covfunc: pn_covfuncs.CovarianceFunction,
argnum: int,
):
if argnum not in (0, 1):
raise ValueError("`argnum` must either be 0 or 1.")

# Check if the input shape of the covariance function matches the input domain shape
# of the linear transformation
covfunc_input_shape_argnum = (
covfunc.input_shape_0 if argnum == 0 else covfunc.input_shape_1
)

if covfunc_input_shape_argnum != L.input_domain_shape:
raise ValueError(
f"`{L=!r}` can not be applied to `{covfunc=!r}`, since "
f"`{L.input_domain_shape=}` is not equal to "
f"`covfunc.input_shape_{argnum}={covfunc_input_shape_argnum}`."
)

# Check if the output shape of the covariance function matches the input codomain
# shape of the linear transformation
covfunc_output_shape_argnum = (
covfunc.output_shape_0 if argnum == 0 else covfunc.output_shape_1
)

if covfunc_output_shape_argnum != L.input_codomain_shape:
raise ValueError(
f"`{L=!r}` can not be applied to `{covfunc=!r}`, since "
f"`{L.input_codomain_shape=}` is not equal to "
f"`covfunc.output_shape_{argnum}={covfunc_output_shape_argnum}`."
)
Loading

0 comments on commit bcb2d12

Please sign in to comment.