Skip to content

Commit

Permalink
[fortran] Support for Computed GO TO (#780)
Browse files Browse the repository at this point in the history
## Summary of Changes
- Adds CAST support for ingesting Computed GO TO
- Adds support for generating GroMEt for CAST files containing Computed
GO TOs.
- Adds a test script `test_goto_computed.py` that contains a unit test
for a small Fortran file that has a Computed GO TO.
- Fixes a small issue in the Annotated CAST generation in
`cast_to_annotated_cast.py`. This fix allows correct support of Computed
GO TOs.
- Updates CAST Visualizer to support visualizing GO TO and Label CAST
nodes.

### CAST Computed GO TO Conversion
Since the expression in Computed GO TOs evaluates to an index rather
than a label, we make a small conversion to the expression in the CAST
to support this. We utilize the _get Gromet function to handle this
indexing.

So:
```fortran
GO TO (100,200,300,400), x+y
```
gets converted to:
```python
_get(["100", "200", "300", "400"], x+y)
``` 

### Related issues

Resolves #698 
Resolves #701 
Resolves #773

---------

Co-authored-by: Tito Ferra <[email protected]>
  • Loading branch information
vincentraymond-ua and titomeister authored Feb 2, 2024
1 parent a9e47dc commit f899664
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 14 deletions.
50 changes: 43 additions & 7 deletions skema/program_analysis/CAST/fortran/ts2cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,53 @@ def visit_function_call(self, node):
source_refs=[self.node_helper.get_source_ref(node)],
)

"""
(keyword_statement [6, 6] - [6, 61]
(statement_label_reference [6, 13] - [6, 16])
(statement_label_reference [6, 18] - [6, 21])
(statement_label_reference [6, 23] - [6, 26])
(statement_label_reference [6, 28] - [6, 31])
(math_expression [6, 34] - [6, 61]
left: (call_expression [6, 34] - [6, 57]
(identifier [6, 34] - [6, 37])
(argument_list [6, 37] - [6, 57]
(math_expression [6, 38] - [6, 53]
left: (math_expression [6, 38] - [6, 49]
left: (parenthesized_expression [6, 38] - [6, 45]
(math_expression [6, 39] - [6, 44]
left: (identifier [6, 39] - [6, 40])
right: (identifier [6, 43] - [6, 44])))
right: (identifier [6, 48] - [6, 49]))
right: (number_literal [6, 52] - [6, 53]))
(number_literal [6, 55] - [6, 56])))
right: (number_literal [6, 60] - [6, 61])))
"""

def visit_keyword_statement(self, node):
# NOTE: RETURN is not the only Fortran keyword. GO TO and CONTINUE are also considered keywords
identifier = self.node_helper.get_identifier(node).lower()
if node.type == "keyword_statement":
if "go to" in identifier:
statement_label_reference = get_first_child_by_type(node, "statement_label_reference")
statement_labels = [
self.node_helper.get_identifier(child)
for child in get_children_by_types(
node, ["statement_label_reference"]
)
]
# If there are multiple statement labels, then this is a COMPUTED GO TO
# Those are handled as a "_get" access into a List of statement labels with the index determined by the expression
if len(statement_labels) > 1:
expr = Call(
func=self.get_gromet_function_node("_get"),
arguments=[
CASTLiteralValue(value_type="List", value=[CASTLiteralValue(value=label, value_type="List") for label in statement_labels]),
self.visit(node.children[-1]),
],
)
return Goto(label=None, expr=expr)
return Goto(
label=self.node_helper.get_identifier(statement_label_reference),
expr=None
label=statement_labels[0],
expr=None,
)
if "continue" in identifier:
return self._visit_no_op(node)
Expand Down Expand Up @@ -419,10 +457,8 @@ def visit_keyword_statement(self, node):

def visit_statement_label(self, node):
"""Visitor for fortran statement labels"""
return Label(
label=self.node_helper.get_identifier(node)
)

return Label(label=self.node_helper.get_identifier(node))

def visit_fortran_builtin_statement(self, node):
"""Visitor for Fortran keywords that are not classified as keyword_statement by tree-sitter"""
# All of the node types that fall into this category end with _statment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def visit_function_def(self, node: FunctionDef):

@_visit.register
def visit_goto(self, node: Goto):
expr = node.expr
expr = self.visit(node.expr) if node.expr != None else None
label = node.label
return AnnCastGoto(expr, label, node.source_refs)

Expand Down
32 changes: 28 additions & 4 deletions skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,15 @@ def visit_function_def(
self.resolve_placeholder_gotos()
var_environment["args"] = deepcopy(prev_arg_env)

def retrieve_labels(self, node: AnnCastCall):
# Retrieves all the labels in node
# Assumes node is a Function Call to "_get" with its labels in the first argument
labels = []
arguments = node.arguments[0]
for arg in arguments.value:
labels.append(arg.value)

return labels

@_visit.register
def visit_goto(
Expand Down Expand Up @@ -3273,7 +3282,7 @@ def visit_goto(
for var in vars:
for v in vars[var]:
if v not in added_vars:
goto_fn.opi = insert_gromet_object(goto_fn.opi, GrometPort(box=len(goto_fn.b)))
goto_fn.opi = insert_gromet_object(goto_fn.opi, GrometPort(box=len(goto_fn.b), name=v))
added_vars.append(v)

# The goto expression FN has two potential expressions, to determine the label and the index
Expand All @@ -3285,7 +3294,7 @@ def visit_goto(
if node.label in self.labels:
label_index = self.labels[node.label]
else:
label_index = -1
label_index = 0

# Multiple gotos could reference the same label that we haven't seen
# So we maintain a list that we can update later
Expand All @@ -3300,8 +3309,23 @@ def visit_goto(
goto_fn.opo = insert_gromet_object(goto_fn.opo, GrometPort(box=len(goto_fn.b), name="fn_idx"))
goto_fn.wfopo = insert_gromet_object(goto_fn.wfopo, GrometWire(src=len(goto_fn.opo), tgt=len(goto_fn.pof)))

goto_fn.bf = insert_gromet_object(goto_fn.bf, GrometBoxFunction(function_type=FunctionType.LITERAL, value=node.label))
label_bf = len(goto_fn.bf)
goto_fn.bf = insert_gromet_object(goto_fn.bf, GrometBoxFunction(function_type=FunctionType.LITERAL, value=node.label))
label_bf = len(goto_fn.bf)
else:
self.visit(node.expr, goto_fn, node)
index_comp_bf = len(goto_fn.bf)

for idx,_ in enumerate(goto_fn.opi, 1):
goto_fn.pif = insert_gromet_object(goto_fn.pif, GrometPort(box=1))
goto_fn.wfopi = insert_gromet_object(goto_fn.wfopi, GrometWire(src=len(goto_fn.pif), tgt=idx))

goto_fn.pof = insert_gromet_object(goto_fn.pof, GrometPort(box=index_comp_bf))
goto_fn.opo = insert_gromet_object(goto_fn.opo, GrometPort(box=len(goto_fn.b), name="fn_idx"))
goto_fn.wfopo = insert_gromet_object(goto_fn.wfopo, GrometWire(src=len(goto_fn.opo), tgt=len(goto_fn.pof)))

goto_fn.bf = insert_gromet_object(goto_fn.bf, GrometBoxFunction(function_type=FunctionType.LITERAL, value=GLiteralValue("None","None")))
label_bf = len(goto_fn.bf)

goto_fn.pof = insert_gromet_object(goto_fn.pof, GrometPort(box=label_bf))
goto_fn.opo = insert_gromet_object(goto_fn.opo, GrometPort(box=len(goto_fn.b), name="label"))
goto_fn.wfopo = insert_gromet_object(goto_fn.wfopo, GrometWire(src=len(goto_fn.opo), tgt=len(goto_fn.pof)))
Expand Down
24 changes: 22 additions & 2 deletions skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def to_agraph(self):
"""Visits the entire CAST object to populate the graph G
and returns an AGraph of the graph G as a result.
"""
self.visit_list(self.cast.nodes)
if isinstance(self.cast, list):
self.visit_list(self.cast[0].nodes)
else:
self.visit_list(self.cast.nodes)
A = nx.nx_agraph.to_agraph(self.G)
A.graph_attr.update(
{"dpi": 227, "fontsize": 20, "fontname": "Menlo", "rankdir": "TB"}
Expand Down Expand Up @@ -474,6 +477,22 @@ def _(self, node: FunctionDef):

return node_uid

@visit.register
def _(self, node: Goto):
node_uid = uuid.uuid4()
if node.expr == None:
self.G.add_node(node_uid, label=f"Goto {node.label}")
else:
self.G.add_node(node_uid, label="Goto (Computed)")

return node_uid

@visit.register
def _(self, node: Label):
node_uid = uuid.uuid4()
self.G.add_node(node_uid, label=f"Label {node.label}")
return node_uid

@visit.register
def _(self, node: Loop):
"""Visits Loop nodes. We visit the conditional expression and the
Expand Down Expand Up @@ -950,7 +969,8 @@ def _(self, node: Name):
node_uid = uuid.uuid4()

class_init = False
for n in self.cast.nodes[0].body:
body = self.cast[0].nodes[0].body if isinstance(self.cast, list) else self.cast.nodes[0].body
for n in body:
if isinstance(n, RecordDef) and n.name == node.name:
class_init = True
self.G.add_node(node_uid, label=node.name + " Init()")
Expand Down
Loading

0 comments on commit f899664

Please sign in to comment.