Skip to content

Commit

Permalink
Merge pull request #38 from timweiland/sum_partial_deriv
Browse files Browse the repository at this point in the history
`PartialDerivative` rework
  • Loading branch information
timweiland authored Jul 12, 2023
2 parents bcb2d12 + d8a151c commit 5036ae9
Show file tree
Hide file tree
Showing 23 changed files with 634 additions and 802 deletions.
9 changes: 4 additions & 5 deletions experiments/0000_poisson_dirichlet_1d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@
" ddf_xs=-f_X_pde,\n",
" df_xs=(\n",
" linpde_gp.linfuncops.diffops.PartialDerivative(\n",
" domain_shape=(),\n",
" domain_index=(),\n",
" linpde_gp.linfuncops.diffops.MultiIndex(1)\n",
" )(u)(X_pde).mean\n",
" ),\n",
" color=\"C3\",\n",
Expand Down Expand Up @@ -924,7 +923,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "linpde-gp",
"display_name": "Python 3.10.9 ('rp')",
"language": "python",
"name": "python3"
},
Expand All @@ -938,11 +937,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "88aade6ae3c887346ad7959dbc8c013e14bde92b1226dcb94dccc773c12fdf89"
"hash": "e6e8c2dcdbd56f97393d936234f742b50420975831114bbb57672bab77d2e717"
}
}
},
Expand Down
18 changes: 13 additions & 5 deletions src/linpde_gp/linfuncops/_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import operator
from typing import Generic, TypeVar

import numpy as np
import probnum as pn
Expand Down Expand Up @@ -56,8 +57,11 @@ def __repr__(self) -> str:
return f"{self._scalar} * {self._linfuncop}"


class SumLinearFunctionOperator(LinearFunctionOperator):
def __init__(self, *summands: LinearFunctionOperator) -> None:
T = TypeVar("T", bound=LinearFunctionOperator)


class SumLinearFunctionOperator(LinearFunctionOperator, Generic[T]):
def __init__(self, *summands: T) -> None:
self._summands = tuple(summands)

input_domain_shape = self._summands[0].input_domain_shape
Expand Down Expand Up @@ -88,7 +92,7 @@ def __init__(self, *summands: LinearFunctionOperator) -> None:
)

@property
def summands(self) -> tuple[LinearFunctionOperator, ...]:
def summands(self) -> tuple[T, ...]:
return self._summands

@functools.singledispatchmethod
Expand All @@ -105,8 +109,8 @@ def __repr__(self):
return " + ".join(str(summand) for summand in self._summands)


class CompositeLinearFunctionOperator(LinearFunctionOperator):
def __init__(self, *linfuncops: LinearFunctionOperator) -> None:
class CompositeLinearFunctionOperator(LinearFunctionOperator, Generic[T]):
def __init__(self, *linfuncops: T) -> None:
assert all(
L0.input_shapes == L1.output_shapes
for L0, L1 in zip(linfuncops[:-1], linfuncops[1:])
Expand All @@ -127,6 +131,10 @@ def __call__(self, f, /, **kwargs):
f,
)

@property
def linfuncops(self) -> tuple[T, ...]:
return self._linfuncops

def __repr__(self) -> str:
return " @ ".join(repr(linfuncop) for linfuncop in self._linfuncops)

Expand Down
7 changes: 5 additions & 2 deletions src/linpde_gp/linfuncops/_select_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class SelectOutput(LinearFunctionOperator):
def __init__(
self,
input_shapes: tuple[ShapeLike, ShapeLike],
idx,
idx: tuple[int, ...] | int,
) -> None:
self._idx = idx

Expand All @@ -23,9 +23,12 @@ def __init__(
)

@property
def idx(self):
def idx(self) -> tuple[int, ...] | int:
return self._idx

@functools.singledispatchmethod
def __call__(self, f, /, **kwargs):
return super().__call__(f, **kwargs)

def __repr__(self) -> str:
return f"SelectOutput(idx={self.idx})"
7 changes: 2 additions & 5 deletions src/linpde_gp/linfuncops/diffops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from ._arithmetic import ScaledLinearDifferentialOperator
from ._coefficients import MultiIndex, PartialDerivativeCoefficients
from ._derivative import Derivative
from ._directional_derivative import (
DirectionalDerivative,
PartialDerivative,
TimeDerivative,
)
from ._directional_derivative import DirectionalDerivative
from ._heat import HeatOperator
from ._laplacian import Laplacian, SpatialLaplacian, WeightedLaplacian
from ._lindiffop import LinearDifferentialOperator
from ._partial_derivative import PartialDerivative, TimeDerivative

# isort: off
from . import _functions
Expand Down
3 changes: 3 additions & 0 deletions src/linpde_gp/linfuncops/diffops/_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ def __rmul__(self, other) -> LinearDifferentialOperator:
@functools.singledispatchmethod
def weak_form(self, test_basis, /):
return self._scalar * self._lindiffop.weak_form(test_basis)

