-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make all zips explicitly strict or non-strict #850
base: main
Are you sure you want to change the base?
Conversation
ada9880
to
a5de1b6
Compare
@Armavica may be crazy work, but can we get a separate commit where we make the non-strict zips. That way it's easier to evaluate if it sounds correct or may be a bug somewhere? Or I guess I can just ctrl+f for it |
@ricardoV94 Yes, I was planning to present this PR(s) in several steps:
|
Sounds good @Armavica |
dc0aa6e
to
36868a5
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #850 +/- ##
==========================================
+ Coverage 81.04% 81.05% +0.01%
==========================================
Files 170 170
Lines 46962 46983 +21
Branches 11507 11510 +3
==========================================
+ Hits 38059 38082 +23
- Misses 6694 6695 +1
+ Partials 2209 2206 -3
|
Is this ready for review? Seems like all test are passing now |
There are still |
5746335
to
e3965ef
Compare
Actually, I find it difficult to make more progress here, so I am signalling this for review. |
@@ -93,7 +93,7 @@ def _validate_updates( | |||
) | |||
else: | |||
update = outputs | |||
for i, u in zip(init, update, strict=False): | |||
for i, u in zip(init[: len(update)], update, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use strict=False
here?
@@ -1745,7 +1745,7 @@ def setup_method(self): | |||
self.random_stream = np.random.default_rng(utt.fetch_seed()) | |||
|
|||
self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)] | |||
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] | |||
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was zipping input_shapes and filter_shapes together so the two last input_shapes were never being used in the tests.
@@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node): | |||
# Slices to take from val | |||
val_slices = [] | |||
|
|||
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): | |||
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like my previous comment, I find this less readable. The strict=False
indicates clearly that we don't expect the sequences to necessarily have the same length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if they're not the same length why are we zipping them? Are we sure they're always ordered correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because they are ordered correctly yes, and presumably what comes after doesn't matter. It's quite common in Subtensor operations / rewrites
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind reverting this, but just to argue a bit in the favor of strict=True
, I think an additional advantage of this approach is that it makes it clearer which one of the two lists is supposed to be shorter. I personally find that I understand more about what is happening here when I read this version compared to strict=False
.
if len(shape) != x.type.ndim: | ||
return _specify_shape(x, *shape) | ||
|
||
new_shape_matches = all( | ||
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None | ||
) | ||
if new_shape_matches: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awkward, the use of strict=False
seems fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised, this really looks better to me. What do you think of:
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)
if all(s in (None, xts) for (s, xts) in zip(shape, x.type.shape, strict=True)):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My problem is the double call to SpecifyShape.
Also, we already established in the comment that if there's different lengths the function is going to raise so the strict=False follows naturally?
6a15a26
to
9007ce1
Compare
Description
strict=True
argument to all zips when it doesn't produce mistakes in the test suite (464 of them), andstrict=False
to the others (28 of them)There remains 10 non-strict zips that I find difficult to understand.
Related Issue
Checklist
Type of change