Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Method calls with a by-proof #5662

Merged
merged 21 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions Source/DafnyCore/AST/Grammar/Printer/Printer.Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ public void PrintStatement(Statement stmt, int indent) {
wr.Write(" ");
}
PrintUpdateRHS(s, indent);
wr.Write(";");
PrintBy(s);

} else if (stmt is CallStmt) {
// Most calls are printed from their concrete syntax given in the input. However, recursive calls to
Expand Down Expand Up @@ -395,8 +395,7 @@ public void PrintStatement(Statement stmt, int indent) {
wr.Write(" ");
PrintUpdateRHS(s.Update, indent);
}
wr.Write(";");

PrintBy(s.Update);
} else if (stmt is VarDeclPattern) {
var s = (VarDeclPattern)stmt;
if (s.tok is AutoGeneratedToken) {
Expand Down Expand Up @@ -455,6 +454,20 @@ public void PrintStatement(Statement stmt, int indent) {
} else {
Contract.Assert(false); throw new cce.UnreachableException(); // unexpected statement
}

void PrintBy(ConcreteUpdateStatement statement) {
BlockStmt proof = statement switch {
UpdateStmt updateStmt => updateStmt.Proof,
AssignOrReturnStmt returnStmt => returnStmt.Proof,
_ => null
};
if (proof != null) {
wr.Write(" by ");
PrintStatement(proof, indent);
} else {
wr.Write(";");
}
}
}

private void PrintHideReveal(HideRevealStmt revealStmt) {
Expand Down
11 changes: 8 additions & 3 deletions Source/DafnyCore/AST/Statements/Assignment/AssignOrReturnStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class AssignOrReturnStmt : ConcreteUpdateStatement, ICloneable<AssignOrRe
public readonly ExprRhs Rhs; // this is the unresolved RHS, and thus can also be a method call
public readonly List<AssignmentRhs> Rhss;
public readonly AttributedToken KeywordToken;
public readonly BlockStmt Proof;
keyboardDrummer marked this conversation as resolved.
Show resolved Hide resolved
[FilledInDuringResolution] public readonly List<Statement> ResolvedStatements = new List<Statement>();
public override IEnumerable<Statement> SubStatements => ResolvedStatements;
public override IToken Tok {
Expand All @@ -22,7 +23,7 @@ public override IToken Tok {
}
}

public override IEnumerable<INode> Children => ResolvedStatements;
public override IEnumerable<INode> Children => ResolvedStatements.Concat(Proof?.Children ?? new List<INode>());
public override IEnumerable<Statement> PreResolveSubStatements => Enumerable.Empty<Statement>();

[ContractInvariantMethod]
Expand All @@ -43,13 +44,14 @@ public AssignOrReturnStmt(Cloner cloner, AssignOrReturnStmt original) : base(clo
Rhs = (ExprRhs)cloner.CloneRHS(original.Rhs);
Rhss = original.Rhss.ConvertAll(cloner.CloneRHS);
KeywordToken = cloner.AttributedTok(original.KeywordToken);
Proof = cloner.CloneBlockStmt(original.Proof);

if (cloner.CloneResolvedFields) {
ResolvedStatements = original.ResolvedStatements.Select(stmt => cloner.CloneStmt(stmt, false)).ToList();
}
}

public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs rhs, AttributedToken keywordToken, List<AssignmentRhs> rhss)
public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs rhs, AttributedToken keywordToken, List<AssignmentRhs> rhss, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(rangeToken != null);
Contract.Requires(lhss != null);
Expand All @@ -59,6 +61,7 @@ public AssignOrReturnStmt(RangeToken rangeToken, List<Expression> lhss, ExprRhs
Rhs = rhs;
Rhss = rhss;
KeywordToken = keywordToken;
Proof = proof;
}

public override IEnumerable<Expression> PreResolveSubExpressions {
Expand Down Expand Up @@ -190,6 +193,8 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
return;
}

ModuleResolver.ResolveByProof(resolver, Proof, resolutionContext);

Expression lhsExtract = null;
if (expectExtract) {
if (resolutionContext.CodeContext is Method caller && caller.Outs.Count == 0 && KeywordToken == null) {
Expand Down Expand Up @@ -285,7 +290,7 @@ private void DesugarElephantStatement(bool expectExtract, Expression lhsExtract,
}
}
// " temp, ... := MethodOrExpression, ...;"
UpdateStmt up = new UpdateStmt(RangeToken, lhss2, rhss2);
UpdateStmt up = new UpdateStmt(RangeToken, lhss2, rhss2, Proof);
if (expectExtract) {
up.OriginalInitialLhs = Lhss.Count == 0 ? null : Lhss[0];
}
Expand Down
15 changes: 11 additions & 4 deletions Source/DafnyCore/AST/Statements/Assignment/UpdateStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public class UpdateStmt : ConcreteUpdateStatement, ICloneable<UpdateStmt>, ICanR
public readonly List<AssignmentRhs> Rhss;
public readonly bool CanMutateKnownState;
public Expression OriginalInitialLhs = null;
public readonly BlockStmt Proof;

public override IToken Tok {
get {
Expand All @@ -22,7 +23,7 @@ public override IToken Tok {
}

[FilledInDuringResolution] public List<Statement> ResolvedStatements;
public override IEnumerable<Statement> SubStatements => Children.OfType<Statement>();
public override IEnumerable<Statement> SubStatements => Children.OfType<Statement>().Concat(Proof != null ? Proof.SubStatements : new List<Statement>());

public override IEnumerable<Expression> NonSpecificationSubExpressions =>
ResolvedStatements == null ? Rhss.SelectMany(r => r.NonSpecificationSubExpressions) : Enumerable.Empty<Expression>();
Expand All @@ -45,26 +46,29 @@ public UpdateStmt Clone(Cloner cloner) {
public UpdateStmt(Cloner cloner, UpdateStmt original) : base(cloner, original) {
Rhss = original.Rhss.Select(cloner.CloneRHS).ToList();
CanMutateKnownState = original.CanMutateKnownState;
Proof = cloner.CloneBlockStmt(original.Proof);
if (cloner.CloneResolvedFields) {
ResolvedStatements = original.ResolvedStatements.Select(stmt => cloner.CloneStmt(stmt, false)).ToList();
}
}

public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss)
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(cce.NonNullElements(lhss));
Contract.Requires(cce.NonNullElements(rhss));
Contract.Requires(lhss.Count != 0 || rhss.Count == 1);
Rhss = rhss;
CanMutateKnownState = false;
Proof = proof;
}
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, bool mutate)
public UpdateStmt(RangeToken rangeToken, List<Expression> lhss, List<AssignmentRhs> rhss, bool mutate, BlockStmt proof = null)
: base(rangeToken, lhss) {
Contract.Requires(cce.NonNullElements(lhss));
Contract.Requires(cce.NonNullElements(rhss));
Contract.Requires(lhss.Count != 0 || rhss.Count == 1);
Rhss = rhss;
CanMutateKnownState = mutate;
Proof = proof;
}

public override IEnumerable<Expression> PreResolveSubExpressions {
Expand Down Expand Up @@ -123,6 +127,9 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
resolver.ResolveAttributes(rhs, resolutionContext);
}

// resolve proof
keyboardDrummer marked this conversation as resolved.
Show resolved Hide resolved
ModuleResolver.ResolveByProof(resolver, Proof, resolutionContext);

// figure out what kind of UpdateStmt this is
if (firstEffectfulRhs == null) {
if (Lhss.Count == 0) {
Expand Down Expand Up @@ -178,7 +185,7 @@ public override void Resolve(ModuleResolver resolver, ResolutionContext resoluti
foreach (var ll in Lhss) {
resolvedLhss.Add(ll.Resolved);
}
CallStmt a = new CallStmt(RangeToken, resolvedLhss, methodCallInfo.Callee, methodCallInfo.ActualParameters, methodCallInfo.Tok);
CallStmt a = new CallStmt(RangeToken, resolvedLhss, methodCallInfo.Callee, methodCallInfo.ActualParameters, methodCallInfo.Tok, Proof);
a.OriginalInitialLhs = OriginalInitialLhs;
ResolvedStatements.Add(a);
}
Expand Down
9 changes: 6 additions & 3 deletions Source/DafnyCore/AST/Statements/Methods/CallStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ void ObjectInvariant() {
public readonly ActualBindings Bindings;
public List<Expression> Args => Bindings.Arguments;
public Expression OriginalInitialLhs = null;
public readonly BlockStmt Proof;
keyboardDrummer marked this conversation as resolved.
Show resolved Hide resolved

public Expression Receiver { get { return MethodSelect.Obj; } }
public Method Method { get { return (Method)MethodSelect.Member; } }

public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<ActualBinding> args, IToken overrideToken = null)
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<ActualBinding> args, IToken overrideToken = null, BlockStmt proof = null)
: base(rangeToken) {
Contract.Requires(rangeToken != null);
Contract.Requires(cce.NonNullElements(lhs));
Expand All @@ -41,6 +42,7 @@ public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr me
this.MethodSelect = memSel;
this.overrideToken = overrideToken;
this.Bindings = new ActualBindings(args);
Proof = proof;
}

public CallStmt Clone(Cloner cloner) {
Expand All @@ -52,14 +54,15 @@ public CallStmt(Cloner cloner, CallStmt original) : base(cloner, original) {
Lhs = original.Lhs.Select(cloner.CloneExpr).ToList();
Bindings = new ActualBindings(cloner, original.Bindings);
overrideToken = original.overrideToken;
Proof = cloner.CloneBlockStmt(original.Proof);
}

/// <summary>
/// This constructor is intended to be used when constructing a resolved CallStmt. The "args" are expected
/// to be already resolved, and are all given positionally.
/// </summary>
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<Expression> args)
: this(rangeToken, lhs, memSel, args.ConvertAll(e => new ActualBinding(null, e))) {
public CallStmt(RangeToken rangeToken, List<Expression> lhs, MemberSelectExpr memSel, List<Expression> args, BlockStmt proof = null)
: this(rangeToken, lhs, memSel, args.ConvertAll(e => new ActualBinding(null, e)), proof: proof) {
Bindings.AcceptArgumentExpressionsAsExactParameterList();
}

Expand Down
11 changes: 1 addition & 10 deletions Source/DafnyCore/AST/Statements/Verification/AssertStmt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,7 @@ public override void GenResolve(INewOrOldResolver resolver, ResolutionContext co

base.GenResolve(resolver, context);

if (Proof != null) {
// clear the labels for the duration of checking the proof body, because break statements are not allowed to leave a the proof body
var prevLblStmts = resolver.EnclosingStatementLabels;
var prevLoopStack = resolver.LoopStack;
resolver.EnclosingStatementLabels = new Scope<Statement>(resolver.Options);
resolver.LoopStack = new List<Statement>();
resolver.ResolveStatement(Proof, context);
resolver.EnclosingStatementLabels = prevLblStmts;
resolver.LoopStack = prevLoopStack;
}
ModuleResolver.ResolveByProof(resolver, Proof, context);
}

public bool HasAssertOnlyAttribute(out AssertOnlyKind assertOnlyKind) {
Expand Down
30 changes: 22 additions & 8 deletions Source/DafnyCore/Dafny.atg
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,7 @@ UpdateStmt<out Statement/*!*/ s>
List<AssignmentRhs> rhss = new List<AssignmentRhs>();
Expression e;
AssignmentRhs r;
BlockStmt proof = null;
IToken x = Token.NoToken;
IToken endTok = Token.NoToken;
IToken startToken = Token.NoToken;
Expand Down Expand Up @@ -2334,10 +2335,16 @@ UpdateStmt<out Statement/*!*/ s>
{ "," Rhs<out r> (. rhss.Add(r); .)
}
)
";" (. endTok = t; .)
(
"by" BlockStmt<out proof, out _, out _>
| ";"
)
| ":" (. SemErr(ErrorId.p_invalid_colon, new RangeToken(startToken, t), "invalid statement beginning here (is a 'label' keyword missing? or a 'const' or 'var' keyword?)"); .)
| { Attribute<ref attrs> } (. endToken = t; .)
";" (. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) }); .)
(
"by" BlockStmt<out proof, out _, out _>
| ";"
) (. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) }); .)
| { Attribute<ref attrs> } (. endToken = t; .)
(. endTok = t; rhss.Add(new ExprRhs(e, attrs) { RangeToken = new RangeToken(e.StartToken, endToken) });
SemErr(ErrorId.p_missing_semicolon, new RangeToken(startToken, t), "missing semicolon at end of statement");
Expand All @@ -2352,18 +2359,21 @@ UpdateStmt<out Statement/*!*/ s>
{ Attribute<ref attrs> } (. exceptionRhs = new ExprRhs(exceptionExpr, attrs); .)
{ "," Rhs<out r> (. rhss.Add(r); .)
}
";" (. endTok = t; .)
(
"by" BlockStmt<out proof, out _, out _>
| ";" (. endTok = t; .)
)
)
(. var rangeToken = new RangeToken(startToken, t);
if (suchThat != null) {
s = new AssignSuchThatStmt(rangeToken, lhss, suchThat, suchThatAssume, null);
} else if (exceptionRhs != null) {
s = new AssignOrReturnStmt(rangeToken, lhss, exceptionRhs, keywordToken, rhss);
s = new AssignOrReturnStmt(rangeToken, lhss, exceptionRhs, keywordToken, rhss, proof);
} else {
if (lhss.Count == 0 && rhss.Count == 0) {
s = new BlockStmt(rangeToken, new List<Statement>()); // error, give empty statement
} else {
s = new UpdateStmt(rangeToken, lhss, rhss);
s = new UpdateStmt(rangeToken, lhss, rhss, proof);
}
}
.)
Expand Down Expand Up @@ -2463,6 +2473,7 @@ VarDeclStatement<.out Statement/*!*/ s.>
ExprRhs exceptionRhs = null;
Attributes attrs = null;
Attributes tokenAttrs = null;
BlockStmt proof = null;
IToken endTok;
IToken startToken = null;
s = dummyStmt;
Expand Down Expand Up @@ -2499,7 +2510,10 @@ VarDeclStatement<.out Statement/*!*/ s.>
{ "," Rhs<out r> (. rhss.Add(r); .)
}
]
SYNC ";" (. endTok = t; .)
SYNC ( "by"
BlockStmt<out proof, out _, out _>
| ";" (. endTok = t; .)
)
(. ConcreteUpdateStatement update;
var lhsExprs = new List<Expression>();
if (isGhost || (rhss.Count == 0 && exceptionRhs == null && suchThat == null)) { // explicitly ghost or no init
Expand All @@ -2516,11 +2530,11 @@ VarDeclStatement<.out Statement/*!*/ s.>
if (suchThat != null) {
update = new AssignSuchThatStmt(updateRangeToken, lhsExprs, suchThat, suchThatAssume, attrs);
} else if (exceptionRhs != null) {
update = new AssignOrReturnStmt(updateRangeToken, lhsExprs, exceptionRhs, keywordToken, rhss);
update = new AssignOrReturnStmt(updateRangeToken, lhsExprs, exceptionRhs, keywordToken, rhss, proof);
} else if (rhss.Count == 0) {
update = null;
} else {
update = new UpdateStmt(updateRangeToken, lhsExprs, rhss);
update = new UpdateStmt(updateRangeToken, lhsExprs, rhss, proof);
}
s = new VarDeclStmt(rangeToken, lhss, update);
.)
Expand Down
3 changes: 3 additions & 0 deletions Source/DafnyCore/Resolver/GhostInterestVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ public void Visit(Statement stmt, bool mustBeErasable, [CanBeNull] string proofC
var s = (UpdateStmt)stmt;
s.ResolvedStatements.ForEach(ss => Visit(ss, mustBeErasable, proofContext));
s.IsGhost = s.ResolvedStatements.All(ss => ss.IsGhost);
if (s.Proof != null) {
Visit(s.Proof, true, "a call-by body");
}

} else if (stmt is AssignOrReturnStmt) {
var s = (AssignOrReturnStmt)stmt;
Expand Down
15 changes: 15 additions & 0 deletions Source/DafnyCore/Resolver/ModuleResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3311,6 +3311,21 @@ internal LetExpr LetVarIn(IToken tok, string name, Type tp, Expression rhs, Expr
return LetPatIn(tok, lhs, rhs, body);
}

internal static void ResolveByProof(INewOrOldResolver resolver, BlockStmt proof, ResolutionContext resolutionContext) {
if (proof == null) {
return;
}

// clear the labels for the duration of checking the proof body, because break statements are not allowed to leave the proof body
var prevLblStmts = resolver.EnclosingStatementLabels;
var prevLoopStack = resolver.LoopStack;
resolver.EnclosingStatementLabels = new Scope<Statement>(resolver.Options);
resolver.LoopStack = new List<Statement>();
resolver.ResolveStatement(proof, resolutionContext);
resolver.EnclosingStatementLabels = prevLblStmts;
resolver.LoopStack = prevLoopStack;
}

/// <summary>
/// If expr.Lhs != null: Desugars "var x: T :- E; F" into "var temp := E; if temp.IsFailure() then temp.PropagateFailure() else var x: T := temp.Extract(); F"
/// If expr.Lhs == null: Desugars " :- E; F" into "var temp := E; if temp.IsFailure() then temp.PropagateFailure() else F"
Expand Down
Loading