Skip to content

Commit

Permalink
Prepare zlib's inflater for Kotlin/Native (#1422)
Browse files Browse the repository at this point in the history
* Prepare zlib's inflater for Kotlin/Native

* Don't check the result of deflateEnd

It is different for different zlib versions. In particular,
it returns Z_DATA_ERROR if the stream is closed without being
used.
  • Loading branch information
squarejesse authored Feb 7, 2024
1 parent 062048a commit 1d5f262
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 3 deletions.
5 changes: 2 additions & 3 deletions okio/src/nativeMain/kotlin/okio/Deflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import platform.zlib.deflateEnd
import platform.zlib.deflateInit2
import platform.zlib.z_stream_s

private val emptyByteArray = byteArrayOf()
internal val emptyByteArray = byteArrayOf()

/**
* Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits
Expand Down Expand Up @@ -145,8 +145,7 @@ internal class Deflater : Closeable {
if (closed) return
closed = true

val deflateEndResult = deflateEnd(zStream.ptr)
check(deflateEndResult == Z_OK)
deflateEnd(zStream.ptr)
nativeHeap.free(zStream)
}
}
114 changes: 114 additions & 0 deletions okio/src/nativeMain/kotlin/okio/Inflater.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

import kotlinx.cinterop.CPointer
import kotlinx.cinterop.UByteVar
import kotlinx.cinterop.addressOf
import kotlinx.cinterop.alloc
import kotlinx.cinterop.free
import kotlinx.cinterop.nativeHeap
import kotlinx.cinterop.ptr
import kotlinx.cinterop.usePinned
import platform.zlib.Z_BUF_ERROR
import platform.zlib.Z_DATA_ERROR
import platform.zlib.Z_NO_FLUSH
import platform.zlib.Z_OK
import platform.zlib.Z_STREAM_END
import platform.zlib.inflateEnd
import platform.zlib.inflateInit2
import platform.zlib.z_stream_s

/**
* Inflate using Kotlin/Native's built-in zlib bindings.
*
* The API is symmetric with [Deflater].
*/
internal class Inflater : Closeable {
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> {
zalloc = null
zfree = null
opaque = null
check(
inflateInit2(
strm = ptr,
windowBits = -15, // Default value for raw deflate.
) == Z_OK,
)
}

var source: ByteArray = emptyByteArray
var sourcePos: Int = 0
var sourceLimit: Int = 0

var target: ByteArray = emptyByteArray
var targetPos: Int = 0
var targetLimit: Int = 0

private var closed = false

/**
* Returns true if no further calls to [inflate] are required because the source stream is
* finished. Otherwise, ensure there's input data in [source] and output space in [target] and
* call this again.
*/
fun inflate(): Boolean {
check(!closed) { "closed" }
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size)
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size)

source.usePinned { pinnedSource ->
target.usePinned { pinnedTarget ->
val sourceByteCount = sourceLimit - sourcePos
zStream.next_in = when {
sourceByteCount > 0 -> pinnedSource.addressOf(sourcePos) as CPointer<UByteVar>
else -> null
}
zStream.avail_in = sourceByteCount.toUInt()

val targetByteCount = targetLimit - targetPos
zStream.next_out = when {
targetByteCount > 0 -> pinnedTarget.addressOf(targetPos) as CPointer<UByteVar>
else -> null
}
zStream.avail_out = targetByteCount.toUInt()

val inflateResult = platform.zlib.inflate(zStream.ptr, Z_NO_FLUSH)

sourcePos += sourceByteCount - zStream.avail_in.toInt()
targetPos += targetByteCount - zStream.avail_out.toInt()

return when (inflateResult) {
Z_OK -> false
Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target.
Z_STREAM_END -> true
Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR")

// One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR.
else -> throw ProtocolException("unexpected inflate result: $inflateResult")
}
}
}
}

override fun close() {
if (closed) return
closed = true

inflateEnd(zStream.ptr)
nativeHeap.free(zStream)
}
}
18 changes: 18 additions & 0 deletions okio/src/nativeTest/kotlin/okio/DeflaterTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package okio

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertTrue
import okio.ByteString.Companion.decodeBase64
Expand Down Expand Up @@ -174,4 +175,21 @@ class DeflaterTest {

deflater.close()
}

@Test
fun cannotDeflateAfterClose() {
val deflater = Deflater()
deflater.close()

assertFailsWith<IllegalStateException> {
deflater.deflate()
}
}

@Test
fun closeIsIdemptent() {
val deflater = Deflater()
deflater.close()
deflater.close()
}
}
201 changes: 201 additions & 0 deletions okio/src/nativeTest/kotlin/okio/InflaterTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertTrue
import okio.ByteString.Companion.decodeBase64
import okio.ByteString.Companion.decodeHex
import okio.ByteString.Companion.toByteString

class InflaterTest {
@Test
fun happyPath() {
val inflater = Inflater().apply {
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
.decodeBase64()!!.toByteArray()
sourcePos = 0
sourceLimit = source.size

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(inflater.inflate())
assertEquals(inflater.sourceLimit, inflater.sourcePos)

val inflated = inflater.target.toByteString(0, inflater.targetPos)
assertEquals(
"God help us, we're in the hands of engineers.",
inflated.utf8(),
)

inflater.close()
}

@Test
fun inflateInParts() {
val inflater = Inflater().apply {
target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJ".decodeBase64()!!.toByteArray()
inflater.sourcePos = 0
inflater.sourceLimit = inflater.source.size
assertFalse(inflater.inflate())
assertEquals(inflater.sourceLimit, inflater.sourcePos)

inflater.source = "SFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64()!!.toByteArray()
inflater.sourcePos = 0
inflater.sourceLimit = inflater.source.size
assertTrue(inflater.inflate())
assertEquals(inflater.sourceLimit, inflater.sourcePos)

val inflated = inflater.target.toByteString(0, inflater.targetPos)
assertEquals(
"God help us, we're in the hands of engineers.",
inflated.utf8(),
)

inflater.close()
}

@Test
fun inflateInsufficientSpaceInTarget() {
val targetBuffer = Buffer()

val inflater = Inflater().apply {
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
.decodeBase64()!!.toByteArray()
sourcePos = 0
sourceLimit = source.size
}

inflater.target = ByteArray(31)
inflater.targetPos = 0
inflater.targetLimit = inflater.target.size
assertFalse(inflater.inflate())
assertEquals(inflater.targetLimit, inflater.targetPos)
targetBuffer.write(inflater.target)

inflater.target = ByteArray(256)
inflater.targetPos = 0
inflater.targetLimit = inflater.target.size
assertTrue(inflater.inflate())
assertEquals(inflater.sourcePos, inflater.sourceLimit)
targetBuffer.write(inflater.target, 0, inflater.targetPos)

assertEquals(
"God help us, we're in the hands of engineers.",
targetBuffer.readUtf8(),
)

inflater.close()
}

@Test
fun inflateEmptyContent() {
val inflater = Inflater().apply {
source = "AwA=".decodeBase64()!!.toByteArray()
sourcePos = 0
sourceLimit = source.size

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(inflater.inflate())

val inflated = inflater.target.toByteString(0, inflater.targetPos)
assertEquals(
"",
inflated.utf8(),
)

inflater.close()
}

@Test
fun inflateInPartsStartingWithEmptySource() {
val inflater = Inflater().apply {
target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

inflater.source = ByteArray(256)
inflater.sourcePos = 0
inflater.sourceLimit = 0
assertFalse(inflater.inflate())

inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
.decodeBase64()!!.toByteArray()
inflater.sourcePos = 0
inflater.sourceLimit = inflater.source.size
assertTrue(inflater.inflate())

val inflated = inflater.target.toByteString(0, inflater.targetPos)
assertEquals(
"God help us, we're in the hands of engineers.",
inflated.utf8(),
)

inflater.close()
}

@Test
fun inflateInvalidData() {
val inflater = Inflater().apply {
target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

inflater.source = "ffffffffffffffff".decodeHex().toByteArray()
inflater.sourcePos = 0
inflater.sourceLimit = inflater.source.size
val exception = assertFailsWith<ProtocolException> {
inflater.inflate()
}
assertEquals("Z_DATA_ERROR", exception.message)

inflater.close()
}

@Test
fun cannotInflateAfterClose() {
val inflater = Inflater()
inflater.close()

assertFailsWith<IllegalStateException> {
inflater.inflate()
}
}

@Test
fun closeIsIdemptent() {
val inflater = Inflater()
inflater.close()
inflater.close()
}
}

0 comments on commit 1d5f262

Please sign in to comment.