Skip to content

Commit

Permalink
Merge pull request scala#10425 from RustedBones/equals-converters
Browse files Browse the repository at this point in the history
Define equality between converter wrappers
  • Loading branch information
lrytz authored Jun 20, 2023
2 parents 0f5746f + 0b0478a commit 7db02fe
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 47 deletions.
46 changes: 38 additions & 8 deletions src/library/scala/collection/convert/JavaCollectionWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,33 @@ private[collection] object JavaCollectionWrappers extends Serializable {
def hasMoreElements = underlying.hasNext
def nextElement() = underlying.next()
override def remove() = throw new UnsupportedOperationException
override def equals(other: Any): Boolean = other match {
case that: IteratorWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
class JIteratorWrapper[A](val underlying: ju.Iterator[A]) extends AbstractIterator[A] with Iterator[A] with Serializable {
def hasNext = underlying.hasNext
def next() = underlying.next
override def equals(other: Any): Boolean = other match {
case that: JIteratorWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
class JEnumerationWrapper[A](val underlying: ju.Enumeration[A]) extends AbstractIterator[A] with Iterator[A] with Serializable {
def hasNext = underlying.hasMoreElements
def next() = underlying.nextElement
override def equals(other: Any): Boolean = other match {
case that: JEnumerationWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

trait IterableWrapperTrait[A] extends ju.AbstractCollection[A] {
Expand All @@ -57,13 +72,11 @@ private[collection] object JavaCollectionWrappers extends Serializable {

@SerialVersionUID(3L)
class IterableWrapper[A](val underlying: Iterable[A]) extends ju.AbstractCollection[A] with IterableWrapperTrait[A] with Serializable {
import scala.runtime.Statics._
override def equals(other: Any): Boolean =
other match {
case other: IterableWrapper[_] => underlying.equals(other.underlying)
case _ => false
}
override def hashCode = finalizeHash(mix(mix(0xcafebabe, "IterableWrapper".hashCode), anyHash(underlying)), 1)
override def equals(other: Any): Boolean = other match {
case that: IterableWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
Expand All @@ -74,6 +87,11 @@ private[collection] object JavaCollectionWrappers extends Serializable {
def iterator = underlying.iterator.asScala
override def iterableFactory = mutable.ArrayBuffer
override def isEmpty: Boolean = !underlying.iterator().hasNext
override def equals(other: Any): Boolean = other match {
case that: JIterableWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
Expand All @@ -86,6 +104,11 @@ private[collection] object JavaCollectionWrappers extends Serializable {
override def knownSize: Int = if (underlying.isEmpty) 0 else super.knownSize
override def isEmpty = underlying.isEmpty
override def iterableFactory = mutable.ArrayBuffer
override def equals(other: Any): Boolean = other match {
case that: JCollectionWrapper[_] => this.underlying == that.underlying
case _ => false
}
override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
Expand Down Expand Up @@ -254,7 +277,7 @@ private[collection] object JavaCollectionWrappers extends Serializable {
def getKey = k
def getValue = v
def setValue(v1 : V) = self.put(k, v1)

// It's important that this implementation conform to the contract
// specified in the javadocs of java.util.Map.Entry.hashCode
//
Expand Down Expand Up @@ -529,6 +552,13 @@ private[collection] object JavaCollectionWrappers extends Serializable {
} catch {
case ex: ClassCastException => null.asInstanceOf[V]
}

override def equals(other: Any): Boolean = other match {
case that: DictionaryWrapper[_, _] => this.underlying == that.underlying
case _ => false
}

override def hashCode: Int = underlying.hashCode()
}

@SerialVersionUID(3L)
Expand Down
166 changes: 127 additions & 39 deletions test/junit/scala/collection/convert/EqualsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@ package scala.collection.convert
import org.junit.Test
import org.junit.Assert._

import java.util.{
AbstractList,
AbstractMap,
AbstractSet,
Collections,
Collection => JCollection,
HashSet => JHashSet,
List => JList,
Map => JMap,
Set => JSet
}
import java.lang.{Iterable => JIterable}
import java.util.concurrent.{ConcurrentHashMap => JCMap}
import scala.collection.{AbstractIterable, concurrent, mutable}
import scala.jdk.CollectionConverters._
import JavaCollectionWrappers._

import java.util.{AbstractList, AbstractSet, List => JList, Set => JSet}

class JTestList(vs: Int*) extends AbstractList[Int] {
def this() = this(Nil: _*)
override def size = vs.size
Expand All @@ -21,58 +33,134 @@ class JTestSet(vs: Int*) extends AbstractSet[Int] {
override def iterator = vs.iterator.asJava
}

object JTestMap {
case class JTestMapEntry(key: Int, value: String) extends JMap.Entry[Int, String] {
override def getKey: Int = key
override def getValue: String = value
override def setValue(value: String): String =
throw new UnsupportedOperationException("Cannot set value on JTestMapEntry")
}
}

class JTestMap(vs: (Int, String)*) extends AbstractMap[Int, String] {
import JTestMap._
override def entrySet(): JSet[JMap.Entry[Int, String]] = {
val entrySet = new JHashSet[JMap.Entry[Int, String]](vs.size);
vs.foreach { case (k, v) => entrySet.add(JTestMapEntry(k, v)) }
entrySet
}
}

/** Test that collection wrappers forward equals and hashCode where appropriate. */
class EqualsTest {

def jlstOf(vs: Int*): JList[Int] = new JTestList(vs: _*)
def jsetOf(vs: Int*): JSet[Int] = new JTestSet(vs: _*)

// Seq extending AbstractList inherits equals
def jListOf(vs: Int*): JList[Int] = new JTestList(vs: _*)
def jSetOf(vs: Int*): JSet[Int] = new JTestSet(vs: _*)
def jMapOf(vs: (Int, String)*): JMap[Int, String] = new JTestMap(vs: _*)

@Test def `List as JList has equals`: Unit = {
val list = List(1, 2, 3)
val jlst = new SeqWrapper(list)
assertEquals(jlstOf(1, 2, 3), jlst)
assertEquals(jlst, jlstOf(1, 2, 3))
assertTrue(jlst == jlstOf(1, 2, 3))
assertEquals(jlst.hashCode, jlst.hashCode)
// SeqWrapper extending util.AbstractList inherits equals
@Test def `Seq as JList has equals`: Unit = {
def seq = Seq(1, 2, 3)
def jList = new SeqWrapper(seq)
assertEquals(jList, jList)
assertEquals(jListOf(1, 2, 3), jList)
assertEquals(jList, jListOf(1, 2, 3))
assertTrue(jList == jListOf(1, 2, 3))
assertEquals(jList.hashCode, jList.hashCode)
}

// SetWrapper extending util.AbstractSet inherits equals
@Test def `Set as JSet has equals`: Unit = {
val set = Set(1, 2, 3)
val jset = new SetWrapper(set)
assertEquals(jsetOf(1, 2, 3), jset)
assertEquals(jset, jsetOf(1, 2, 3))
assertTrue(jset == jsetOf(1, 2, 3))
assertEquals(jset.hashCode, jset.hashCode)
def set = Set(1, 2, 3)
def jSet = new SetWrapper(set)
assertEquals(jSet, jSet)
assertEquals(jSetOf(1, 2, 3), jSet)
assertEquals(jSet, jSetOf(1, 2, 3))
assertTrue(jSet == jSetOf(1, 2, 3))
assertEquals(jSet.hashCode, jSet.hashCode)
}

// MapWrapper extending util.AbstractMap inherits equals
@Test def `Map as JMap has equals`: Unit = {
val map = Map(1 -> "one", 2 -> "two", 3 -> "three")
val jmap = new MapWrapper(map)
assertEquals(jmap, jmap)
def map = Map(1 -> "one", 2 -> "two", 3 -> "three")
def jMap = new MapWrapper(map)
assertEquals(jMap, jMap)
assertEquals(jMapOf(1 -> "one", 2 -> "two", 3 -> "three"), jMap)
assertEquals(jMap, jMapOf(1 -> "one", 2 -> "two", 3 -> "three"))
assertTrue(jMap == jMapOf(1 -> "one", 2 -> "two", 3 -> "three"))
assertEquals(jMap.hashCode, jMap.hashCode)
}

@Test def `Anything as Collection is equal to Anything`: Unit = {
def set = Set(1, 2, 3)
def jset = new IterableWrapper(set)
assertTrue(jset == jset)
assertEquals(jset, jset)
assertNotEquals(jset, set)
assertEquals(jset.hashCode, jset.hashCode)
@Test def `Iterable as JIterable does not compare equal`: Unit = {
// scala iterable without element equality defined
def iterable: Iterable[Int] = new AbstractIterable[Int] {
override def iterator: Iterator[Int] = Iterator(1, 2, 3)
}
def jIterable = new IterableWrapper(iterable)
assertNotEquals(jIterable, jIterable)
assertNotEquals(jIterable.hashCode, jIterable.hashCode)
}

@Test def `Iterator wrapper does not compare equal`: Unit = {
def it = List(1, 2, 3).iterator
def jit = new IteratorWrapper(it)
assertNotEquals(jit, jit)
assertNotEquals(jit.hashCode, jit.hashCode)
@Test def `Iterator as JIterator does not compare equal`: Unit = {
def iterator = Iterator(1, 2, 3)
def jIterator = new IteratorWrapper(iterator)
assertNotEquals(jIterator, jIterator)
assertNotEquals(jIterator.hashCode, jIterator.hashCode)
}

@Test def `Anything.asScala Iterable has case equals`: Unit = {
def vs = jlstOf(42, 27, 37)
def it = new JListWrapper(vs)
assertEquals(it, it)
assertEquals(it.hashCode, it.hashCode)
@Test def `All wrapper compare equal if underlying is equal`(): Unit = {
val jList = Collections.emptyList[String]()
assertEquals(jList.asScala, jList.asScala)

val jIterator = jList.iterator()
assertEquals(jIterator.asScala, jIterator.asScala)

val jEnumeration = Collections.emptyEnumeration[String]()
assertEquals(jEnumeration.asScala, jEnumeration.asScala)

val jIterable = jList.asInstanceOf[JIterable[String]]
assertEquals(jIterable.asScala, jIterable.asScala)

val jCollection = jList.asInstanceOf[JCollection[String]]
assertEquals(jCollection.asScala, jCollection.asScala)

val jSet = Collections.emptySet[String]()
assertEquals(jSet.asScala, jSet.asScala)

val jMap = Collections.emptyMap[String, String]()
assertEquals(jMap.asScala, jMap.asScala)

val jCMap = new JCMap[String, String]()
assertEquals(jCMap.asScala, jCMap.asScala)

val iterator = Iterator.empty[String]
assertEquals(iterator.asJava, iterator.asJava)
assertEquals(iterator.asJavaEnumeration, iterator.asJavaEnumeration)

val iterable = Iterable.empty[String]
assertEquals(iterable.asJava, iterable.asJava)
assertEquals(iterable.asJavaCollection, iterable.asJavaCollection)

val buffer = mutable.Buffer.empty[String]
assertEquals(buffer.asJava, buffer.asJava)

val seq = mutable.Seq.empty[String]
assertEquals(seq.asJava, seq.asJava)

val mutableSet = mutable.Set.empty[String]
assertEquals(mutableSet.asJava, mutableSet.asJava)

val set = Set.empty[String]
assertEquals(set.asJava, set.asJava)

val mutableMap = mutable.Map.empty[String, String]
assertEquals(mutableMap.asJava, mutableMap.asJava)
assertEquals(mutableMap.asJavaDictionary, mutableMap.asJavaDictionary)

val map = Map.empty[String, String]
assertEquals(map.asJava, map.asJava)

val concurrentMap = concurrent.TrieMap.empty[String, String]
assertEquals(concurrentMap.asJava, concurrentMap.asJava)
}
}

0 comments on commit 7db02fe

Please sign in to comment.