-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prepare zlib's inflater for Kotlin/Native (#1422)
* Prepare zlib's inflater for Kotlin/Native * Don't check the result of deflateEnd It is different for different zlib versions. In particular, it returns Z_DATA_ERROR if the stream is closed without being used.
- Loading branch information
1 parent
062048a
commit 1d5f262
Showing
4 changed files
with
335 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright (C) 2024 Square, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package okio | ||
|
||
import kotlinx.cinterop.CPointer | ||
import kotlinx.cinterop.UByteVar | ||
import kotlinx.cinterop.addressOf | ||
import kotlinx.cinterop.alloc | ||
import kotlinx.cinterop.free | ||
import kotlinx.cinterop.nativeHeap | ||
import kotlinx.cinterop.ptr | ||
import kotlinx.cinterop.usePinned | ||
import platform.zlib.Z_BUF_ERROR | ||
import platform.zlib.Z_DATA_ERROR | ||
import platform.zlib.Z_NO_FLUSH | ||
import platform.zlib.Z_OK | ||
import platform.zlib.Z_STREAM_END | ||
import platform.zlib.inflateEnd | ||
import platform.zlib.inflateInit2 | ||
import platform.zlib.z_stream_s | ||
|
||
/** | ||
* Inflate using Kotlin/Native's built-in zlib bindings. | ||
* | ||
* The API is symmetric with [Deflater]. | ||
*/ | ||
internal class Inflater : Closeable { | ||
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> { | ||
zalloc = null | ||
zfree = null | ||
opaque = null | ||
check( | ||
inflateInit2( | ||
strm = ptr, | ||
windowBits = -15, // Default value for raw deflate. | ||
) == Z_OK, | ||
) | ||
} | ||
|
||
var source: ByteArray = emptyByteArray | ||
var sourcePos: Int = 0 | ||
var sourceLimit: Int = 0 | ||
|
||
var target: ByteArray = emptyByteArray | ||
var targetPos: Int = 0 | ||
var targetLimit: Int = 0 | ||
|
||
private var closed = false | ||
|
||
/** | ||
* Returns true if no further calls to [inflate] are required because the source stream is | ||
* finished. Otherwise, ensure there's input data in [source] and output space in [target] and | ||
* call this again. | ||
*/ | ||
fun inflate(): Boolean { | ||
check(!closed) { "closed" } | ||
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size) | ||
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size) | ||
|
||
source.usePinned { pinnedSource -> | ||
target.usePinned { pinnedTarget -> | ||
val sourceByteCount = sourceLimit - sourcePos | ||
zStream.next_in = when { | ||
sourceByteCount > 0 -> pinnedSource.addressOf(sourcePos) as CPointer<UByteVar> | ||
else -> null | ||
} | ||
zStream.avail_in = sourceByteCount.toUInt() | ||
|
||
val targetByteCount = targetLimit - targetPos | ||
zStream.next_out = when { | ||
targetByteCount > 0 -> pinnedTarget.addressOf(targetPos) as CPointer<UByteVar> | ||
else -> null | ||
} | ||
zStream.avail_out = targetByteCount.toUInt() | ||
|
||
val inflateResult = platform.zlib.inflate(zStream.ptr, Z_NO_FLUSH) | ||
|
||
sourcePos += sourceByteCount - zStream.avail_in.toInt() | ||
targetPos += targetByteCount - zStream.avail_out.toInt() | ||
|
||
return when (inflateResult) { | ||
Z_OK -> false | ||
Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target. | ||
Z_STREAM_END -> true | ||
Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR") | ||
|
||
// One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR. | ||
else -> throw ProtocolException("unexpected inflate result: $inflateResult") | ||
} | ||
} | ||
} | ||
} | ||
|
||
override fun close() { | ||
if (closed) return | ||
closed = true | ||
|
||
inflateEnd(zStream.ptr) | ||
nativeHeap.free(zStream) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/* | ||
* Copyright (C) 2024 Square, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package okio | ||
|
||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
import kotlin.test.assertFailsWith | ||
import kotlin.test.assertFalse | ||
import kotlin.test.assertTrue | ||
import okio.ByteString.Companion.decodeBase64 | ||
import okio.ByteString.Companion.decodeHex | ||
import okio.ByteString.Companion.toByteString | ||
|
||
class InflaterTest { | ||
@Test | ||
fun happyPath() { | ||
val inflater = Inflater().apply { | ||
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" | ||
.decodeBase64()!!.toByteArray() | ||
sourcePos = 0 | ||
sourceLimit = source.size | ||
|
||
target = ByteArray(256) | ||
targetPos = 0 | ||
targetLimit = target.size | ||
} | ||
|
||
assertTrue(inflater.inflate()) | ||
assertEquals(inflater.sourceLimit, inflater.sourcePos) | ||
|
||
val inflated = inflater.target.toByteString(0, inflater.targetPos) | ||
assertEquals( | ||
"God help us, we're in the hands of engineers.", | ||
inflated.utf8(), | ||
) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun inflateInParts() { | ||
val inflater = Inflater().apply { | ||
target = ByteArray(256) | ||
targetPos = 0 | ||
targetLimit = target.size | ||
} | ||
|
||
inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJ".decodeBase64()!!.toByteArray() | ||
inflater.sourcePos = 0 | ||
inflater.sourceLimit = inflater.source.size | ||
assertFalse(inflater.inflate()) | ||
assertEquals(inflater.sourceLimit, inflater.sourcePos) | ||
|
||
inflater.source = "SFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64()!!.toByteArray() | ||
inflater.sourcePos = 0 | ||
inflater.sourceLimit = inflater.source.size | ||
assertTrue(inflater.inflate()) | ||
assertEquals(inflater.sourceLimit, inflater.sourcePos) | ||
|
||
val inflated = inflater.target.toByteString(0, inflater.targetPos) | ||
assertEquals( | ||
"God help us, we're in the hands of engineers.", | ||
inflated.utf8(), | ||
) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun inflateInsufficientSpaceInTarget() { | ||
val targetBuffer = Buffer() | ||
|
||
val inflater = Inflater().apply { | ||
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" | ||
.decodeBase64()!!.toByteArray() | ||
sourcePos = 0 | ||
sourceLimit = source.size | ||
} | ||
|
||
inflater.target = ByteArray(31) | ||
inflater.targetPos = 0 | ||
inflater.targetLimit = inflater.target.size | ||
assertFalse(inflater.inflate()) | ||
assertEquals(inflater.targetLimit, inflater.targetPos) | ||
targetBuffer.write(inflater.target) | ||
|
||
inflater.target = ByteArray(256) | ||
inflater.targetPos = 0 | ||
inflater.targetLimit = inflater.target.size | ||
assertTrue(inflater.inflate()) | ||
assertEquals(inflater.sourcePos, inflater.sourceLimit) | ||
targetBuffer.write(inflater.target, 0, inflater.targetPos) | ||
|
||
assertEquals( | ||
"God help us, we're in the hands of engineers.", | ||
targetBuffer.readUtf8(), | ||
) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun inflateEmptyContent() { | ||
val inflater = Inflater().apply { | ||
source = "AwA=".decodeBase64()!!.toByteArray() | ||
sourcePos = 0 | ||
sourceLimit = source.size | ||
|
||
target = ByteArray(256) | ||
targetPos = 0 | ||
targetLimit = target.size | ||
} | ||
|
||
assertTrue(inflater.inflate()) | ||
|
||
val inflated = inflater.target.toByteString(0, inflater.targetPos) | ||
assertEquals( | ||
"", | ||
inflated.utf8(), | ||
) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun inflateInPartsStartingWithEmptySource() { | ||
val inflater = Inflater().apply { | ||
target = ByteArray(256) | ||
targetPos = 0 | ||
targetLimit = target.size | ||
} | ||
|
||
inflater.source = ByteArray(256) | ||
inflater.sourcePos = 0 | ||
inflater.sourceLimit = 0 | ||
assertFalse(inflater.inflate()) | ||
|
||
inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" | ||
.decodeBase64()!!.toByteArray() | ||
inflater.sourcePos = 0 | ||
inflater.sourceLimit = inflater.source.size | ||
assertTrue(inflater.inflate()) | ||
|
||
val inflated = inflater.target.toByteString(0, inflater.targetPos) | ||
assertEquals( | ||
"God help us, we're in the hands of engineers.", | ||
inflated.utf8(), | ||
) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun inflateInvalidData() { | ||
val inflater = Inflater().apply { | ||
target = ByteArray(256) | ||
targetPos = 0 | ||
targetLimit = target.size | ||
} | ||
|
||
inflater.source = "ffffffffffffffff".decodeHex().toByteArray() | ||
inflater.sourcePos = 0 | ||
inflater.sourceLimit = inflater.source.size | ||
val exception = assertFailsWith<ProtocolException> { | ||
inflater.inflate() | ||
} | ||
assertEquals("Z_DATA_ERROR", exception.message) | ||
|
||
inflater.close() | ||
} | ||
|
||
@Test | ||
fun cannotInflateAfterClose() { | ||
val inflater = Inflater() | ||
inflater.close() | ||
|
||
assertFailsWith<IllegalStateException> { | ||
inflater.inflate() | ||
} | ||
} | ||
|
||
@Test | ||
fun closeIsIdemptent() { | ||
val inflater = Inflater() | ||
inflater.close() | ||
inflater.close() | ||
} | ||
} |