Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding conditionals for torch #939

Open
3 of 5 tasks
Ch0ronomato opened this issue Jul 17, 2024 · 3 comments
Open
3 of 5 tasks

Adding conditionals for torch #939

Ch0ronomato opened this issue Jul 17, 2024 · 3 comments

Comments

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Jul 17, 2024

Description

Add the branching ops

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Jul 25, 2024

Hey @ricardoV94 , could I get some clarity on scalar loop? I was under the impression that it might just work (I don't see any explicit tests for numba or jax) - what is the work needed for scalar loop? Here is an example test I wrote, that also maybe invalid

def test_ScalarOp():
    n_steps = int64("n_steps")
    x0 = float64("x0")
    const = float64("const")
    x = x0 + const

    op = ScalarLoop(init=[x0], constant=[const], update=[x])
    x = op(n_steps, x0, const)

    fn = function([n_steps, x0, const], x, mode=pytorch_mode)
    np.testing.assert_allclose(fn(5, 0, 1), 5)
    np.testing.assert_allclose(fn(5, 0, 2), 10)
    np.testing.assert_allclose(fn(4, 3, -1), -1)
op = ScalarLoop(), node = ScalarLoop(n_steps, x0, const)
kwargs = {'input_storage': [[None], [None], [None]], 'output_storage': [[None]], 'storage_map': {ScalarLoop.0: [None], const: [None], x0: [None], n_steps: [None]}}
nfunc_spec = None

    @pytorch_funcify.register(ScalarOp)
    def pytorch_funcify_ScalarOp(op, node, **kwargs):
        """Return pytorch function that implements the same computation as the Scalar Op.
    
        This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
        even though it's dispatched on the Scalar Op.
        """
    
        nfunc_spec = getattr(op, "nfunc_spec", None)
        if nfunc_spec is None:
>           raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
E           NotImplementedError: Dispatch not implemented for Scalar Op ScalarLoop

pytensor/link/pytorch/dispatch/scalar.py:19: NotImplementedError

@ricardoV94
Copy link
Member

You haven't seen JAX/Numba code because scalar loop isn't yet supported in those backends either.

I suggest checking the perform method to have an idea of how the Operator works

@ricardoV94
Copy link
Member

For Blockwise you should be able to use vmap repeatedly for each batch dimension. If they would have an equivalent to np.vectorize that would be all we need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants