Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check changed files with pytype on github actions #3571

Merged
merged 13 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ jobs:
uses: actions/checkout@v2
- name: Setting up python
uses: actions/setup-python@v2
# We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
# We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Install yapf
run: python3 -m pip install yapf
Expand All @@ -65,6 +65,26 @@ jobs:
printf "You can fix the lint errors above by running\n"
printf " git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n"

pytype:
runs-on: ubuntu-18.04
env:
BASE_REF: ${{ github.base_ref }}
steps:
- name: Checking out repository
uses: actions/checkout@v2
- name: Setting up python
uses: actions/setup-python@v2
with:
# Pytype does not support python3.9, which this action defaults to.
python-version: '3.8'
- name: Fetching Base Branch
# We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Install pytype
run: python3 -m pip install pytype
- name: Run pytype on changed files
run: ./build_tools/pytype/check_diff.sh "${BASE_REF?}"

clang-format:
runs-on: ubuntu-18.04
env:
Expand All @@ -79,8 +99,8 @@ jobs:
chmod +x /tmp/git-clang-format
- name: Checking out repository
uses: actions/checkout@v2
# We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
# We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Running clang-format on changed source files
run: |
Expand All @@ -102,8 +122,8 @@ jobs:
steps:
- name: Checking out repository
uses: actions/checkout@v2
# We have to explicitly fetch the base branch as well
- name: Fetching Base Branch
# We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${BASE_REF?}:${BASE_REF?}"
- name: Checking tabs
run: ./scripts/check_tabs.sh "${BASE_REF?}"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Python
*.pyc
**/.ipynb_checkpoints/
.pytype/

# Visual Studio files
.vs/
Expand Down
96 changes: 96 additions & 0 deletions build_tools/pytype/check_diff.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/bin/bash
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Uses git diff to run pytype on changed files.
# Example Usage:
# Defaults to comparing against 'main'.
# ./build_tools/pytype/check_diff.sh
# A specific branch can be specified.
# ./build_tools/pytype/check_diff.sh google
# Or all python files outside of './third_party/' can be checked.
# ./build_tools/pytype/check_diff.sh all

DIFF_TARGET="${1:-main}"
echo "Running pycheck against '${DIFF_TARGET?}'"

if [[ "${DIFF_TARGET?}" = "all" ]]; then
phoenix-meadowlark marked this conversation as resolved.
Show resolved Hide resolved
FILES=$(find -name "*\.py" -not -path "./third_party/*")
else
FILES=$(git diff --name-only "${DIFF_TARGET?}" | grep '.*\.py')
fi


