Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python tree-sitter to CAST porting: Loops #745

Merged
merged 7 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions skema/program_analysis/CAST/python/node_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import List, Dict
from skema.program_analysis.CAST2FN.model.cast import SourceRef

Expand All @@ -24,6 +25,34 @@
"not"
]

# Whatever constructs we see in the left
# part of the for loop construct
# for LEFT in RIGHT:
FOR_LOOP_LEFT_TYPES = [
"identifier",
"tuple_pattern",
"pattern_list",
"list_pattern"
]

# Whatever constructs we see in the right
# part of the for loop construct
# for LEFT in RIGHT:
FOR_LOOP_RIGHT_TYPES = [
"call",
"identifier",
"list",
"tuple"
]

# Whatever constructs we see in the conditional
# part of the while loop
WHILE_COND_TYPES = [
"boolean_operator",
"call",
"comparison_operator"
]

class NodeHelper():
def __init__(self, source: str, source_file_name: str):
self.source = source
Expand All @@ -32,14 +61,16 @@ def __init__(self, source: str, source_file_name: str):
# get_identifier optimization variables
self.source_lines = source.splitlines(keepends=True)
self.line_lengths = [len(line) for line in self.source_lines]
self.line_length_sums = [sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))]
self.line_length_sums = [0] + list(itertools.accumulate(self.line_lengths))

