Skip to content

Commit

Permalink
Merge pull request scala#10352 from som-snytt/issue/12757-glblub-loop
Browse files Browse the repository at this point in the history
More tailrec in handling long seq
  • Loading branch information
lrytz authored Jun 20, 2023
2 parents 0842f23 + b9a4518 commit 0f5746f
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 46 deletions.
3 changes: 2 additions & 1 deletion src/partest/scala/tools/partest/nest/Runner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class Runner(val testInfo: TestInfo, val suiteRunner: AbstractRunner) {
// We'll let the checkfile diffing report this failure
Files.write(log.toPath, stackTraceString(t).getBytes(Charset.defaultCharset()), CREATE, APPEND)
case t: Throwable =>
Files.write(log.toPath, t.getMessage.getBytes(Charset.defaultCharset()), CREATE, APPEND)
val data = (if (t.getMessage != null) t.getMessage else t.getClass.getName).getBytes(Charset.defaultCharset())
Files.write(log.toPath, data, CREATE, APPEND)
throw t
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/reflect/scala/reflect/internal/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ trait Printers extends api.Printers { self: SymbolTable =>
out.write(indentString, 0, indentMargin)
}

def printSeq[a](ls: List[a])(printelem: a => Unit)(printsep: => Unit): Unit =
@tailrec
final def printSeq[A](ls: List[A])(printelem: A => Unit)(printsep: => Unit): Unit =
ls match {
case List() =>
case List(x) => printelem(x)
Expand Down
82 changes: 52 additions & 30 deletions src/reflect/scala/reflect/internal/tpe/GlbLubs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ package reflect
package internal
package tpe

import scala.collection.mutable
import scala.collection.mutable, mutable.ListBuffer
import scala.annotation.tailrec
import Variance._

Expand Down Expand Up @@ -101,20 +101,20 @@ private[internal] trait GlbLubs {

def headOf(ix: Int) = baseTypeSeqs(ix).rawElem(ices(ix))

val pretypes: mutable.ListBuffer[Type] = mutable.ListBuffer.empty[Type]
val pretypes: ListBuffer[Type] = ListBuffer.empty[Type]

var isFinished = false
while (! isFinished && ices(0) < baseTypeSeqs(0).length){
while (!isFinished && ices(0) < baseTypeSeqs(0).length) {
lubListDepth = lubListDepth.incr
// Step 1: run through the List with these variables:
// 1) Is there any empty list? Are they equal or are we taking the smallest?
// isFinished: tsBts.exists(typeListIsEmpty)
// Is the frontier made up of types with the same symbol?
var isUniformFrontier = true
var isUniformFrontier = true
var sym = headOf(0).typeSymbol
// var tsYs = tsBts
var ix = 0
while (! isFinished && ix < baseTypeSeqs.length){
while (!isFinished && ix < baseTypeSeqs.length) {
if (ices(ix) == baseTypeSeqs(ix).length)
isFinished = true
else {
Expand All @@ -130,7 +130,7 @@ private[internal] trait GlbLubs {
// the invariant holds, i.e., the one that conveys most information regarding subtyping. Before
// merging, strip targs that refer to bound tparams (when we're computing the lub of type
// constructors.) Also filter out all types that are a subtype of some other type.
if (! isFinished){
if (!isFinished) {
// ts0 is the 1-dimensional frontier of symbols cutting through 2-dimensional tsBts.
// Invariant: all symbols "under" (closer to the first row) the frontier
// are smaller (according to _.isLess) than the ones "on and beyond" the frontier
Expand All @@ -145,7 +145,7 @@ private[internal] trait GlbLubs {
}

if (isUniformFrontier) {
val ts1 = elimSub(ts0, depth) map elimHigherOrderTypeParam
val ts1 = elimSub(ts0, depth).map(elimHigherOrderTypeParam)
mergePrefixAndArgs(ts1, Covariant, depth) match {
case NoType =>
case tp => pretypes += tp
Expand All @@ -165,11 +165,12 @@ private[internal] trait GlbLubs {
jx += 1
}
if (printLubs) {
val str = baseTypeSeqs.zipWithIndex.map({ case (tps, idx) =>
tps.toList.drop(ices(idx)).map(" " + _ + "\n").mkString(" (" + idx + ")\n", "", "\n")
}).mkString("")

println("Frontier(\n" + str + ")")
println {
baseTypeSeqs.zipWithIndex.map { case (tps, idx) =>
tps.toList.drop(ices(idx)).map(" " + _).mkString(" (" + idx + ")\n", "\n", "\n")
}
.mkString("Frontier(\n", "", ")")
}
printLubMatrixAux(lubListDepth)
}
}
Expand Down Expand Up @@ -198,36 +199,57 @@ private[internal] trait GlbLubs {

/** From a list of types, retain only maximal types as determined by the partial order `po`. */
private def maxTypes(ts: List[Type])(po: (Type, Type) => Boolean): List[Type] = {
def loop(ts: List[Type]): List[Type] = ts match {
def stacked(ts: List[Type]): List[Type] = ts match {
case t :: ts1 =>
val ts2 = loop(ts1.filterNot(po(_, t)))
val ts2 = stacked(ts1.filterNot(po(_, t)))
if (ts2.exists(po(t, _))) ts2 else t :: ts2
case Nil => Nil
}

// The order here matters because type variables and
// wildcards can act both as subtypes and supertypes.
val (ts2, ts1) = partitionConserve(ts) { tp =>
isWildCardOrNonGroundTypeVarCollector.collect(tp).isDefined
// loop thru tails, filtering for survivors of po test with the current element, which is saved for later culling
@tailrec
def loop(survivors: List[Type], toCull: List[Type]): List[Type] = survivors match {
case h :: rest =>
loop(rest.filterNot(po(_, h)), h :: toCull)
case _ =>
// unwind the stack of saved elements, accumulating a result containing elements surviving po (in swapped order)
def sieve(res: List[Type], remaining: List[Type]): List[Type] = remaining match {
case h :: tail =>
val res1 = if (res.exists(po(h, _))) res else h :: res
sieve(res1, tail)
case _ => res
}
toCull match {
case _ :: Nil => toCull
case _ => sieve(Nil, toCull)
}
}

loop(ts1 ::: ts2)
// The order here matters because type variables and wildcards can act both as subtypes and supertypes.
val sorted = {
val (wilds, ts1) = partitionConserve(ts)(isWildCardOrNonGroundTypeVarCollector.collect(_).isDefined)
ts1 ::: wilds
}
if (sorted.lengthCompare(5) > 0) loop(sorted, Nil)
else stacked(sorted)
}

/** Eliminate from list of types all elements which are a supertype
* of some other element of the list. */
* of some other element of the list. */
private def elimSuper(ts: List[Type]): List[Type] =
maxTypes(ts)((t1, t2) => t2 <:< t1)
if (ts.lengthCompare(1) <= 0) ts
else maxTypes(ts)((t1, t2) => t2 <:< t1)

/** Eliminate from list of types all elements which are a subtype
* of some other element of the list. */
@tailrec private def elimSub(ts: List[Type], depth: Depth): List[Type] = {
val ts1 = maxTypes(ts)(isSubType(_, _, depth.decr))
if (ts1.lengthCompare(1) <= 0) ts1 else {
val ts2 = ts1.mapConserve(t => elimAnonymousClass(t.dealiasWiden))
if (ts1 eq ts2) ts1 else elimSub(ts2, depth)
* of some other element of the list. */
@tailrec private def elimSub(ts: List[Type], depth: Depth): List[Type] =
if (ts.lengthCompare(1) <= 0) ts else {
val ts1 = maxTypes(ts)(isSubType(_, _, depth.decr))
if (ts1.lengthCompare(1) <= 0) ts1 else {
val ts2 = ts1.mapConserve(t => elimAnonymousClass(t.dealiasWiden))
if (ts1 eq ts2) ts1 else elimSub(ts2, depth)
}
}
}

/** Does this set of types have the same weak lub as
* it does regular lub? This is exposed so lub callers
Expand Down Expand Up @@ -491,7 +513,7 @@ private[internal] trait GlbLubs {
val (ts, tparams) = stripExistentialsAndTypeVars(ts0)
val glbOwner = commonOwner(ts)
val ts1 = {
val res = mutable.ListBuffer.empty[Type]
val res = ListBuffer.empty[Type]
def loop(ty: Type): Unit = ty match {
case RefinedType(ps, _) => ps.foreach(loop)
case _ => res += ty
Expand All @@ -508,7 +530,7 @@ private[internal] trait GlbLubs {
def glbsym(proto: Symbol): Symbol = {
val prototp = glbThisType.memberInfo(proto)
val symtypes: List[Type] = {
val res = mutable.ListBuffer.empty[Type]
val res = ListBuffer.empty[Type]
ts foreach { t =>
t.nonPrivateMember(proto.name).alternatives foreach { alt =>
val mi = glbThisType.memberInfo(alt)
Expand Down
32 changes: 18 additions & 14 deletions src/reflect/scala/reflect/internal/tpe/TypeConstraints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ private[internal] trait TypeConstraints {
import definitions._

/** A log of type variable with their original constraints. Used in order
* to undo constraints in the case of isSubType/isSameType failure.
*/
* to undo constraints in the case of isSubType/isSameType failure.
*/
private lazy val _undoLog = new UndoLog
def undoLog = _undoLog

Expand Down Expand Up @@ -54,9 +54,9 @@ private[internal] trait TypeConstraints {
}

/** No sync necessary, because record should only
* be called from within an undo or undoUnless block,
* which is already synchronized.
*/
* be called from within an undo or undoUnless block,
* which is already synchronized.
*/
private[reflect] def record(tv: TypeVar) = {
log ::= UndoPair(tv, tv.constr.cloneInternal)
}
Expand Down Expand Up @@ -96,11 +96,11 @@ private[internal] trait TypeConstraints {
*/

/** Guard these lists against AnyClass and NothingClass appearing,
* else loBounds.isEmpty will have different results for an empty
* constraint and one with Nothing as a lower bound. [Actually
* guarding addLoBound/addHiBound somehow broke raw types so it
* only guards against being created with them.]
*/
* else loBounds.isEmpty will have different results for an empty
* constraint and one with Nothing as a lower bound. [Actually
* guarding addLoBound/addHiBound somehow broke raw types so it
* only guards against being created with them.]
*/
private[this] var lobounds = lo0 filterNot (_.isNothing)
private[this] var hibounds = hi0 filterNot (_.isAny)
private[this] var numlo = numlo0
Expand All @@ -124,15 +124,21 @@ private[internal] trait TypeConstraints {
// See pos/t6367 and pos/t6499 for the competing test cases.
val mustConsider = tp.typeSymbol match {
case NothingClass => true
case _ => !(lobounds contains tp)
case _ => !lobounds.contains(tp)
}
if (mustConsider) {
def justTwoStrings: Boolean = (
tp.typeSymbol == StringClass && tp.isInstanceOf[ConstantType] &&
lobounds.lengthCompare(1) == 0 && lobounds.head.typeSymbol == StringClass
)
if (isNumericBound && isNumericValueType(tp)) {
if (numlo == NoType || isNumericSubType(numlo, tp))
numlo = tp
else if (!isNumericSubType(tp, numlo))
numlo = numericLoBound
}
else if (justTwoStrings)
lobounds = tp.widen :: Nil // don't accumulate strings; we know they are not exactly the same bc mustConsider
else lobounds ::= tp
}
}
Expand Down Expand Up @@ -222,7 +228,7 @@ private[internal] trait TypeConstraints {

@inline def toBound(hi: Boolean, tparam: Symbol) = if (hi) tparam.info.upperBound else tparam.info.lowerBound

def solveOne(tvar: TypeVar, isContravariant: Boolean): Unit = {
def solveOne(tvar: TypeVar, isContravariant: Boolean): Unit =
if (tvar.constr.inst == NoType) {
tvar.constr.inst = null // mark tvar as being solved

Expand Down Expand Up @@ -252,7 +258,6 @@ private[internal] trait TypeConstraints {
}
}


if (!(otherTypeVarBeingSolved || containsSymbol(bound, tparam))) {
val boundSym = bound.typeSymbol
if (up) {
Expand Down Expand Up @@ -284,7 +289,6 @@ private[internal] trait TypeConstraints {
// debuglog(s"$tvar setInst $newInst")
tvar setInst newInst
}
}

// println("solving "+tvars+"/"+tparams+"/"+(tparams map (_.info)))
foreachWithIndex(tvars)((tvar, i) => solveOne(tvar, areContravariant(i)))
Expand Down
16 changes: 16 additions & 0 deletions test/files/run/t12757.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

import scala.tools.partest.DirectTest

object Test extends DirectTest {
def header = """|object Test extends App {
| val myStrings: List[String] = List(""".stripMargin.linesIterator
def footer = """| )
| println(myStrings.mkString(","))
|}""".stripMargin.linesIterator
def values = Iterator.tabulate(4000)(i => s" \"$i\",")
def code = (header ++ values ++ footer).mkString("\n")

override def extraSettings: String = "-usejavacp -J-Xms256k"

def show() = assert(compile())
}
16 changes: 16 additions & 0 deletions test/files/run/t12757b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

import scala.tools.partest.DirectTest

object Test extends DirectTest {
def header = """|object Test extends App {
| val myInts: List[Int] = List(""".stripMargin.linesIterator
def footer = """| )
| println(myInts.mkString(","))
|}""".stripMargin.linesIterator
def values = Iterator.tabulate(4000)(i => s" $i,")
def code = (header ++ values ++ footer).mkString("\n")

override def extraSettings: String = "-usejavacp -J-Xms256k"

def show() = assert(compile())
}
19 changes: 19 additions & 0 deletions test/files/run/t12757c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

import scala.tools.partest.DirectTest

object Test extends DirectTest {
def header = """|object Test extends App {
| val myStrings = List(
| 42,
| Test,
|""".stripMargin.linesIterator
def footer = """| )
| println(myStrings.mkString(","))
|}""".stripMargin.linesIterator
def values = Iterator.tabulate(4000)(i => s" \"$i\",")
def code = (header ++ values ++ footer).mkString("\n")

override def extraSettings: String = "-usejavacp -J-Xms256k"

def show() = assert(compile())
}

0 comments on commit 0f5746f

Please sign in to comment.