# We seperate the python files into multiple pytype calls because otherwise
# Ninja gets confused. See https://github.com/google/pytype/issues/198
BASE=$(echo "${FILES?}" | grep -vP '^(\./)?integrations/*')
IREE_TF=$(echo "${FILES?}" | \
grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/tf/.*')
IREE_XLA=$(echo "${FILES?}" | \
grep -P '^(\./)?integrations/tensorflow/bindings/python/pyiree/xla/.*')
COMPILER=$(echo "${FILES?}" | \
grep -P '^(\./)?integrations/tensorflow/compiler/.*')
E2E=$(echo "${FILES?}" | grep -P '^(\./)?integrations/tensorflow/e2e/.*')

function check_files() {
# $1: previous return code
# $2...: files to check
if [[ -z "${@:2}" ]]; then
echo "No files to check."
echo
return "${1?}"
fi

# We disable import-error because pytype doesn't have access to bazel.
# We disable pyi-error because of the way the bindings imports work.
echo "${@:2}" | \
xargs python3 -m pytype --disable=import-error,pyi-error -j $(nproc)
phoenix-meadowlark marked this conversation as resolved.
Show resolved Hide resolved
EXIT_CODE="$?"
echo
if [[ "${EXIT_CODE?}" -gt "${1?}" ]]; then
return "${EXIT_CODE?}"
else
return "${1?}"
fi
}

MAX_CODE=0

echo "Checking .py files outside of integrations/"
check_files "${MAX_CODE?}" "${BASE?}"
MAX_CODE="$?"

echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/tf/.*"
check_files "${MAX_CODE?}" "${IREE_TF?}"
MAX_CODE="$?"

echo "Checking .py files in integrations/tensorflow/bindings/python/pyiree/xla/.*"
check_files "${MAX_CODE?}" "${IREE_XLA?}"
MAX_CODE="$?"

echo "Checking .py files in integrations/tensorflow/compiler/.*"
check_files "${MAX_CODE?}" "${COMPILER?}"
MAX_CODE="$?"

echo "Checking .py files in integrations/tensorflow/e2e/.*"
check_files "${MAX_CODE?}" "${E2E?}"
MAX_CODE="$?"


if [[ "${MAX_CODE?}" -ne "0" ]]; then
echo "One or more pytype checks failed."
echo "You can view these errors locally by running"
echo " ./build_tools/pytype/check_diff.sh ${DIFF_TARGET?}"
exit "${MAX_CODE?}"
fi
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class Trace:
"""Stores the inputs and outputs of a series of calls to a module."""

def __init__(self,
module: tf_utils.CompiledModule,
module: Union[tf_utils.CompiledModule, None],
function: Union[Callable[["TracedModule"], None], None],
_load_dict: Dict[str, Any] = None):
"""Extracts metadata from module and function and initializes.
Expand Down Expand Up @@ -563,7 +563,7 @@ def __init__(self, module: tf_utils.CompiledModule, trace: Trace):
self._module = module
self._trace = trace

def _trace_call(self, method: Callable[..., Any], method_name: str):
def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
"""Decorates a CompiledModule method to capture its inputs and outputs."""

def call(*args, **kwargs):
Expand Down Expand Up @@ -611,8 +611,8 @@ def __getattr__(self, attr):


def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
module_class: Type[tf.Module],
exported_names: Sequence[str] = ()) -> Modules:
"""Compiles module_class to each backend that we test.

Args:
Expand Down Expand Up @@ -648,11 +648,10 @@ def compile_tf_module(
return _global_modules


def compile_tf_signature_def_saved_model(saved_model_dir: str,
saved_model_tags: Set[str],
module_name: str, exported_name: str,
input_names: Sequence[str],
output_names: Sequence[str]):
def compile_tf_signature_def_saved_model(
saved_model_dir: str, saved_model_tags: Set[str], module_name: str,
exported_name: str, input_names: Sequence[str],
output_names: Sequence[str]) -> Modules:
"""Compiles a SignatureDef SavedModel to each backend that we test.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def save_input_values(inputs: Sequence[np.ndarray],


def _setup_mlir_crash_reproducer(
function: Callable[[Any], Any],
function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a thread about this somewhere. I think it said there is a way to write this. Maybe worth investigating, but can be a followup

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. Could you send a link?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/mypy#5876. I would have to understand pytype better to understand all that's going on in that thread. python/typing#264 (comment) and https://stackoverflow.com/q/57837609 may also help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was looking through those. The first one would be ideal if it were implemented. But the other two make assumptions about types that the function can ingest, which would be inappropriate in our case.

artifacts_dir: str,
backend_id: str,
) -> Callable[[Any], Any]:
) -> Any: # Callable[Any, Any]
"""Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.

Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
Expand Down Expand Up @@ -253,14 +253,24 @@ def _compile_module(saved_model_dir, saved_model_tags, backend_info,
exported_name, artifacts_dir)


class _FunctionWrapper(object):

def __call__(self, *args, **kwargs):
raise NotImplementedError()

def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
"""Dummy function to match _IreeFunctionWrapper's API."""
return ("",), ("",)


class CompiledModule(object):
"""Base class for the TF and IREE compiled modules."""

def __init__(
self,
module_name: str,
backend_info: "BackendInfo",
compiled_paths: Dict[str, str],
compiled_paths: Union[Dict[str, str], None],
):
"""Shared base constructor – not useful on its own.

Expand Down Expand Up @@ -344,20 +354,16 @@ def create_from_signature_def_saved_model(cls,
"""
raise NotImplementedError()

def __getattr__(self, attr: str) -> _FunctionWrapper:
raise NotImplementedError()

def iree_serializable(self):
return False

def tflite_serializable(self):
return False


class _FunctionWrapper(object):

def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
"""Dummy function to match _IreeFunctionWrapper's API."""
return (), ()


class _IreeFunctionWrapper(_FunctionWrapper):
"""Wraps an IREE function, making it callable."""

Expand Down Expand Up @@ -681,7 +687,7 @@ def _get_concrete_functions(module_class: Type[tf.Module],
instance = module_class()
functions = []
for name in exported_names:
functions.append(instance.__getattribute__(name).get_concrete_function())
functions.append(getattr(instance, name).get_concrete_function())
return functions, exported_names


Expand Down Expand Up @@ -787,7 +793,8 @@ class _TfLiteFunctionWrapper(_FunctionWrapper):
def __init__(self, interpreter: tf.lite.Interpreter):
self._interpreter = interpreter

def __call__(self, *args, **kwargs) -> Tuple[Any]:
def __call__(self, *args,
**kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
if len(args) and len(kwargs):
raise ValueError("Passing both args and kwargs is not supported by "
"_TfLiteFunctionWrapper")
Expand Down Expand Up @@ -823,13 +830,12 @@ def __call__(self, *args, **kwargs) -> Tuple[Any]:
outputs.append(value)

# Process them to match the output of the tf.Module.
if not is_dict:
outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]
if is_dict:
return dict(outputs)
else:
outputs = dict(outputs)
return outputs
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)


class TfLiteCompiledModule(CompiledModule):
Expand Down