Check class init arguments: test with checkify? #21943
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments 19 replies
-
A small up if someone is interested. Thanks |
Beta Was this translation helpful? Give feedback.
14 replies
-
So here is the solution thanks to Equinox: import equinox as eqx
@register_pytree_node_class
class Params:
def __init__(self, a, b):
jax.debug.print("__init__ Params(a={}, b={})",a,b)
a = eqx.error_if(a, a < 0, "a must be >= 0")
self._a = a
self._b = b
def __repr__(self):
return f"Params(a={self._a}, b={self._b})"
@property
def a(self):
return self._a
@property
def b(self):
return self._b
def tree_flatten(self):
children = (self._a, self._b)
aux_data = None
jax.debug.print("tree_flatten...")
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
jax.debug.print("tree_unflatten...")
return cls(*children) Here is the prints and Traceback
|
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I have a long standing pb to perform a test on arguments of class init for PyTrees.
Here a use case where Params has two arguments (a,b) and I would like to test that a>=0..
See below that checkify does not detect the second object instantiation error.
Have you an idea to do such test (nb. for the moment we leave to the user the responsability to instantiate with right args)
Thanks.
and I want to pass a jit test:
I get this result w/o any error detected:
Beta Was this translation helpful? Give feedback.
All reactions