Skip to content

Commit

Permalink
feat: Method calls with a by-proof (#5662)
Browse files Browse the repository at this point in the history
Fixes #5582. To enable this, we split method calls into two separate
calls on the boogie level.

```
TODAY:
  // check termination...
  call z := Call$$MyMethod(u);

AFTER SPLITTING Call$$, we have:
  // check termination...
  call Call_Pre$$MyMethod(u);
  call z := Call_Post$$MyMethod(u);

FOR CALL-BY, WE HAVE:
   // check termination...
   if (*) {
     // include the proof here
     LemmaAboutP();
     call Call_Pre$$MyMethod(u);
     assume false;
   }
   call z := Call_Post$$MyMethod(u);
```
  • Loading branch information
fabiomadge authored Aug 12, 2024
1 parent f5e5106 commit f7a0a55
Show file tree
Hide file tree
Showing 32 changed files with 438 additions and 158 deletions.
19 changes: 16 additions & 3 deletions Source/DafnyCore/AST/Grammar/Printer/Printer.Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ public void PrintStatement(Statement stmt, int indent) {
wr.Write(" ");
}
PrintUpdateRHS(s, indent);
wr.Write(";");
PrintBy(s);

} else if (stmt is CallStmt) {
// Most calls are printed from their concrete syntax given in the input. However, recursive calls to
Expand Down Expand Up @@ -395,8 +395,7 @@ public void PrintStatement(Statement stmt, int indent) {
wr.Write(" ");
PrintUpdateRHS(s.Update, indent);
}
wr.Write(";");

PrintBy(s.Update);
} else if (stmt is VarDeclPattern) {
var s = (VarDeclPattern)stmt;
if (s.tok is AutoGeneratedToken) {
Expand Down Expand Up @@ -455,6 +454,20 @@ public void PrintStatement(Statement stmt, int indent) {
} else {
Contract.Assert(false); throw new cce.UnreachableException(); // unexpected statement
}

void PrintBy(ConcreteUpdateStatement statement) {
BlockStmt proof = statement switch {
UpdateStmt updateStmt => updateStmt.Proof,
AssignOrReturnStmt returnStmt => returnStmt.Proof,
_ => null
};
if (proof != null) {
wr.Write(" by ");
PrintStatement(proof, indent);
} else {
wr.Write(";");
}
}
}

private void PrintHideReveal(HideRevealStmt revealStmt) {
Expand Down
11 changes: 8 additions & 3 deletions Source/DafnyCore/AST/Statements/Assignment/AssignOrReturnStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class AssignOrReturnStmt : ConcreteUpdateStatement, ICloneable<AssignOrRe
public readonly ExprRhs Rhs; // this is the unresolved RHS, and thus can also be a method call
public readonly List<AssignmentRhs> Rhss;
public readonly AttributedToken KeywordToken;
public readonly BlockStmt Proof;
[FilledInDuringResolution] public readonly List<Statement> ResolvedStatements = new List<Statement>();
public override IEnumerable<Statement> SubStatements => ResolvedStatements;
public override IToken Tok {
Expand All @@ -22,7 +23,7 @@ public override IToken Tok {
}
}

public override IEnumerable<INode> Children => ResolvedStatements;
public override IEnumerable<INode> Children => ResolvedStatements.Concat(Proof?.Children ?? new List<INode>());
public override IEnumerable<Statement> PreResolveSubStatements => Enumerable.Empty<Statement>();

[ContractInvariantMethod]
Expand All @@ -43,13 +44,14 @@ public AssignOrReturnStmt(Cloner cloner, AssignOrReturnStmt original) : base(clo
Rhs = (ExprRhs)cloner.CloneRHS(original.Rhs);
Rhss = original.Rhss.ConvertAll(cloner.CloneRHS);
KeywordToken = cloner.AttributedTok(original.KeywordToken);
Proof = cloner.CloneBlockStmt(original.Proof);

if (cloner.CloneResolvedFields) {
ResolvedStatements = original.ResolvedStatements.Select(stmt => cloner.CloneStmt(stmt, false)).ToList();
}
}

public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs rhs, AttributedToken keywordToken, List<AssignmentRhs> rhss)
public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs rhs, AttributedToken keywordToken, List<AssignmentRhs> rhss, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(rangeToken != null);
Contract.Requires(lhss != null);
Expand All @@ -59,6 +61,7 @@ public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs
Rhs = rhs;
Rhss = rhss;
KeywordToken = keywordToken;
Proof = proof;
}

public override IEnumerable<Expression> PreResolveSubExpressions {
Expand Down Expand Up @@ -190,6 +193,8 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
return;
}

ModuleResolver.ResolveByProof(resolver, Proof, resolutionContext);

Expression lhsExtract = null;
if (expectExtract) {
if (resolutionContext.CodeContext is Method caller && caller.Outs.Count == 0 && KeywordToken == null) {
Expand Down Expand Up @@ -285,7 +290,7 @@ private void DesugarElephantStatement(bool expectExtract, Expression lhsExtract,
}
}
// " temp, ... := MethodOrExpression, ...;"
UpdateStmt up = new UpdateStmt(RangeToken, lhss2, rhss2);
UpdateStmt up = new UpdateStmt(RangeToken, lhss2, rhss2, Proof);
if (expectExtract) {
up.OriginalInitialLhs = Lhss.Count == 0 ? null : Lhss[0];
}
Expand Down
15 changes: 11 additions & 4 deletions Source/DafnyCore/AST/Statements/Assignment/UpdateStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public class UpdateStmt : ConcreteUpdateStatement, ICloneable<UpdateStmt>, ICanR
public readonly List<AssignmentRhs> Rhss;
public readonly bool CanMutateKnownState;
public Expression OriginalInitialLhs = null;
public readonly BlockStmt Proof;

public override IToken Tok {
get {
Expand All @@ -22,7 +23,7 @@ public override IToken Tok {
}

[FilledInDuringResolution] public List<Statement> ResolvedStatements;
public override IEnumerable<Statement> SubStatements => Children.OfType<Statement>();
public override IEnumerable<Statement> SubStatements => Children.OfType<Statement>().Concat(Proof != null ? Proof.SubStatements : new List<Statement>());

public override IEnumerable<Expression> NonSpecificationSubExpressions =>
ResolvedStatements == null ? Rhss.SelectMany(r => r.NonSpecificationSubExpressions) : Enumerable.Empty<Expression>();
Expand All @@ -45,26 +46,29 @@ public UpdateStmt Clone(Cloner cloner) {
public UpdateStmt(Cloner cloner, UpdateStmt original) : base(cloner, original) {
Rhss = original.Rhss.Select(cloner.CloneRHS).ToList();
CanMutateKnownState = original.CanMutateKnownState;
Proof = cloner.CloneBlockStmt(original.Proof);
if (cloner.CloneResolvedFields) {
ResolvedStatements = original.ResolvedStatements.Select(stmt => cloner.CloneStmt(stmt, false)).ToList();
}
}

public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss)
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(cce.NonNullElements(lhss));
Contract.Requires(cce.NonNullElements(rhss));
Contract.Requires(lhss.Count != 0 || rhss.Count == 1);
Rhss = rhss;
CanMutateKnownState = false;
Proof = proof;
}
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, bool mutate)
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, bool mutate, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(cce.NonNullElements(lhss));
Contract.Requires(cce.NonNullElements(rhss));
Contract.Requires(lhss.Count != 0 || rhss.Count == 1);
Rhss = rhss;
CanMutateKnownState = mutate;
Proof = proof;
}

public override IEnumerable<Expression> PreResolveSubExpressions {
Expand Down Expand Up @@ -123,6 +127,9 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
resolver.ResolveAttributes(rhs, resolutionContext);
}

// resolve proof
ModuleResolver.ResolveByProof(resolver, Proof, resolutionContext);

// figure out what kind of UpdateStmt this is
if (firstEffectfulRhs == null) {
if (Lhss.Count == 0) {
Expand Down Expand Up @@ -178,7 +185,7 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
foreach (var ll in Lhss) {
resolvedLhss.Add(ll.Resolved);
}
CallStmt a = new CallStmt(RangeToken, resolvedLhss, methodCallInfo.Callee, methodCallInfo.ActualParameters, methodCallInfo.Tok);
CallStmt a = new CallStmt(RangeToken, resolvedLhss, methodCallInfo.Callee, methodCallInfo.ActualParameters, methodCallInfo.Tok, Proof);
a.OriginalInitialLhs = OriginalInitialLhs;
ResolvedStatements.Add(a);
}
Expand Down
9 changes: 6 additions & 3 deletions Source/DafnyCore/AST/Statements/Methods/CallStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ void ObjectInvariant() {
public readonly ActualBindings Bindings;
public List<Expression> Args => Bindings.Arguments;
public Expression OriginalInitialLhs = null;
public readonly BlockStmt Proof;

public Expression Receiver { get { return MethodSelect.Obj; } }
public Method Method { get { return (Method)MethodSelect.Member; } }

public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<ActualBinding> args, IToken overrideToken = null)
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<ActualBinding> args, IToken overrideToken = null, BlockStmt proof = null)
: base(rangeToken) {
Contract.Requires(rangeToken != null);
Contract.Requires(cce.NonNullElements(lhs));
Expand All @@ -41,6 +42,7 @@ public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr me
this.MethodSelect = memSel;
this.overrideToken = overrideToken;
this.Bindings = new ActualBindings(args);
Proof = proof;
}