def __repr__(self) -> str:
return f"{self._scalar} * {self._lindiffop}"
15 changes: 15 additions & 0 deletions src/linpde_gp/linfuncops/diffops/_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def from_index(
def order(self) -> int:
return np.sum(self._multi_index)

@functools.cached_property
def is_mixed(self) -> bool:
return np.count_nonzero(self._multi_index) > 1

@property
def array(self) -> np.ndarray:
return self._multi_index
Expand All @@ -54,6 +58,9 @@ def __eq__(self, __o: object) -> bool:
return NotImplemented
return np.all(self.array == __o.array)

def __repr__(self) -> str:
return f"MultiIndex({self._multi_index.tolist()})"


class PartialDerivativeCoefficients(Mapping[ShapeType, Mapping[MultiIndex, float]]):
r"""Partial derivative coefficients of a linear differential operator.
Expand Down Expand Up @@ -116,6 +123,14 @@ def __init__(
def num_entries(self) -> int:
return self._num_entries

@functools.cached_property
def has_mixed(self) -> bool:
return any(
multi_index.is_mixed
for codomain_idx in self._coefficient_dict
for multi_index in self._coefficient_dict[codomain_idx]
)

@property
def input_domain_shape(self) -> ShapeType:
return self._input_domain_shape
Expand Down
35 changes: 4 additions & 31 deletions src/linpde_gp/linfuncops/diffops/_derivative.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import functools
from typing import Callable

import jax
import probnum as pn

import linpde_gp # pylint: disable=unused-import # for type hints

from ._coefficients import MultiIndex, PartialDerivativeCoefficients
from ._lindiffop import LinearDifferentialOperator
from ._coefficients import MultiIndex
from ._partial_derivative import PartialDerivative


class Derivative(LinearDifferentialOperator):
class Derivative(PartialDerivative):
def __init__(
self,
order: int,
Expand All @@ -19,36 +17,11 @@ def __init__(
raise ValueError(f"Order must be >= 0, but got {order}.")

super().__init__(
coefficients=PartialDerivativeCoefficients(
{(): {MultiIndex(order): 1.0}}, (), ()
),
input_shapes=((), ()),
MultiIndex(order),
)

self._order = order

@property
def order(self) -> int:
return self._order

@functools.singledispatchmethod
def __call__(self, f, /, **kwargs):
if self.order == 0:
return f
return super().__call__(f, **kwargs)

def _jax_fallback(self, f: Callable, /, *, argnum: int = 0, **kwargs) -> Callable:
@jax.jit
def _f_deriv(*args):
def _f_arg(arg):
return f(*args[:argnum], arg, *args[argnum + 1 :])

_, deriv = jax.jvp(_f_arg, (args[argnum],), (1.0,))

return deriv

return _f_deriv

@functools.singledispatchmethod
def weak_form(
self, test_basis: pn.functions.Function, /
Expand Down
46 changes: 1 addition & 45 deletions src/linpde_gp/linfuncops/diffops/_directional_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import numpy as np
import probnum as pn
from probnum.typing import ArrayLike, ShapeLike
from probnum.typing import ArrayLike

import linpde_gp # pylint: disable=unused-import # for type hints

Expand Down Expand Up @@ -59,47 +59,3 @@ def weak_form(
self, test_basis: pn.functions.Function, /
) -> "linpde_gp.linfunctls.LinearFunctional":
raise NotImplementedError()


class PartialDerivative(DirectionalDerivative):
def __init__(
self,
domain_shape: ShapeLike,
domain_index,
) -> None:
direction = np.zeros(domain_shape)
direction[domain_index] = 1.0

super().__init__(direction)

@functools.singledispatchmethod
def __call__(self, f, /, **kwargs):
return super().__call__(f, **kwargs)

@functools.singledispatchmethod
def weak_form(
self, test_basis: pn.functions.Function, /
) -> "linpde_gp.linfunctls.LinearFunctional":
raise NotImplementedError()


class TimeDerivative(PartialDerivative):
def __init__(self, domain_shape: ShapeLike) -> None:
domain_shape = pn.utils.as_shape(domain_shape)

assert len(domain_shape) <= 1

super().__init__(
domain_shape,
domain_index=() if domain_shape == () else (0,),
)

@functools.singledispatchmethod
def __call__(self, f, /, **kwargs):
return super().__call__(f, **kwargs)

@functools.singledispatchmethod
def weak_form(
self, test_basis: pn.functions.Function, /
) -> "linpde_gp.linfunctls.LinearFunctional":
raise NotImplementedError()
4 changes: 2 additions & 2 deletions src/linpde_gp/linfuncops/diffops/_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from linpde_gp import functions

from ._derivative import Derivative
from ._directional_derivative import DirectionalDerivative
from ._laplacian import Laplacian, SpatialLaplacian
from ._partial_derivative import PartialDerivative


@Derivative.__call__.register # pylint: disable=no-member
@DirectionalDerivative.__call__.register # pylint: disable=no-member
@Laplacian.__call__.register # pylint: disable=no-member
@PartialDerivative.__call__.register # pylint: disable=no-member
@SpatialLaplacian.__call__.register # pylint: disable=no-member
def _(self, f: functions.Constant, /) -> functions.Zero:
assert f.input_shape == self.input_domain_shape
Expand Down
2 changes: 1 addition & 1 deletion src/linpde_gp/linfuncops/diffops/_heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from probnum.typing import FloatLike, ShapeLike

from .._arithmetic import SumLinearFunctionOperator
from ._directional_derivative import TimeDerivative
from ._laplacian import WeightedLaplacian
from ._partial_derivative import TimeDerivative


class HeatOperator(SumLinearFunctionOperator):
Expand Down
Loading

0 comments on commit 5036ae9

Please sign in to comment.