Skip to content

Commit

Permalink
Merge branch 'main' into igor/ci4
Browse files Browse the repository at this point in the history
  • Loading branch information
coffeeinprogress authored Aug 17, 2024
2 parents 3190993 + f55f57c commit 307d2eb
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions .unreleased/features/rewrite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle expressions such as S \in SUBSET [ a : Int ] by rewriting the expression into \A r \in S: DOMAIN r = {"a"} /\ r.a \in Int
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker
}
apply(tla.and(domEq +: fieldsEq: _*).as(b))
}

// S ∈ SUBSET { ["a" ↦ x] : x ∈ T }
case memEx @ OperEx(TlaSetOper.in, setRec,
OperEx(TlaSetOper.powerset,
OperEx(TlaSetOper.map, OperEx(TlaFunOper.rec, fieldsAndValues @ _*), varsAndSets @ _*)))
if fieldsAndValues.length == varsAndSets.length =>
val (fields, values) = TlaOper.deinterleave(fieldsAndValues)
val (vars, sets) = TlaOper.deinterleave(varsAndSets)
assert(fields.length == vars.length)
if (values.zip(vars).exists(p => p._1 != p._2)) {
// The set has a more general form: { [f_1 |-> e_1, ..., f_k |-> e_k]: x_1 \in S_1, ..., x_k \in S_k }, where
// e_1, ..., e_k are expressions over x_1, ..., x_k.
// We do not know how to optimize it.
memEx
} else {
val strSetT = SetT1(StrT1)
val b = BoolT1

val domType = getElemType(setRec)
val r = tla.name(nameGen.newName()).as(domType)

val domEq = tla.eql(tla.dom(r).as(SetT1(domType)), tla.enumSet(fields: _*).as(strSetT)).as(b)

val fieldsEq = fields.zip(values.zip(sets)).map { case (key, (value, set)) =>
tla.in(tla.appFun(r, key).as(value.typeTag.asTlaType1()), set).as(b)
}
apply(tla.forall(r, setRec, tla.and(domEq +: fieldsEq: _*).as(b)).as(b))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,100 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach {

// An optimization for set membership over sets of records. Note that this is the standard form produced by Keramelizer.
test("""r \in { [a |-> x, b |-> y]: x \in S, y \in T } becomes DOMAIN r = { "a", "b" } /\ r.a \in S /\ r.b \in T""") {
// ... [a |-> x, b |-> y] ...
val recT = RecT1("a" -> IntT1, "b" -> IntT1)
val recSetT = SetT1(recT)
// ... x \in S, y \in T ...
val record =
enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT)
// ... S ...
val S = name("S").as(intSetT)
// ... T ...
val T = name("T").as(intSetT)
// { ... }
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), S, name("y").as(intT), T).as(recSetT)
// r ...
val r = name("r").as(recT)
// ... \in ...
val input = in(r, recordSet).as(boolT)

// ~~>

// DOMAIN r = { "a", "b" }
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT)
// r.a \in S
val memA = in(appFun(r, str("a")).as(intT), S).as(boolT)
// r.b \in T
val memB = in(appFun(r, str("b")).as(intT), T).as(boolT)
// ... /\ ... /\ ...
val expected = and(domEq, memA, memB).as(boolT)
val output = optimizer.apply(input)

assert(expected == output)
}

// An optimization for set membership of the powerset of a record where the record has infinite co-domains.
test("""S \in SUBSET [a : T] ~~> \A r \in S: DOMAIN r = { "a" } /\ r.a \in T""") {

// ... { [a |-> x] : x \in T } ...
val recT = RecT1("a" -> IntT1)
val record =
enumFun(str("a"), name("x").as(intT)).as(recT)
val T = name("T").as(intSetT)
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), T).as(recSetT)

// ... SUBSET ...
val powSetT = powSet(recordSet).as(recSetT)

// S ...
val s = name("S").as(recSetT)

// ... \in ...
val input = in(s, powSetT).as(boolT)
val output = optimizer.apply(input)

// ~~>

// DOMAIN r = { "a" }
val r = name("t_1").as(recT)
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a")).as(strSetT)).as(boolT)

// r.a \in T
val memA = in(appFun(r, str("a")).as(intT), T).as(boolT)

// ... /\ ...
val conjunct = and(domEq, memA).as(boolT)

// \A ...
val expected = forall(r, s, conjunct).as(boolT)

assert(expected == output)
}

test("""S \in SUBSET [a : T, b : U] ~~> \A r \in S: DOMAIN r = { "a", "b" } /\ r.a \in T /\ r.b \in U""") {
val recT = RecT1("a" -> IntT1, "b" -> IntT1)
val record =
enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT)
val T = name("T").as(intSetT)
val U = name("U").as(intSetT)
val recSetT = SetT1(recT)
val recordSet = map(record, name("x").as(intT), T, name("y").as(intT), U).as(recSetT)
val powSetT = powSet(recordSet).as(recSetT)
val s = name("S").as(recSetT)
val input = in(s, powSetT).as(boolT)
val output = optimizer.apply(input)

val r = name("t_1").as(recT)
val strSetT = SetT1(StrT1)
val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT)
val memA = in(appFun(r, str("a")).as(intT), T).as(boolT)
val memB = in(appFun(r, str("b")).as(intT), U).as(boolT)
val conjunct = and(domEq, memA, memB).as(boolT)
val expected = forall(r, s, conjunct).as(boolT)

assert(expected == output)
}

Expand Down

0 comments on commit 307d2eb

Please sign in to comment.