public CallStmt Clone(Cloner cloner) {
Expand All @@ -52,14 +54,15 @@ public CallStmt(Cloner cloner, CallStmt original) : base(cloner, original) {
Lhs = original.Lhs.Select(cloner.CloneExpr).ToList();
Bindings = new ActualBindings(cloner, original.Bindings);
overrideToken = original.overrideToken;
Proof = cloner.CloneBlockStmt(original.Proof);
}

/// <summary>
/// This constructor is intended to be used when constructing a resolved CallStmt. The "args" are expected
/// to be already resolved, and are all given positionally.
/// </summary>
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<Expression> args)
: this(rangeToken, lhs, memSel, args.ConvertAll(e => new ActualBinding(null, e))) {
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<Expression> args, BlockStmt proof = null)
: this(rangeToken, lhs, memSel, args.ConvertAll(e => new ActualBinding(null, e)), proof: proof) {
Bindings.AcceptArgumentExpressionsAsExactParameterList();
}

Expand Down
11 changes: 1 addition & 10 deletions Source/DafnyCore/AST/Statements/Verification/AssertStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,7 @@ public override void GenResolve(INewOrOldResolver resolver, ResolutionContext co

base.GenResolve(resolver, context);

if (Proof != null) {
// clear the labels for the duration of checking the proof body, because break statements are not allowed to leave a the proof body
var prevLblStmts = resolver.EnclosingStatementLabels;
var prevLoopStack = resolver.LoopStack;
resolver.EnclosingStatementLabels = new Scope<Statement>(resolver.Options);
resolver.LoopStack = new List<Statement>();
resolver.ResolveStatement(Proof, context);
resolver.EnclosingStatementLabels = prevLblStmts;
resolver.LoopStack = prevLoopStack;
}
ModuleResolver.ResolveByProof(resolver, Proof, context);
}

