-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added compile fixture #37
Draft
markkraay
wants to merge
32
commits into
main
Choose a base branch
from
dev-mkraay-fix-256-fixture-for-compile
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
90d48b8
added compile fixture
markkraay c3f2312
added compile fixture to where tests
markkraay 1423b20
compile fixture for unary
markkraay 7db02df
added compile fixture for strongly_typed
markkraay f128221
change `compute` to `func`
markkraay 84469e5
use `eager` instead of `lazy`
markkraay e56af4a
remove baked in args from compile fixture
markkraay c7d4757
enable more tests
markkraay 688cf3f
enable compile fixture in more ops
markkraay 546f5d1
enabled compile_fixture for test_unsqueeze
markkraay cd081d9
Also check shape in test_unsqueeze
markkraay 0e07c83
enable test_slice
markkraay c194230
fixed bug with filtering kwargs
markkraay 910911e
enabled reshape and reduce
markkraay 4e5a9b2
added comment for allclose
markkraay e6636d8
try to dynamically add markers for compile / eager
markkraay eb13a3d
enable test_flip: all tests passing
markkraay 9cf2bc6
enable test_full: failing test_shape_tensor[compile]
markkraay c35b45f
remove ir dump from test_full
markkraay 8c5ad76
enable test_iota; failing test_iota_from_shape_tensor[compile]
markkraay 18bcefc
remove extra fixture
markkraay f9f4cb3
enable test_linear; all tests passing
markkraay 82edeaf
enable test_matrix_multiplication; all tests passing
markkraay 0e02f61
enable test_reshape; all tests passing
markkraay a768de7
enable test_cast: failing test_cast[compile-*] & test_cast_from_bool[…
markkraay 623f54c
enable test_concatenate: all tests pass
markkraay c998d30
enable & reformat test_plugin: all test passing
markkraay 8ce3f62
enable test_convolution; all tests pass
markkraay a1d1735
enable test_quantize; filed issue #102
markkraay 7d7de04
enable test_functional; failing many
markkraay e146aa1
enabled test_conv_transpose; all tests passing
markkraay 0575c5b
fixed test_cast
markkraay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# | ||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import pytest | ||
|
||
import tripy as tp | ||
|
||
|
||
@pytest.fixture(params=["compile", "eager"]) | ||
def compile_fixture(request): | ||
def wrapper(func, *args, **kwargs): | ||
def get_shape(x: tp.Tensor): | ||
x.eval() | ||
return tp.InputInfo(x.trace_tensor.shape, dtype=x.dtype) | ||
|
||
mode = request.param | ||
if mode == "compile": | ||
compiler = tp.Compiler(func) | ||
# Cast appropriate args / kwargs to use `tp.InputInfo` | ||
compile_args = tuple(map(lambda x: get_shape(x) if isinstance(x, tp.Tensor) else x, list(args))) | ||
compile_kwargs = dict((k, get_shape(v) if isinstance(v, tp.Tensor) else v) for k, v in kwargs.items()) | ||
compiled_func = compiler.compile(*compile_args, **compile_kwargs) | ||
# Remove baked in args, aka, only keep tp.Tensor's | ||
args = tuple(filter(lambda x: isinstance(x, tp.Tensor), args)) | ||
kwargs = dict(filter(lambda kv: isinstance(kv[1], tp.Tensor), kwargs.items())) | ||
return compiled_func(*args, **kwargs) | ||
elif mode == "eager": | ||
return func(*args, **kwargs) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#102 (comment)
compile_fixture
should be able to specify which args are constants / parameters. For Q/DQ, scale should be Parameters.