Skip to content

Commit

Permalink
Python tree-sitter to CAST Porting: Imports (#826)
Browse files Browse the repository at this point in the history
This PR introduces support for generating CAST import statements. This
is using tree sitter, as part of the ongoing effort to port over the
Python AST to CAST generation to use tree-sitter.

### Summary of Changes

- Adds handlers to python/ts2cast.py to support CAST generation for
import statements.
- Adds basic CAST generation for the attribute idiom. In order to
support function calls from imports.
- Adds test script test_import_cast.py with unit tests to maintain
consistency.
- Added 'support' for yields and assert statements. The current
generation is to passthrough them, so nothing important gets generated
from these.

### Related issues
- 

Resolves #804

---------

Co-authored-by: Vincent Raymond <[email protected]>
  • Loading branch information
titomeister and vincentraymond-ua authored Mar 4, 2024
1 parent 08fffe3 commit 102dd72
Show file tree
Hide file tree
Showing 2 changed files with 419 additions and 6 deletions.
141 changes: 135 additions & 6 deletions skema/program_analysis/CAST/python/ts2cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(self, source_file_path: str, from_file = True):
# Additional variables used in generation
self.var_count = 0

# A dictionary used to keep track of aliases that imports use
# (like import x as y, or from x import y as z)
# Used to resolve aliasing in imports
self.aliases = {}

# Tree walking structures
self.variable_context = VariableContext()
self.node_helper = NodeHelper(self.source, self.source_file_name)
Expand All @@ -104,7 +109,19 @@ def run(self, root) -> List[Module]:

# TODO: node helper for ignoring comments

def check_alias(self, name):
"""Given a python string that represents a name,
this function checks to see if that name is an alias
for a different name, and returns it if it is indeed an alias.
Otherwise, the original name is returned.
"""
if name in self.aliases:
return self.aliases[name]
else:
return name

def visit(self, node: Node):
# print(f"===Visiting node[{node.type}]===")
if node.type == "module":
return self.visit_module(node)
elif node.type == "parenthesized_expression":
Expand All @@ -125,6 +142,8 @@ def visit(self, node: Node):
return self.visit_comparison_op(node)
elif node.type == "assignment":
return self.visit_assignment(node)
elif node.type == "attribute":
return self.visit_attribute(node)
elif node.type == "identifier":
return self.visit_identifier(node)
elif node.type == "unary_operator":
Expand All @@ -147,6 +166,14 @@ def visit(self, node: Node):
return self.visit_while(node)
elif node.type == "for_statement":
return self.visit_for(node)
elif node.type == "import_statement":
return self.visit_import(node)
elif node.type == "import_from_statement":
return self.visit_import_from(node)
elif node.type == "yield":
return self.visit_yield(node)
elif node.type == "assert_statement":
return self.visit_assert(node)
else:
return self._visit_passthrough(node)

Expand Down Expand Up @@ -236,8 +263,10 @@ def visit_return(self, node: Node) -> ModelReturn:

def visit_call(self, node: Node) -> Call:
ref = self.node_helper.get_source_ref(node)
func_identifier = get_first_child_by_type(node, "identifier")
func_name = self.visit(func_identifier) #self.node_helper.get_identifier(func_identifier)

func_cast = self.visit(node.children[0])

func_name = get_func_name_node(func_cast)

arg_list = get_first_child_by_type(node, "argument_list")
args = get_non_control_children(arg_list)
Expand All @@ -250,7 +279,7 @@ def visit_call(self, node: Node) -> Call:
elif isinstance(cast, AstNode):
func_args.append(cast)

if func_name.val.name == "range":
if get_name_node(func_cast).name == "range":
start_step_value = CASTLiteralValue(
ScalarType.INTEGER,
value="1",
Expand All @@ -267,7 +296,7 @@ def visit_call(self, node: Node) -> Call:

# Function calls only want the 'Name' part of the 'Var' that the visit returns
return Call(
func=func_name.val,
func=func_name,
arguments=func_args,
source_refs=[ref]
)
Expand Down Expand Up @@ -500,6 +529,69 @@ def visit_literal(self, node: Node) -> Any:
source_refs=[literal_source_ref]
)

def handle_dotted_name(self, import_stmt) -> ModelImport:
ref = self.node_helper.get_source_ref(import_stmt)
name = self.node_helper.get_identifier(import_stmt)
self.visit(import_stmt)

return name

def handle_aliased_import(self, import_stmt) -> ModelImport:
ref = self.node_helper.get_source_ref(import_stmt)
dot_name = get_children_by_types(import_stmt,["dotted_name"])[0]
name = self.handle_dotted_name(dot_name)
alias = get_children_by_types(import_stmt, ["identifier"])[0]
self.visit(alias)

return (name, self.node_helper.get_identifier(alias))

def visit_import(self, node: Node):
ref = self.node_helper.get_source_ref(node)
to_ret = []

names_list = get_children_by_types(node, ["dotted_name", "aliased_import"])
for name in names_list:
if name.type == "dotted_name":
resolved_name = self.handle_dotted_name(name)
to_ret.append(ModelImport(name=resolved_name, alias=None, symbol=None, all=False, source_refs=ref))
elif name.type == "aliased_import":
resolved_name = self.handle_aliased_import(name)
self.aliases[resolved_name[1]] = resolved_name[0]
to_ret.append(ModelImport(name=resolved_name[0], alias=resolved_name[1], symbol=None, all=False, source_refs=ref))

return to_ret

def visit_import_from(self, node: Node):
ref = self.node_helper.get_source_ref(node)
to_ret = []

names_list = get_children_by_types(node, ["dotted_name", "aliased_import"])
wild_card = get_children_by_types(node, ["wildcard_import"])
module_name = self.node_helper.get_identifier(names_list[0])

# if "wildcard_import" exists then it'll be in the list
if len(wild_card) == 1:
to_ret.append(ModelImport(name=module_name, alias=None, symbol=None, all=True, source_refs=ref))
else:
for name in names_list[1:]:
if name.type == "dotted_name":
resolved_name = self.handle_dotted_name(name)
to_ret.append(ModelImport(name=module_name, alias=None, symbol=resolved_name, all=False, source_refs=ref))
elif name.type == "aliased_import":
resolved_name = self.handle_aliased_import(name)
self.aliases[resolved_name[1]] = resolved_name[0]
to_ret.append(ModelImport(name=module_name, alias=resolved_name[1], symbol=resolved_name[0], all=False, source_refs=ref))

return to_ret

def visit_attribute(self, node: Node):
ref = self.node_helper.get_source_ref(node)
obj,_,attr = node.children
obj_cast = self.visit(obj)
attr_cast = self.visit(attr)

return Attribute(value= get_name_node(obj_cast), attr=get_name_node(attr_cast), source_refs=ref)

def handle_for_clause(self, node: Node):
# Given the "for x in seq" clause of a list comprehension
# we translate it to a CAST for loop, leaving the actual
Expand Down Expand Up @@ -884,7 +976,9 @@ def visit_for(self, node: Node) -> Loop:

def visit_name(self, node):
# First, we will check if this name is already defined, and if it is return the name node generated previously
identifier = self.node_helper.get_identifier(node)
# NOTE: the call to check_alias is a crucial change, to resolve aliasing
# need to make sure nothing breaks
identifier = self.check_alias(self.node_helper.get_identifier(node))
if self.variable_context.is_variable(identifier):
return self.variable_context.get_node(identifier)

Expand All @@ -909,14 +1003,49 @@ def get_gromet_function_node(self, func_name: str) -> Name:

return self.variable_context.add_variable(func_name, "function", None)

def visit_yield(self, node):
source_code_data_type = ["Python", "3.8", "List"]
ref = self.node_helper.get_source_ref(node)
return [
CASTLiteralValue(
StructureType.LIST,
"YieldNotImplemented",
source_code_data_type,
ref
)
]

def visit_assert(self, node):
source_code_data_type = ["Python", "3.8", "List"]
ref = self.node_helper.get_source_ref(node)
return [
CASTLiteralValue(
StructureType.LIST,
"AssertNotImplemented",
source_code_data_type,
ref
)
]


def get_name_node(node):
# Given a CAST node, if it's type Var, then we extract the name node out of it
# If it's anything else, then the node just gets returned normally
cur_node = node
if isinstance(node, list):
cur_node = node[0]
if isinstance(cur_node, Attribute):
return get_name_node(cur_node.attr)
if isinstance(cur_node, Var):
return cur_node.val
else:
return node

def get_func_name_node(node):
# Given a CAST node, we attempt to extract the appropriate name element
# from it.
cur_node = node
if isinstance(cur_node, Var):
return cur_node.val
else:
return node
return cur_node
Loading

0 comments on commit 102dd72

Please sign in to comment.