Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all zips explicitly strict or non-strict #850

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

Armavica
Copy link
Member

@Armavica Armavica commented Jun 24, 2024

Description

  • First commit: adding a strict=True argument to all zips when it doesn't produce mistakes in the test suite (464 of them), and strict=False to the others (28 of them)
  • Second commit: enable ruff rule requiring and explicit strict argument to all zips
  • Rest of the commits: transform the non-strict zips into strict zips (18 of them for now)

There remains 10 non-strict zips that I find difficult to understand.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from ada9880 to a5de1b6 Compare June 26, 2024 22:40
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 27, 2024

@Armavica may be crazy work, but can we get a separate commit where we make the non-strict zips. That way it's easier to evaluate if it sounds correct or may be a bug somewhere?

Or I guess I can just ctrl+f for it

@Armavica
Copy link
Member Author

Armavica commented Jun 28, 2024

@ricardoV94 Yes, I was planning to present this PR(s) in several steps:

  • Commits that add strict=True to zips without tests failing
  • Commits that add strict=True to zips and fix their failures (bugs)
  • Commits that add strict=False to the remaining zips that need it, or rewrite them so they can be made strict
    How does that sound to you?

@ricardoV94
Copy link
Member

Sounds good @Armavica

@Armavica Armavica force-pushed the zip-strict branch 5 times, most recently from dc0aa6e to 36868a5 Compare June 29, 2024 06:32
Copy link

codecov bot commented Jun 29, 2024

Codecov Report

Attention: Patch coverage is 91.70306% with 19 lines in your changes missing coverage. Please review.

Project coverage is 81.05%. Comparing base (05d376f) to head (e3965ef).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #850      +/-   ##
==========================================
+ Coverage   81.04%   81.05%   +0.01%     
==========================================
  Files         170      170              
  Lines       46962    46983      +21     
  Branches    11507    11510       +3     
==========================================
+ Hits        38059    38082      +23     
- Misses       6694     6695       +1     
+ Partials     2209     2206       -3     
Files Coverage Δ
pytensor/compile/builders.py 88.42% <100.00%> (ø)
pytensor/compile/function/pfunc.py 82.92% <100.00%> (ø)
pytensor/compile/function/types.py 79.94% <100.00%> (ø)
pytensor/gradient.py 77.37% <100.00%> (ø)
pytensor/graph/basic.py 88.60% <100.00%> (ø)
pytensor/graph/op.py 87.89% <ø> (ø)
pytensor/graph/replace.py 84.21% <100.00%> (ø)
pytensor/graph/rewriting/basic.py 70.34% <100.00%> (ø)
pytensor/link/c/basic.py 87.48% <100.00%> (ø)
pytensor/link/c/cmodule.py 56.88% <100.00%> (ø)
... and 55 more

... and 2 files with indirect coverage changes

@jessegrabowski
Copy link
Member

Is this ready for review? Seems like all test are passing now

@Armavica
Copy link
Member Author

Armavica commented Jul 7, 2024

Is this ready for review? Seems like all test are passing now

There are still 11 10 instances of non-strict zips that produce errors if I make them strict, I need to investigate them one by one to see if that's expected behaviour or not.

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from 5746335 to e3965ef Compare July 7, 2024 13:05
@Armavica
Copy link
Member Author

Armavica commented Jul 7, 2024

Actually, I find it difficult to make more progress here, so I am signalling this for review.
I added 464 easy strict=True, 18 less immediate ones, and there are still 10 strict=False that I find the most difficult to understand. I think that they could be handled in another PR. This one introduces 464+18 = 482 safeguards, which I think is a good score :)

@Armavica Armavica marked this pull request as ready for review July 7, 2024 13:36
@@ -93,7 +93,7 @@ def _validate_updates(
)
else:
update = outputs
for i, u in zip(init, update, strict=False):
for i, u in zip(init[: len(update)], update, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use strict=False here?

@@ -1745,7 +1745,7 @@ def setup_method(self):
self.random_stream = np.random.default_rng(utt.fetch_seed())

self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)]
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)]
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was zipping input_shapes and filter_shapes together so the two last input_shapes were never being used in the tests.

@@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
val_slices = []

for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like my previous comment, I find this less readable. The strict=False indicates clearly that we don't expect the sequences to necessarily have the same length?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if they're not the same length why are we zipping them? Are we sure they're always ordered correctly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because they are ordered correctly yes, and presumably what comes after doesn't matter. It's quite common in Subtensor operations / rewrites

Copy link
Member Author

@Armavica Armavica Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind reverting this, but just to argue a bit in the favor of strict=True, I think an additional advantage of this approach is that it makes it clearer which one of the two lists is supposed to be shorter. I personally find that I understand more about what is happening here when I read this version compared to strict=False.

Comment on lines +594 to +589
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)

new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awkward, the use of strict=False seems fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised, this really looks better to me. What do you think of:

    if len(shape) != x.type.ndim:
        return _specify_shape(x, *shape)

    if all(s in (None, xts) for (s, xts) in zip(shape, x.type.shape, strict=True)):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My problem is the double call to SpecifyShape.

Also, we already established in the comment that if there's different lengths the function is going to raise so the strict=False follows naturally?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants