Skip to content

Commit

Permalink
feat: call graph computer (without closures) (#782)
Browse files Browse the repository at this point in the history
### Summary of Changes

Implement a first version of the call graph computer that can handle
everything but closures. These require the partial evaluator and will be
done later.

---------

Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
lars-reimann and megalinter-bot committed Nov 20, 2023
1 parent b909cb8 commit 34bf182
Show file tree
Hide file tree
Showing 76 changed files with 1,375 additions and 26 deletions.
26 changes: 26 additions & 0 deletions packages/safe-ds-lang/src/language/flow/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { SdsCallable } from '../generated/ast.js';
import { stream, Stream } from 'langium';

export class CallGraph {
constructor(
readonly root: SdsCallable | undefined,
readonly children: CallGraph[],
readonly isRecursive: boolean = false,
) {}

/**
* Traverses the call graph depth-first in pre-order and returns a stream of all callables that are called directly
* or indirectly.
*/
streamCalledCallables(): Stream<SdsCallable | undefined> {
return stream(this.streamCalledCallablesGenerator());
}

private *streamCalledCallablesGenerator(): Generator<SdsCallable | undefined, void> {
yield this.root;

for (const child of this.children) {
yield* child.streamCalledCallablesGenerator();
}
}
}
330 changes: 327 additions & 3 deletions packages/safe-ds-lang/src/language/flow/safe-ds-call-graph-computer.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
import { AstNode, type AstNodeLocator, getDocument, streamAst, WorkspaceCache } from 'langium';
import { isSdsCall, type SdsCall } from '../generated/ast.js';
import {
AstNode,
type AstNodeLocator,
getContainerOfType,
getDocument,
isNamed,
stream,
streamAst,
WorkspaceCache,
} from 'langium';
import {
isSdsAnnotation,
isSdsBlockLambda,
isSdsCall,
isSdsCallable,
isSdsCallableType,
isSdsClass,
isSdsEnumVariant,
isSdsExpressionLambda,
isSdsFunction,
isSdsParameter,
isSdsSegment,
SdsArgument,
SdsBlockLambda,
SdsCall,
SdsCallable,
SdsClass,
SdsEnumVariant,
SdsExpression,
SdsExpressionLambda,
SdsFunction,
SdsParameter,
SdsSegment,
} from '../generated/ast.js';
import type { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js';
import type { SafeDsServices } from '../safe-ds-module.js';
import {
BlockLambdaClosure,
EvaluatedCallable,
ExpressionLambdaClosure,
NamedCallable,
ParameterSubstitutions,
substitutionsAreEqual,
UnknownEvaluatedNode,
} from '../partialEvaluation/model.js';
import { CallGraph } from './model.js';
import { getArguments, getParameters } from '../helpers/nodeProperties.js';
import { SafeDsTypeComputer } from '../typing/safe-ds-type-computer.js';
import { CallableType, StaticType } from '../typing/model.js';
import { isEmpty } from '../../helpers/collectionUtils.js';
import { SafeDsPartialEvaluator } from '../partialEvaluation/safe-ds-partial-evaluator.js';

export class SafeDsCallGraphComputer {
private readonly astNodeLocator: AstNodeLocator;
private readonly nodeMapper: SafeDsNodeMapper;
private readonly partialEvaluator: SafeDsPartialEvaluator;
private readonly typeComputer: SafeDsTypeComputer;

/**
* Stores the calls inside the node with the given ID.
Expand All @@ -15,11 +64,268 @@ export class SafeDsCallGraphComputer {
constructor(services: SafeDsServices) {
this.astNodeLocator = services.workspace.AstNodeLocator;
this.nodeMapper = services.helpers.NodeMapper;
this.partialEvaluator = services.evaluation.PartialEvaluator;
this.typeComputer = services.types.TypeComputer;

this.callCache = new WorkspaceCache(services.shared);
}

getCalls(node: AstNode): SdsCall[] {
/**
* Returns whether the given call is recursive using the given parameter substitutions.
*
* @param node
* The call to check.
*
* @param substitutions
* The parameter substitutions to use. These are **not** the argument of the call, but the values of the parameters
* of any containing callables, i.e. the context of the call.
*/
isRecursive(node: SdsCall, substitutions: ParameterSubstitutions = NO_SUBSTITUTIONS): boolean {
return this.getCallGraph(node, substitutions).isRecursive;
}

/**
* Returns a stream of all callables that are called directly or indirectly by the given call. The call graph is
* traversed depth-first. If a callable is called recursively, it is only included once again.
*
* @param node
* The call to check.
*
* @param substitutions
* The parameter substitutions to use. These are **not** the argument of the call, but the values of the parameters
* of any containing callables, i.e. the context of the call.
*/
getCallGraph(node: SdsCall, substitutions: ParameterSubstitutions = NO_SUBSTITUTIONS): CallGraph {
const call = this.createSyntheticCallForCall(node, substitutions);
return this.getCallGraphWithRecursionCheck(call, []);
}

private getCallGraphWithRecursionCheck(syntheticCall: SyntheticCall, visited: SyntheticCall[]): CallGraph {
const evaluatedCallable = syntheticCall.callable;

// Handle unknown callables & recursive calls
if (!evaluatedCallable) {
return new CallGraph(undefined, [], false);
} else if (visited.some((it) => it.equals(syntheticCall))) {
return new CallGraph(evaluatedCallable.callable, [], true);
}

// Visit all calls in the callable
const newVisited = [...visited, syntheticCall];
const children = this.getExecutedCalls(syntheticCall).map((it) => {
return this.getCallGraphWithRecursionCheck(it, newVisited);
});

return new CallGraph(
evaluatedCallable.callable,
children,
children.some((it) => it.isRecursive),
);
}

private getExecutedCalls(syntheticCall: SyntheticCall): SyntheticCall[] {
if (!syntheticCall.callable) {
/* c8 ignore next 2 */
return [];
}

const callable = syntheticCall.callable.callable;
const substitutions = syntheticCall.substitutions;

if (isSdsBlockLambda(callable) || isSdsExpressionLambda(callable) || isSdsSegment(callable)) {
return this.getExecutedCallsInPipelineCallable(callable, substitutions);
} else if (isSdsClass(callable) || isSdsEnumVariant(callable) || isSdsFunction(callable)) {
return this.getExecutedCallsInStubCallable(callable, substitutions);
} else {
/* c8 ignore next 2 */
return [];
}
}

private getExecutedCallsInPipelineCallable(
callable: SdsBlockLambda | SdsExpressionLambda | SdsSegment,
substitutions: ParameterSubstitutions,
): SyntheticCall[] {
const callsInDefaultValues = getParameters(callable).flatMap((it) => {
// The default value is only executed if no argument is passed for the parameter
if (it.defaultValue && !substitutions.has(it)) {
return this.getCalls(it.defaultValue);
} else {
return [];
}
});

let callsInBody: SdsCall[];
if (isSdsBlockLambda(callable)) {
callsInBody = this.getCalls(callable.body);
} else if (isSdsExpressionLambda(callable)) {
callsInBody = this.getCalls(callable.result);
} else {
callsInBody = this.getCalls(callable.body);
}

return [...callsInDefaultValues, ...callsInBody]
.filter((it) => getContainerOfType(it, isSdsCallable) === callable)
.map((it) => this.createSyntheticCallForCall(it, substitutions));
}

private getExecutedCallsInStubCallable(
callable: SdsClass | SdsEnumVariant | SdsFunction,
substitutions: ParameterSubstitutions,
): SyntheticCall[] {
const callsInDefaultValues = getParameters(callable).flatMap((parameter) => {
// The default value is only executed if no argument is passed for the parameter
if (parameter.defaultValue && !substitutions.has(parameter)) {
// We assume all calls in the default value are executed
const calls = this.getCalls(parameter.defaultValue);
if (!isEmpty(calls)) {
return calls.map((call) => this.createSyntheticCallForCall(call, substitutions));
}

// We assume a single callable as default value is executed
const evaluatedCallable = this.getEvaluatedCallable(parameter.defaultValue, substitutions);
if (evaluatedCallable) {
return [this.createSyntheticCallForEvaluatedCallable(evaluatedCallable)];
}
}

return [];
});

const callablesInSubstitutions = stream(substitutions.values()).flatMap((it) => {
if (it instanceof EvaluatedCallable) {
return [this.createSyntheticCallForEvaluatedCallable(it)];
}

return [];
});

return [...callsInDefaultValues, ...callablesInSubstitutions];
}

private createSyntheticCallForCall(call: SdsCall, substitutions: ParameterSubstitutions): SyntheticCall {
const evaluatedCallable = this.getEvaluatedCallable(call.receiver, substitutions);
const newSubstitutions = this.getNewSubstitutions(evaluatedCallable, getArguments(call), substitutions);
return new SyntheticCall(evaluatedCallable, newSubstitutions);
}

private createSyntheticCallForEvaluatedCallable(evaluatedCallable: EvaluatedCallable): SyntheticCall {
return new SyntheticCall(evaluatedCallable, evaluatedCallable.substitutionsOnCreation);
}

private getEvaluatedCallable(
expression: SdsExpression,
substitutions: ParameterSubstitutions,
): EvaluatedCallable | undefined {
// TODO use the partial evaluator here; necessary for closures
// const value = this.partialEvaluator.evaluate(expression, substitutions);
// if (value instanceof EvaluatedCallable) {
// return value;
// }
//
// return undefined;

let callableOrParameter = this.getCallableOrParameter(expression);

if (!callableOrParameter || isSdsAnnotation(callableOrParameter) || isSdsCallableType(callableOrParameter)) {
return undefined;
} else if (isSdsParameter(callableOrParameter)) {
// Parameter is set explicitly
const substitution = substitutions.get(callableOrParameter);
if (substitution) {
if (substitution instanceof EvaluatedCallable) {
return substitution;
} else {
/* c8 ignore next 2 */
return undefined;
}
}

// Parameter might have a default value
if (!callableOrParameter.defaultValue) {
return undefined;
}
return this.getEvaluatedCallable(callableOrParameter.defaultValue, substitutions);
} else if (isNamed(callableOrParameter)) {
return new NamedCallable(callableOrParameter);
} else if (isSdsBlockLambda(callableOrParameter)) {
return new BlockLambdaClosure(callableOrParameter, substitutions);
} else if (isSdsExpressionLambda(callableOrParameter)) {
return new ExpressionLambdaClosure(callableOrParameter, substitutions);
} else {
/* c8 ignore next 2 */
return undefined;
}
}

private getCallableOrParameter(expression: SdsExpression): SdsCallable | SdsParameter | undefined {
const type = this.typeComputer.computeType(expression);

if (type instanceof CallableType) {
return type.parameter ?? type.callable;
} else if (type instanceof StaticType) {
const declaration = type.instanceType.declaration;
if (isSdsCallable(declaration)) {
return declaration;
}
}

return undefined;
}

private getNewSubstitutions(
callable: EvaluatedCallable | undefined,
args: SdsArgument[],
substitutions: ParameterSubstitutions,
): ParameterSubstitutions {
if (!callable) {
return NO_SUBSTITUTIONS;
}

// Substitutions on creation
const substitutionsOnCreation = callable.substitutionsOnCreation;

// Substitutions on call
const parameters = getParameters(callable?.callable);
const substitutionsOnCall = new Map(
args.flatMap((it) => {
// Ignore arguments that don't get assigned to a parameter
const parameterIndex = this.nodeMapper.argumentToParameter(it)?.$containerIndex ?? -1;
if (parameterIndex === -1) {
/* c8 ignore next 2 */
return [];
}

// argumentToParameter returns parameters of callable types. We have to remap this to parameter of the
// actual callable.
const parameter = parameters[parameterIndex];
if (!parameter) {
/* c8 ignore next 2 */
return [];
}

const value = this.getEvaluatedCallable(it.value, substitutions);
if (!value) {
// We still have to remember that a value was passed, so the default value is not used
return [[parameter, UnknownEvaluatedNode]];
}

return [[parameter, value]];
}),
);

return new Map([...substitutionsOnCreation, ...substitutionsOnCall]);
}

/**
* Returns all calls inside the given node. If the node is a call, it is included as well.
*/
getCalls(node: AstNode | undefined): SdsCall[] {
if (!node) {
/* c8 ignore next 2 */
return [];
}

const key = this.getNodeId(node);
return this.callCache.get(key, () => streamAst(node).filter(isSdsCall).toArray());
}
Expand All @@ -30,3 +336,21 @@ export class SafeDsCallGraphComputer {
return `${documentUri}~${nodePath}`;
}
}

class SyntheticCall {
constructor(
readonly callable: EvaluatedCallable | undefined,
readonly substitutions: ParameterSubstitutions,
) {}

equals(other: SyntheticCall): boolean {
if (!this.callable) {
/* c8 ignore next 2 */
return !other.callable && substitutionsAreEqual(this.substitutions, other.substitutions);
}

return this.callable.equals(other.callable) && substitutionsAreEqual(this.substitutions, other.substitutions);
}
}

const NO_SUBSTITUTIONS: ParameterSubstitutions = new Map();
Loading

0 comments on commit 34bf182

Please sign in to comment.