def get_identifier(self, node: Node) -> str:
"""Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
start_line, start_column = node.start_point
end_line, end_column = node.end_point

start_index = self.line_length_sums[start_line-1] + start_column
# Edge case for when an identifier is on the very first line of the code
# We can't index into the line_length_sums
start_index = self.line_length_sums[start_line] + start_column
if start_line == end_line:
end_index = start_index + (end_column-start_column)
else:
Expand Down
225 changes: 220 additions & 5 deletions skema/program_analysis/CAST/python/ts2cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
ModelIf,
RecordDef,
Attribute,
ScalarType
ScalarType,
StructureType
)

from skema.program_analysis.CAST.python.node_helper import (
Expand All @@ -35,7 +36,10 @@
get_first_child_index,
get_last_child_index,
get_control_children,
get_non_control_children
get_non_control_children,
FOR_LOOP_LEFT_TYPES,
FOR_LOOP_RIGHT_TYPES,
WHILE_COND_TYPES
)
from skema.program_analysis.CAST.python.util import (
generate_dummy_source_refs,
Expand Down Expand Up @@ -71,6 +75,9 @@ def __init__(self, source_file_path: str, from_file = True):
)
)

# Additional variables used in generation
self.var_count = 0

# Tree walking structures
self.variable_context = VariableContext()
self.node_helper = NodeHelper(self.source, self.source_file_name)
Expand All @@ -82,6 +89,7 @@ def __init__(self, source_file_path: str, from_file = True):
def generate_cast(self) -> List[CAST]:
'''Interface for generating CAST.'''
module = self.run(self.tree.root_node)
module.name = self.source_file_name
return CAST([generate_dummy_source_refs(module)], "Python")

def run(self, root) -> List[Module]:
Expand Down Expand Up @@ -115,12 +123,18 @@ def visit(self, node: Node):
return self.visit_assignment(node)
elif node.type == "identifier":
return self.visit_identifier(node)
elif node.type =="unary_operator":
elif node.type == "unary_operator":
return self.visit_unary_op(node)
elif node.type =="binary_operator":
elif node.type == "binary_operator":
return self.visit_binary_op(node)
elif node.type in ["integer"]:
elif node.type in ["integer", "list"]:
return self.visit_literal(node)
elif node.type in ["list_pattern", "pattern_list", "tuple_pattern"]:
return self.visit_pattern(node)
elif node.type == "while_statement":
return self.visit_while(node)
elif node.type == "for_statement":
return self.visit_for(node)
else:
return self._visit_passthrough(node)

Expand Down Expand Up @@ -224,6 +238,21 @@ def visit_call(self, node: Node) -> Call:
elif isinstance(cast, AstNode):
func_args.append(cast)

if func_name.val.name == "range":
start_step_value = LiteralValue(
ScalarType.INTEGER,
value="1",
source_code_data_type=["Python", PYTHON_VERSION, str(type(1))],
source_refs=[ref]
)
# Add a step value
if len(func_args) == 2:
func_args.append(start_step_value)
# Add a start and step value
elif len(func_args) == 1:
func_args.insert(0, start_step_value)
func_args.append(start_step_value)

# Function calls only want the 'Name' part of the 'Var' that the visit returns
return Call(
func=func_name.val,
Expand Down Expand Up @@ -371,6 +400,17 @@ def visit_binary_op(self, node: Node) -> Operator:
source_refs=[ref]
)

def visit_pattern(self, node: Node):
pattern_cast = []
for elem in node.children:
cast = self.visit(elem)
if isinstance(cast, List):
pattern_cast.extend(cast)
elif isinstance(cast, AstNode):
pattern_cast.append(cast)

return LiteralValue(value_type=StructureType.TUPLE, value=pattern_cast)

def visit_identifier(self, node: Node) -> Var:
identifier = self.node_helper.get_identifier(node)

Expand Down Expand Up @@ -417,6 +457,173 @@ def visit_literal(self, node: Node) -> Any:
source_code_data_type=["Python", PYTHON_VERSION, str(type(True))],
source_refs=[literal_source_ref]
)
elif literal_type == "list":
list_items = []
for elem in node.children:
cast = self.visit(elem)
if isinstance(cast, List):
list_items.extend(cast)
elif isinstance(cast, AstNode):
list_items.append(cast)

return LiteralValue(
value_type=StructureType.LIST,
value = list_items,
source_code_data_type=["Python", PYTHON_VERSION, str(type([0]))],
source_refs=[literal_source_ref]
)
elif literal_type == "tuple":
tuple_items = []
for elem in node.children:
cast = self.visit(cast)
if isinstance(cast, List):
tuple_items.extend(cast)
elif isinstance(cast, AstNode):
tuple_items.append(cast)

return LiteralValue(
value_type=StructureType.LIST,
value = tuple_items,
source_code_data_type=["Python", PYTHON_VERSION, str(type((0)))],
source_refs=[literal_source_ref]
)



def visit_while(self, node: Node) -> Loop:
ref = self.node_helper.get_source_ref(node)

# Push a variable context since a loop
# can create variables that only it can see
self.variable_context.push_context()

loop_cond_node = get_children_by_types(node, WHILE_COND_TYPES)[0]
loop_body_node = get_children_by_types(node, "block")[0].children

loop_cond = self.visit(loop_cond_node)

loop_body = []
for node in loop_body_node:
cast = self.visit(node)
if isinstance(cast, List):
loop_body.extend(cast)
elif isinstance(cast, AstNode):
loop_body.append(cast)

self.variable_context.pop_context()

return Loop(
pre=[],
expr=loop_cond,
body=loop_body,
post=[],
source_refs = ref
)

def visit_for(self, node: Node) -> Loop:
ref = self.node_helper.get_source_ref(node)

# Pre: left, right
loop_cond_left = get_children_by_types(node, FOR_LOOP_LEFT_TYPES)[0]
loop_cond_right = get_children_by_types(node, FOR_LOOP_RIGHT_TYPES)[-1]

# Construct pre and expr value using left and right as needed
# need calls to "_Iterator"

self.variable_context.push_context()
iterator_name = self.variable_context.generate_iterator()
stop_cond_name = self.variable_context.generate_stop_condition()
iter_func = self.get_gromet_function_node("iter")
next_func = self.get_gromet_function_node("next")

loop_cond_left_cast = self.visit(loop_cond_left)
loop_cond_right_cast = self.visit(loop_cond_right)

loop_pre = []
loop_pre.append(
Assignment(
left = Var(iterator_name, "Iterator"),
right = Call(
iter_func,
arguments=[loop_cond_right_cast]
)
)
)

loop_pre.append(
Assignment(
left=LiteralValue(
"Tuple",
[
loop_cond_left_cast,
Var(iterator_name, "Iterator"),
Var(stop_cond_name, "Boolean"),
],
source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"],
source_refs=ref
),
right=Call(
next_func,
arguments=[Var(iterator_name, "Iterator")],
),
)

)

loop_expr = Operator(
source_language="Python",
interpreter="Python",
version=PYTHON_VERSION,
op="ast.Eq",
operands=[
stop_cond_name,
LiteralValue(
ScalarType.BOOLEAN,
False,
["Python", PYTHON_VERSION, "boolean"],
source_refs=ref,
)
],
source_refs=ref
)

loop_body_node = get_children_by_types(node, "block")[0].children
loop_body = []
for node in loop_body_node:
cast = self.visit(node)
if isinstance(cast, List):
loop_body.extend(cast)
elif isinstance(cast, AstNode):
loop_body.append(cast)

# Insert an additional call to 'next' at the end of the loop body,
# to facilitate looping in GroMEt
loop_body.append(
Assignment(
left=LiteralValue(
"Tuple",
[
loop_cond_left_cast,
Var(iterator_name, "Iterator"),
Var(stop_cond_name, "Boolean"),
],
),
right=Call(
next_func,
arguments=[Var(iterator_name, "Iterator")],
),
)
)

self.variable_context.pop_context()
return Loop(
pre=loop_pre,
expr=loop_expr,
body=loop_body,
post=[],
source_refs = ref
)


def visit_name(self, node):
# First, we will check if this name is already defined, and if it is return the name node generated previously
Expand All @@ -436,6 +643,14 @@ def _visit_passthrough(self, node):
child_cast = self.visit(child)
if child_cast:
return child_cast

def get_gromet_function_node(self, func_name: str) -> Name:
# Idealy, we would be able to create a dummy node and just call the name visitor.
# However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here.
if self.variable_context.is_variable(func_name):
return self.variable_context.get_node(func_name)

return self.variable_context.add_variable(func_name, "function", None)

def get_name_node(node):
# Given a CAST node, if it's type Var, then we extract the name node out of it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,10 @@ def _(self, node: LiteralValue):
return node_uid
elif node.value_type == StructureType.TUPLE:
node_uid = uuid.uuid4()
self.G.add_node(node_uid, label=f"Tuple (...)")
self.G.add_node(node_uid, label=f"Tuple")
tuple_elems = self.visit_list(node.value)
for elem_uid in tuple_elems:
self.G.add_edge(node_uid, elem_uid)
return node_uid
elif node.value_type == None:
node_uid = uuid.uuid4()
Expand Down
3 changes: 3 additions & 0 deletions skema/program_analysis/tests/test_expression_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ def test_exp1():
assert isinstance(asg_node.right, LiteralValue)
assert asg_node.right.value_type == "Integer"
assert asg_node.right.value == '3'

if __name__ == "__main__":
cast = generate_cast(exp0())
Loading
Loading