Skip to content

Commit

Permalink
Merge pull request #115 from ericbn/add-type-alias-support
Browse files Browse the repository at this point in the history
Support traversal of TypeAlias for Python 3.12
  • Loading branch information
bwhmather authored Apr 19, 2024
2 parents dcb59fd + 20b3f63 commit 2025243
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 54 deletions.
38 changes: 19 additions & 19 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
name: "Unit Tests"
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
runs-on: ["ubuntu-22.04", "windows-2019", "macos-11"]
runs-on: ${{ matrix.runs-on }}
steps:
Expand All @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest
pip install pyyaml==6.0
pip install pyyaml==6.0.1
pip install -e .[test]
- name: Run tests
run: |
Expand All @@ -34,15 +34,15 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov coveralls
pip install pyyaml==6.0
pip install pyyaml==6.0.1
pip install -e .[test]
- name: Run tests
run: |
Expand All @@ -58,10 +58,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -75,10 +75,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -92,10 +92,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -109,10 +109,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -126,16 +126,16 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[test]
pip install pytest
pip install pyyaml==6.0
pip install pyyaml==6.0.1
pip install pylint
- name: Run pylint
run: |
Expand All @@ -146,16 +146,16 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: "3.11"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy
pip install pytest
pip install pyyaml==6.0
pip install pyyaml==6.0.1
pip install types-PyYAML
pip install types-setuptools
- name: Run mypy
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: 3.11
python-version: 3.12
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
32 changes: 32 additions & 0 deletions src/ssort/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _iter_child_nodes_of_function_def(
if node.returns is not None:
yield node.returns
yield from node.body
if sys.version_info >= (3, 12):
yield from node.type_params


@iter_child_nodes.register(ast.ClassDef)
Expand All @@ -58,6 +60,8 @@ def _iter_child_nodes_of_class_def(node: ast.ClassDef) -> Iterable[ast.AST]:
yield from node.bases
yield from node.keywords
yield from node.body
if sys.version_info >= (3, 12):
yield from node.type_params


@iter_child_nodes.register(ast.Return)
Expand Down Expand Up @@ -496,3 +500,31 @@ def _iter_child_nodes_of_type_ignore(
node: ast.TypeIgnore,
) -> Iterable[ast.AST]:
return ()


if sys.version_info >= (3, 12):

@iter_child_nodes.register(ast.TypeAlias)
def _iter_child_nodes_of_type_alias(
node: ast.TypeAlias,
) -> Iterable[ast.AST]:
yield node.name
yield from node.type_params
yield node.value

@iter_child_nodes.register(ast.TypeVar)
def _iter_child_nodes_of_type_var(node: ast.TypeVar) -> Iterable[ast.AST]:
if node.bound is not None:
yield node.bound

@iter_child_nodes.register(ast.ParamSpec)
def _iter_child_nodes_of_param_spec(
node: ast.ParamSpec,
) -> Iterable[ast.AST]:
return ()

@iter_child_nodes.register(ast.TypeVarTuple)
def _iter_child_nodes_of_type_var_tuple(
node: ast.TypeVarTuple,
) -> Iterable[ast.AST]:
return ()
9 changes: 9 additions & 0 deletions src/ssort/_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def split_class(statement):
assert token.type == NAME

token = next(tokens)
if token.string == "[":
token = next(tokens)
depth = 1
while depth:
if token.string == "[":
depth += 1
if token.string == "]":
depth -= 1
token = next(tokens)
if token.string == "(":
token = next(tokens)
depth = 1
Expand Down
93 changes: 69 additions & 24 deletions src/ssort/_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import dataclasses
import enum
import sys
from typing import Iterable

from ssort._ast import iter_child_nodes
Expand Down Expand Up @@ -32,6 +33,13 @@ def get_requirements(node: ast.AST) -> Iterable[Requirement]:
yield from get_requirements(child)


def _get_requirements_from_nodes(
nodes: Iterable[ast.AST],
) -> Iterable[Requirement]:
for node in nodes:
yield from get_requirements(node)


def _get_scope_from_arguments(args: ast.arguments) -> set[str]:
scope: set[str] = set()
scope.update(arg.arg for arg in args.posonlyargs)
Expand All @@ -44,20 +52,54 @@ def _get_scope_from_arguments(args: ast.arguments) -> set[str]:
return scope


if sys.version_info >= (3, 12):

def _get_scope_from_type_params(
type_params: list[ast.type_param],
) -> set[str]:
return set(type_param.name for type_param in type_params) # type: ignore[attr-defined]

@get_requirements.register(ast.TypeAlias)
def _get_requirements_for_type_alias(
node: ast.TypeAlias,
) -> Iterable[Requirement]:
scope = _get_scope_from_type_params(node.type_params)
for requirement in _get_requirements_from_nodes(node.type_params):
if requirement.name not in scope:
yield requirement

scope.add(node.name.id)
for requirement in get_requirements(node.value):
if not requirement.deferred:
requirement = dataclasses.replace(requirement, deferred=True)
if requirement.name not in scope:
yield requirement


@get_requirements.register(ast.FunctionDef)
@get_requirements.register(ast.AsyncFunctionDef)
def _get_requirements_for_function_def(
node: ast.FunctionDef | ast.AsyncFunctionDef,
) -> Iterable[Requirement]:
for decorator in node.decorator_list:
yield from get_requirements(decorator)
yield from _get_requirements_from_nodes(node.decorator_list)

yield from get_requirements(node.args)
scope: set[str] = set()
if sys.version_info >= (3, 12):
scope.update(_get_scope_from_type_params(node.type_params))
for requirement in _get_requirements_from_nodes(node.type_params):
if requirement.name not in scope:
yield requirement

for requirement in get_requirements(node.args):
if requirement.name not in scope:
yield requirement

if node.returns is not None:
yield from get_requirements(node.returns)
for requirement in get_requirements(node.returns):
if requirement.name not in scope:
yield requirement

scope = _get_scope_from_arguments(node.args)
scope.update(_get_scope_from_arguments(node.args))

requirements = []
for statement in node.body:
Expand All @@ -80,13 +122,20 @@ def _get_requirements_for_function_def(
def _get_requirements_for_class_def(
node: ast.ClassDef,
) -> Iterable[Requirement]:
for decorator in node.decorator_list:
yield from get_requirements(decorator)
yield from _get_requirements_from_nodes(node.decorator_list)

for base in node.bases:
yield from get_requirements(base)
scope: set[str] = set()
if sys.version_info >= (3, 12):
scope.update(_get_scope_from_type_params(node.type_params))
for requirement in _get_requirements_from_nodes(node.type_params):
if requirement.name not in scope:
yield requirement

scope = set(CLASS_BUILTINS)
for requirement in _get_requirements_from_nodes(node.bases):
if requirement.name not in scope:
yield requirement

scope.update(CLASS_BUILTINS)

for statement in node.body:
for stmt_dep in get_requirements(statement):
Expand All @@ -106,15 +155,13 @@ def _get_requirements_for_for(
yield from get_requirements(node.target)
yield from get_requirements(node.iter)

for stmt in node.body:
for requirement in get_requirements(stmt):
if requirement.name not in bindings:
yield requirement
for requirement in _get_requirements_from_nodes(node.body):
if requirement.name not in bindings:
yield requirement

for stmt in node.orelse:
for requirement in get_requirements(stmt):
if requirement.name not in bindings:
yield requirement
for requirement in _get_requirements_from_nodes(node.orelse):
if requirement.name not in bindings:
yield requirement


@get_requirements.register(ast.With)
Expand All @@ -124,13 +171,11 @@ def _get_requirements_for_with(
) -> Iterable[Requirement]:
bindings = set(get_bindings(node))

for item in node.items:
yield from get_requirements(item)
yield from _get_requirements_from_nodes(node.items)

for stmt in node.body:
for requirement in get_requirements(stmt):
if requirement.name not in bindings:
yield requirement
for requirement in _get_requirements_from_nodes(node.body):
if requirement.name not in bindings:
yield requirement


@get_requirements.register(ast.Global)
Expand Down
Loading

0 comments on commit 2025243

Please sign in to comment.