Skip to content

Commit

Permalink
Add missing covfunc transform validations
Browse files Browse the repository at this point in the history
  • Loading branch information
timweiland committed Jul 3, 2023
1 parent 3c883e7 commit 76f5424
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _(
*,
argnum: int = 0,
):
validate_covfunc_transformation(self, k, argnum)

D0 = (
self
if argnum == 0
Expand All @@ -61,6 +63,8 @@ def _(
*,
argnum: int = 0,
):
validate_covfunc_transformation(self, k, argnum)

if argnum == 0 and k.L0.order == 0:
D0 = self
D1 = k.L1
Expand All @@ -79,6 +83,8 @@ def _(

@diffops.JaxPartialDerivative.__call__.register # pylint: disable=no-member
def _(self, k: covfuncs.JaxCovarianceFunction, /, *, argnum=0):
validate_covfunc_transformation(self, k, argnum)

return covfuncs.JaxLambdaCovarianceFunction(
self._derive(k.jax, argnum=argnum), # pylint: disable=protected-access
input_shape=self.output_domain_shape,
Expand All @@ -95,6 +101,7 @@ def _partial_derivative_fallback(
D: diffops.PartialDerivative, k: pn_covfuncs.CovarianceFunction, argnum: int = 0
):
validate_covfunc_transformation(D, k, argnum)

if D.order == 0:
return k
if int(np.prod(D.input_domain_shape)) == 1:
Expand Down

0 comments on commit 76f5424

Please sign in to comment.