Skip to content

Commit

Permalink
Change mockito format in BypassMergeSortShuffleWriterSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
mccheah committed Jul 25, 2019
1 parent 56fa450 commit b8b7b8d
Showing 1 changed file with 44 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach

import org.apache.spark._
Expand Down Expand Up @@ -60,66 +59,67 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte

override def beforeEach(): Unit = {
super.beforeEach()
MockitoAnnotations.initMocks(this)
tempDir = Utils.createTempDir()
outputFile = File.createTempFile("shuffle", null, tempDir)
taskMetrics = new TaskMetrics
MockitoAnnotations.initMocks(this)
shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
shuffleId = 0,
numMaps = 2,
dependency = dependency
)
val memoryManager = new TestMemoryManager(conf)
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
when(dependency.partitioner).thenReturn(new HashPartitioner(7))
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)

when(blockResolver.writeIndexFileAndCommit(
anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
.thenAnswer { invocationOnMock =>
val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}
null
}.when(blockResolver)
.writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))

when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
any[File],
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics]))
.thenAnswer { invocation =>
val args = invocation.getArguments
val manager = new SerializerManager(new JavaSerializer(conf), conf)
new DiskBlockObjectWriter(
args(1).asInstanceOf[File],
manager,
args(2).asInstanceOf[SerializerInstance],
args(3).asInstanceOf[Int],
syncWrites = false,
args(4).asInstanceOf[ShuffleWriteMetrics],
blockId = args(0).asInstanceOf[BlockId])
}

when(diskBlockManager.createTempShuffleBlock())
.thenAnswer { _ =>
val blockId = new TempShuffleBlockId(UUID.randomUUID)
val file = new File(tempDir, blockId.name)
blockIdToFileMap.put(blockId, file)
temporaryFilesCreated += file
(blockId, file)
}

doAnswer((invocation: InvocationOnMock) => {
val args = invocation.getArguments
val manager = new SerializerManager(new JavaSerializer(conf), conf)
new DiskBlockObjectWriter(
args(1).asInstanceOf[File],
manager,
args(2).asInstanceOf[SerializerInstance],
args(3).asInstanceOf[Int],
syncWrites = false,
args(4).asInstanceOf[ShuffleWriteMetrics],
blockId = args(0).asInstanceOf[BlockId]
)
}).when(blockManager)
.getDiskWriter(
any[BlockId],
any[File],
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics])

doAnswer((_: InvocationOnMock) => {
val blockId = new TempShuffleBlockId(UUID.randomUUID)
val file = new File(tempDir, blockId.name)
blockIdToFileMap.put(blockId, file)
temporaryFilesCreated += file
(blockId, file)
}).when(diskBlockManager).createTempShuffleBlock()

doAnswer((invocation: InvocationOnMock) => {
when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation =>
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
}).when(diskBlockManager).getFile(any[BlockId])
}

val memoryManager = new TestMemoryManager(conf)
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
conf, blockManager, blockResolver)
}
Expand Down

0 comments on commit b8b7b8d

Please sign in to comment.