Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch conditionals: IfElse #940

Closed
wants to merge 73 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
3c4f73d
Add IfElse
Jul 17, 2024
bfb97ea
Remove space
Jul 17, 2024
6ad1c5c
Implement Dot and BatchedDot in PyTensor (#878)
HangenYuu Jul 18, 2024
cac9feb
Add `OpFromGraph` wrapper around `alloc_diag` (#915)
jessegrabowski Jul 18, 2024
ad27dc7
Bump actions/upload-artifact from 3 to 4 (#560)
dependabot[bot] Jul 18, 2024
f489cf4
Added rewrite for matrix inv(inv(x)) -> x (#893)
tanish1729 Jul 19, 2024
981688c
Implement `pad` (#748)
jessegrabowski Jul 19, 2024
a601a27
Update away from torch.where
Jul 21, 2024
aab9fae
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 22, 2024
739d97d
Removed unused config options
Armavica Jul 19, 2024
b9f2dde
Remove add_experimental_configvars
Armavica Jul 19, 2024
f9f5c5b
Remove default in_c_key and change for cast_policy
Armavica Jul 19, 2024
158a7d0
Fix typo in docstring
Armavica Jul 19, 2024
7a0175a
Simplify _ChangeFlagDecorator
Armavica Jul 19, 2024
d9ed1e2
Fix typo amblibm -> amdlibm
Armavica Jul 19, 2024
9f4b89d
Remove unused ContextsParam
Armavica Jul 19, 2024
d455460
Simplify config.add(linker)
Armavica Jul 19, 2024
367351f
Fixed dead wiki links (#950)
HangenYuu Jul 25, 2024
58fec45
Implement nlinalg Ops in PyTorch (#920)
twaclaw Jul 26, 2024
7fd8cbd
Update for m1
Jul 17, 2024
a5587a7
Add new env file
Jul 21, 2024
a09fa75
Update comment
Jul 21, 2024
d6254af
Update environment-osx-arm64.yml
twiecki Jul 22, 2024
23427a0
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 29, 2024
f25a624
Implement Einsum
jessegrabowski Apr 19, 2024
b65d08c
Skip tri test in latest version of JAX
ricardoV94 Aug 4, 2024
da91dc7
Corrected the reference from 'an PyTensor' to 'a PyTensor' in the con…
abhishekshah5486 Aug 5, 2024
0ae3cfe
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Aug 5, 2024
48450b0
Fix test to allow for n_outs>1
Aug 9, 2024
fd27b6a
Remove test value
Aug 9, 2024
7fffec6
Pickle error message changed (#966)
twiecki Aug 10, 2024
29183c7
Add building of pyodide universal wheels (#918)
twiecki Aug 10, 2024
4d0103b
Removed types examples and introduced tensor (#968)
Krupakar-Reddy-S Aug 12, 2024
f62401a
maintanance: unpin scipy
ferrine Aug 13, 2024
dd8895d
mypy: fix graph.py
ferrine Aug 14, 2024
a3f0a4e
mypy: fix graph/basic.py
ferrine Aug 14, 2024
79232b2
Implement Dot and BatchedDot in PyTensor (#878)
HangenYuu Jul 18, 2024
143ded6
Add `OpFromGraph` wrapper around `alloc_diag` (#915)
jessegrabowski Jul 18, 2024
8c30780
Bump actions/upload-artifact from 3 to 4 (#560)
dependabot[bot] Jul 18, 2024
297bdd4
Added rewrite for matrix inv(inv(x)) -> x (#893)
tanish1729 Jul 19, 2024
a4e014e
Implement `pad` (#748)
jessegrabowski Jul 19, 2024
8d25c14
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 22, 2024
6fcc37c
Removed unused config options
Armavica Jul 19, 2024
39612d1
Remove add_experimental_configvars
Armavica Jul 19, 2024
9571d4f
Remove default in_c_key and change for cast_policy
Armavica Jul 19, 2024
ab4f150
Fix typo in docstring
Armavica Jul 19, 2024
153d209
Simplify _ChangeFlagDecorator
Armavica Jul 19, 2024
3aaf756
Fix typo amblibm -> amdlibm
Armavica Jul 19, 2024
1b2802e
Remove unused ContextsParam
Armavica Jul 19, 2024
9c6748f
Simplify config.add(linker)
Armavica Jul 19, 2024
9973e03
Fixed dead wiki links (#950)
HangenYuu Jul 25, 2024
286c8fc
Implement nlinalg Ops in PyTorch (#920)
twaclaw Jul 26, 2024
70c902b
Update for m1
Jul 17, 2024
bd607f3
Add new env file
Jul 21, 2024
3249ae2
Update comment
Jul 21, 2024
d2ad1ed
Update environment-osx-arm64.yml
twiecki Jul 22, 2024
f11df4a
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 29, 2024
a7c099c
Implement Einsum
jessegrabowski Apr 19, 2024
6112f82
Skip tri test in latest version of JAX
ricardoV94 Aug 4, 2024
cd8585d
Corrected the reference from 'an PyTensor' to 'a PyTensor' in the con…
abhishekshah5486 Aug 5, 2024
bd38216
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Aug 5, 2024
521b8ca
Pickle error message changed (#966)
twiecki Aug 10, 2024
917cc55
Add building of pyodide universal wheels (#918)
twiecki Aug 10, 2024
e879b0c
Removed types examples and introduced tensor (#968)
Krupakar-Reddy-S Aug 12, 2024
3523d79
maintanance: unpin scipy
ferrine Aug 13, 2024
400323f
mypy: fix graph.py
ferrine Aug 14, 2024
f0214a1
mypy: fix graph/basic.py
ferrine Aug 14, 2024
9f3a938
Add IfElse
Jul 17, 2024
d36d4ce
Remove space
Jul 17, 2024
9adbbe2
Update away from torch.where
Jul 21, 2024
2766457
Fix test to allow for n_outs>1
Aug 9, 2024
ef9277b
Remove test value
Aug 9, 2024
d4aaeaf
Merge branch 'branches' of github.com:Ch0ronomato/pytensor into branches
Aug 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ jobs:
- name: Build SDist
run: pipx run build --sdist

- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: sdist
path: dist/*.tar.gz

build_wheels:
name: Build ${{ matrix.python-version }} wheels on ${{ matrix.platform }}
name: Build wheels for ${{ matrix.platform }}
runs-on: ${{ matrix.platform }}
strategy:
matrix:
Expand All @@ -51,19 +52,52 @@ jobs:
- name: Build wheels
uses: pypa/[email protected]

- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.platform }}
path: ./wheelhouse/*.whl

build_universal_wheel:
name: Build universal wheel for Pyodide
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: pip install numpy versioneer wheel

- name: Build universal wheel
run: |
PYODIDE=1 python setup.py bdist_wheel --universal

- uses: actions/upload-artifact@v4
with:
name: universal_wheel
path: dist/*.whl

check_dist:
name: Check dist
needs: [make_sdist,build_wheels]
runs-on: ubuntu-22.04
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@v4
with:
name: artifact
name: sdist
path: dist

- uses: actions/download-artifact@v4
with:
pattern: wheels-*
path: dist
merge-multiple: true

- name: Check SDist
run: |
mkdir -p test-sdist
Expand All @@ -83,12 +117,23 @@ jobs:
runs-on: ubuntu-latest
if: github.event_name == 'release' && github.event.action == 'published'
steps:
- uses: actions/download-artifact@v3
with:
name: artifact
path: dist

- uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.pypi_password }}
- uses: actions/download-artifact@v4
with:
name: sdist
path: dist

- uses: actions/download-artifact@v4
with:
pattern: wheels-*
path: dist
merge-multiple: true

- uses: actions/download-artifact@v4
with:
name: universal_wheel
path: dist

- uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.pypi_password }}
9 changes: 5 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ jobs:
FLOAT32: ${{ matrix.float32 }}

- name: Upload coverage file
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: coverage
name: coverage-${{ steps.matrix-id.outputs.id }}
path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml

benchmarks:
Expand Down Expand Up @@ -273,10 +273,11 @@ jobs:
python -m pip install -U coverage>=5.1 coveralls

- name: Download coverage file
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: coverage
pattern: coverage-*
path: coverage
merge-multiple: true

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
rev: v0.5.6
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ For issues a minimal working example (MWE) is strongly recommended when relevant
(fixing a typo in the documentation does not require a MWE). For discussions,
MWEs are generally required. All MWEs must be implemented using PyTensor. Please
do not submit MWEs if they are not implemented in PyTensor. In certain cases,
pseudocode may be acceptable, but an PyTensor implementation is always preferable.
pseudocode may be acceptable, but a PyTensor implementation is always preferable.

## Quick links

Expand Down
4 changes: 2 additions & 2 deletions doc/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ to extend PyTensor, please feel free to ask.
install
tutorial/index

.. _LISA: https://mila.umontreal.ca/
.. _LISA: https://mila.quebec/en
.. _Greek mathematician: http://en.wikipedia.org/wiki/Theano_(mathematician)
.. _numpy: http://numpy.scipy.org/
.. _numpy: https://numpy.org/
.. _BLAS: http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms

.. _sympy: http://www.sympy.org/
Expand Down
45 changes: 1 addition & 44 deletions doc/library/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,6 @@ import ``pytensor`` and print the config variable, as in:

String value: either ``'cpu'``

.. attribute:: force_device

Bool value: either ``True`` or ``False``

Default: ``False``

This flag's value cannot be modified during the program execution.

.. attribute:: print_active_device

Bool value: either ``True`` or ``False``
Expand Down Expand Up @@ -139,16 +131,6 @@ import ``pytensor`` and print the config variable, as in:
equal to ``float64`` is created.
This can be used to help find upcasts to ``float64`` in user code.

.. attribute:: deterministic

String value: either ``'default'``, ``'more'``

Default: ``'default'``

If ``more``, sometimes PyTensor will select :class:`Op` implementations that
are more "deterministic", but slower. See the ``dnn.conv.algo*``
flags for more cases.

.. attribute:: allow_gc

Bool value: either ``True`` or ``False``
Expand Down Expand Up @@ -373,7 +355,7 @@ import ``pytensor`` and print the config variable, as in:

When ``True``, ignore the first call to an PyTensor function while profiling.

.. attribute:: config.lib__amblibm
.. attribute:: config.lib__amdlibm

Bool value: either ``True`` or ``False``

Expand Down Expand Up @@ -412,16 +394,6 @@ import ``pytensor`` and print the config variable, as in:
ignore it (i.e. ``'ignore'``).
We suggest never using ``'ignore'`` except during testing.

.. attribute:: assert_no_cpu_op

String value: ``'ignore'`` or ``'warn'`` or ``'raise'`` or ``'pdb'``

Default: ``'ignore'``

If there is a CPU :class:`Op` in the computational graph, depending on its value,
this flag can either raise a warning, an exception or drop into the frame
with ``pdb``.

.. attribute:: on_shape_error

String value: ``'warn'`` or ``'raise'``
Expand Down Expand Up @@ -797,18 +769,3 @@ import ``pytensor`` and print the config variable, as in:
The verbosity level of the meta-rewriter: ``0`` for silent, ``1`` to only
warn when PyTensor cannot meta-rewrite an :class:`Op`, ``2`` for full output (e.g.
timings and the rewrites selected).


.. attribute:: config.metaopt__optimizer_excluding

Default: ``""``

A list of rewrite tags that we don't want included in the meta-rewriter.
Multiple tags are separate by ``':'``.

.. attribute:: config.metaopt__optimizer_including

Default: ``""``

A list of rewriter tags to be included during meta-rewriting.
Multiple tags are separate by ``':'``.
14 changes: 7 additions & 7 deletions doc/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ This is a sort of memo for developers and would-be developers.

.. _git: http://git-scm.com/
.. _pytest: http://docs.pytest.org/en/latest/
.. _numpy: http://numpy.scipy.org/
.. _numpy: https://numpy.org/
.. _python: http://www.python.org
.. _scipy: http://scipy.org/

.. _autodiff: http://www.autodiff.org
.. _boost.python: http://www.boost.org/doc/libs/1_38_0/libs/python/doc/index.html
.. _boost.python: https://www.boost.org/doc/libs/1_85_0/libs/python/doc/html/index.html
.. _cython: http://www.cython.org/
.. _liboil: http://liboil.freedesktop.org/wiki/
.. _llvm: http://llvm.org/
.. _networkx: http://networkx.lanl.gov/
.. _pypy: http://codespeak.net/pypy/dist/pypy/doc/
.. _networkx: https://networkx.org/
.. _pypy: https://doc.pypy.org/en/latest/
.. _swig: http://www.swig.org/
.. _unpython: http://code.google.com/p/unpython/
.. _pycppad: http://www.seanet.com/~bradbell/pycppad/index.xml
.. _shedskin: http://shed-skin.blogspot.com/
.. _unpython: https://code.google.com/archive/p/unpython/
.. _pycppad: https://github.com/Simple-Robotics/pycppad
.. _shedskin: https://shedskin.github.io/shedskin/
Loading
Loading