public bool HasAssertOnlyAttribute(out AssertOnlyKind assertOnlyKind) {
Expand Down
30 changes: 22 additions & 8 deletions Source/DafnyCore/Dafny.atg
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,7 @@ UpdateStmt<out Statement/*!*/ s>
List<AssignmentRhs> rhss = new List<AssignmentRhs>();
Expression e;
AssignmentRhs r;
BlockStmt proof = null;
IToken x = Token.NoToken;
IToken endTok = Token.NoToken;
IToken startToken = Token.NoToken;
Expand Down Expand Up @@ -2334,10 +2335,16 @@ UpdateStmt<out Statement/*!*/ s>
{ "," Rhs<out r> (. rhss.Add(r); .)
}
)
";" (. endTok = t; .)
(
"by" BlockStmt<out proof, out _, out _>
| ";"
)
| ":" (. SemErr(ErrorId.p_invalid_colon, new RangeToken(startToken, t), "invalid statement beginning here (is a 'label' keyword missing? or a 'const' or 'var' keyword?)"); .)
| { Attribute<ref attrs> } (. endToken = t; .)
";" (. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) }); .)
(
"by" BlockStmt<out proof, out _, out _>
| ";"
) (. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) }); .)
| { Attribute<ref attrs> } (. endToken = t; .)
(. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) });
SemErr(ErrorId.p_missing_semicolon, new RangeToken(startToken, t), "missing semicolon at end of statement");
Expand All @@ -2352,18 +2359,21 @@ UpdateStmt<out Statement/*!*/ s>
{ Attribute<ref attrs> } (. exceptionRhs = new ExprRhs(exceptionExpr, attrs); .)
{ "," Rhs<out r> (. rhss.Add(r); .)
}
";" (. endTok = t; .)
(
"by" BlockStmt<out proof, out _, out _>
| ";" (. endTok = t; .)
)
)
(. var rangeToken = new RangeToken(startToken, t);
if (suchThat != null) {
s = new AssignSuchThatStmt(rangeToken, lhss, suchThat, suchThatAssume, null);
} else if (exceptionRhs != null) {
s = new AssignOrReturnStmt(rangeToken, lhss, exceptionRhs, keywordToken, rhss);
s = new AssignOrReturnStmt(rangeToken, lhss, exceptionRhs, keywordToken, rhss, proof);
} else {
if (lhss.Count == 0 && rhss.Count == 0) {
s = new BlockStmt(rangeToken, new List<Statement>()); // error, give empty statement
} else {
s = new UpdateStmt(rangeToken, lhss, rhss);
s = new UpdateStmt(rangeToken, lhss, rhss, proof);
}
}
.)
Expand Down Expand Up @@ -2463,6 +2473,7 @@ VarDeclStatement<.out Statement/*!*/ s.>
ExprRhs exceptionRhs = null;
Attributes attrs = null;
Attributes tokenAttrs = null;
BlockStmt proof = null;
IToken endTok;
IToken startToken = null;
s = dummyStmt;
Expand Down Expand Up @@ -2499,7 +2510,10 @@ VarDeclStatement<.out Statement/*!*/ s.>
{ "," Rhs<out r> (. rhss.Add(r); .)
}
]
SYNC ";" (. endTok = t; .)
SYNC ( "by"
BlockStmt<out proof, out _, out _>
| ";" (. endTok = t; .)
)
(. ConcreteUpdateStatement update;
var lhsExprs = new List<Expression>();
if (isGhost || (rhss.Count == 0 && exceptionRhs == null && suchThat == null)) { // explicitly ghost or no init
Expand All @@ -2516,11 +2530,11 @@ VarDeclStatement<.out Statement/*!*/ s.>
if (suchThat != null) {
update = new AssignSuchThatStmt(updateRangeToken, lhsExprs, suchThat, suchThatAssume, attrs);
} else if (exceptionRhs != null) {
update = new AssignOrReturnStmt(updateRangeToken, lhsExprs, exceptionRhs, keywordToken, rhss);
update = new AssignOrReturnStmt(updateRangeToken, lhsExprs, exceptionRhs, keywordToken, rhss, proof);
} else if (rhss.Count == 0) {
update = null;
} else {
update = new UpdateStmt(updateRangeToken, lhsExprs, rhss);
update = new UpdateStmt(updateRangeToken, lhsExprs, rhss, proof);
}
s = new VarDeclStmt(rangeToken, lhss, update);
.)
Expand Down
3 changes: 3 additions & 0 deletions Source/DafnyCore/Resolver/GhostInterestVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ public void Visit(Statement stmt, bool mustBeErasable, [CanBeNull] string proofC
var s = (UpdateStmt)stmt;
s.ResolvedStatements.ForEach(ss => Visit(ss, mustBeErasable, proofContext));
s.IsGhost = s.ResolvedStatements.All(ss => ss.IsGhost);
if (s.Proof != null) {
Visit(s.Proof, true, "a call-by body");
}

} else if (stmt is AssignOrReturnStmt) {
var s = (AssignOrReturnStmt)stmt;
Expand Down
15 changes: 15 additions & 0 deletions Source/DafnyCore/Resolver/ModuleResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3311,6 +3311,21 @@ internal LetExpr LetVarIn(IToken tok, string name, Type tp, Expression rhs, Expr
return LetPatIn(tok, lhs, rhs, body);
}

internal static void ResolveByProof(INewOrOldResolver resolver, BlockStmt proof, ResolutionContext resolutionContext) {
if (proof == null) {
return;
}

// clear the labels for the duration of checking the proof body, because break statements are not allowed to leave the proof body
var prevLblStmts = resolver.EnclosingStatementLabels;
var prevLoopStack = resolver.LoopStack;
resolver.EnclosingStatementLabels = new Scope<Statement>(resolver.Options);
resolver.LoopStack = new List<Statement>();
resolver.ResolveStatement(proof, resolutionContext);
resolver.EnclosingStatementLabels = prevLblStmts;
resolver.LoopStack = prevLoopStack;
}

/// <summary>
/// If expr.Lhs != null: Desugars "var x: T :- E; F" into "var temp := E; if temp.IsFailure() then temp.PropagateFailure() else var x: T := temp.Extract(); F"
/// If expr.Lhs == null: Desugars " :- E; F" into "var temp := E; if temp.IsFailure() then temp.PropagateFailure() else F"
Expand Down
Loading

0 comments on commit f7a0a55

Please sign in to comment.