-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NSInputStream.source() and BufferedSource.inputStream() functions for Apple's NSInputStream #1123
base: master
Are you sure you want to change the base?
Add NSInputStream.source() and BufferedSource.inputStream() functions for Apple's NSInputStream #1123
Changes from 15 commits
e0548f2
5578fd8
b95de57
ddb66c3
a6c3e16
7a8c292
b8d1ba9
c319152
506fa4f
aa0b6d3
9ac4e4d
7e9cbfa
72efc44
0f74ac7
3191c97
031a3a1
3571012
41c38a5
c59ea79
fb883f9
e554a72
f917c2e
a40e1c7
97192bb
bbddc69
8b3fe7e
f5d63c0
4b7600f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright (C) 2020 Square, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package okio | ||
|
||
import kotlinx.cinterop.CPointer | ||
import kotlinx.cinterop.CPointerVar | ||
import kotlinx.cinterop.UnsafeNumber | ||
import kotlinx.cinterop.addressOf | ||
import kotlinx.cinterop.convert | ||
import kotlinx.cinterop.pointed | ||
import kotlinx.cinterop.reinterpret | ||
import kotlinx.cinterop.usePinned | ||
import kotlinx.cinterop.value | ||
import platform.Foundation.NSData | ||
import platform.Foundation.NSError | ||
import platform.Foundation.NSInputStream | ||
import platform.Foundation.NSLocalizedDescriptionKey | ||
import platform.Foundation.NSUnderlyingErrorKey | ||
import platform.darwin.NSInteger | ||
import platform.darwin.NSUInteger | ||
import platform.darwin.NSUIntegerVar | ||
import platform.posix.memcpy | ||
import platform.posix.uint8_tVar | ||
|
||
fun BufferedSource.inputStream(): NSInputStream = BufferedSourceInputStream(this) | ||
|
||
/** Returns an input stream that reads from this source. */ | ||
@OptIn(UnsafeNumber::class) | ||
private class BufferedSourceInputStream( | ||
private val bufferedSource: BufferedSource | ||
) : NSInputStream(NSData()) { | ||
|
||
private var error: NSError? = null | ||
|
||
override fun streamError(): NSError? = error | ||
|
||
override fun open() { | ||
// no-op | ||
} | ||
|
||
override fun read(buffer: CPointer<uint8_tVar>?, maxLength: NSUInteger): NSInteger { | ||
try { | ||
val internalBuffer = bufferedSource.buffer | ||
|
||
if (bufferedSource is RealBufferedSource) { | ||
if (bufferedSource.closed) throw IOException("closed") | ||
|
||
if (internalBuffer.size == 0L) { | ||
val count = bufferedSource.source.read(internalBuffer, Segment.SIZE.toLong()) | ||
if (count == -1L) return 0 | ||
} | ||
} | ||
|
||
val toRead = minOf(maxLength.toInt(), internalBuffer.size).toInt() | ||
return internalBuffer.readNative(buffer, toRead).convert() | ||
} catch (e: Exception) { | ||
error = e.toNSError() | ||
return -1 | ||
} | ||
} | ||
|
||
override fun getBuffer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think this is a big deal for now, but eventually it might be a problem .. . . Using usePinned to get the address of a byte within a buffer is fine, but I worry that the caller might need usePinned when they later read from the buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (This is only an issue if Kotlin’s GC ever learns to relocate objects) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wondered if this could be problematic, but couldn't think of a way to pass the unpinning to the caller. I figure typical usage would read the buffer immediately after calling the function and wouldn't hold the reference longer. Still a minute chance GC occurs right between unpinning and reading though, and even tinier chance the address becomes invalid. Copying the buffer would nullify the usefulness of the API and violate the documented behavior of returning in 0(1). I figured having it implemented is probably preferable to not, but alternatively we could just return false. I wonder what the behavior would be to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The memory leak question is also a good one. At the moment this API doesn't tell the GC about this reference to the ByteArray, so it could be collected (or recycled) before the caller gets to use it. Maybe we should pin it here, and/or take other steps to avoid recycling, then unpin either on the next call to this method or on close(). That would work and it wouldn't leak. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great idea! I'll make that change. |
||
buffer: CPointer<CPointerVar<uint8_tVar>>?, | ||
length: CPointer<NSUIntegerVar>? | ||
): Boolean { | ||
if (bufferedSource.buffer.size > 0) { | ||
bufferedSource.buffer.head?.let { s -> | ||
s.data.usePinned { | ||
buffer?.pointed?.value = it.addressOf(s.pos).reinterpret() | ||
length?.pointed?.value = (s.limit - s.pos).convert() | ||
return true | ||
} | ||
} | ||
} | ||
return false | ||
} | ||
|
||
override fun hasBytesAvailable(): Boolean = bufferedSource.buffer.size > 0 | ||
|
||
override fun close() = bufferedSource.close() | ||
|
||
override fun description(): String = "$bufferedSource.inputStream()" | ||
|
||
private fun Exception.toNSError(): NSError { | ||
return NSError( | ||
"Kotlin", | ||
0, | ||
mapOf( | ||
NSLocalizedDescriptionKey to message, | ||
NSUnderlyingErrorKey to this | ||
) | ||
) | ||
} | ||
|
||
private fun Buffer.readNative(sink: CPointer<uint8_tVar>?, maxLength: Int): Int { | ||
val s = head ?: return 0 | ||
val toCopy = minOf(maxLength, s.limit - s.pos) | ||
s.data.usePinned { | ||
memcpy(sink, it.addressOf(s.pos), toCopy.convert()) | ||
} | ||
|
||
s.pos += toCopy | ||
size -= toCopy.toLong() | ||
|
||
if (s.pos == s.limit) { | ||
head = s.pop() | ||
SegmentPool.recycle(s) | ||
} | ||
|
||
return toCopy | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright (C) 2020 Square, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package okio | ||
|
||
import kotlinx.cinterop.UnsafeNumber | ||
import kotlinx.cinterop.addressOf | ||
import kotlinx.cinterop.convert | ||
import kotlinx.cinterop.reinterpret | ||
import kotlinx.cinterop.usePinned | ||
import platform.Foundation.NSInputStream | ||
import platform.darwin.UInt8Var | ||
|
||
/** Returns a source that reads from `in`. */ | ||
fun NSInputStream.source(): Source = NSInputStreamSource(this) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @swankjesse after merging the latest upstream changes, the linter doesn't like the lowercase filename |
||
|
||
@OptIn(UnsafeNumber::class) | ||
private open class NSInputStreamSource( | ||
private val input: NSInputStream | ||
) : Source { | ||
|
||
init { | ||
input.open() | ||
} | ||
|
||
override fun read(sink: Buffer, byteCount: Long): Long { | ||
if (byteCount == 0L) return 0L | ||
require(byteCount >= 0L) { "byteCount < 0: $byteCount" } | ||
val tail = sink.writableSegment(1) | ||
val maxToCopy = minOf(byteCount, Segment.SIZE - tail.limit) | ||
val bytesRead = tail.data.usePinned { | ||
val bytes = it.addressOf(tail.limit).reinterpret<UInt8Var>() | ||
input.read(bytes, maxToCopy.convert()).toLong() | ||
} | ||
if (bytesRead < 0) throw IOException(input.streamError?.localizedDescription) | ||
if (bytesRead == 0L) { | ||
if (tail.pos == tail.limit) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice |
||
// We allocated a tail segment, but didn't end up needing it. Recycle! | ||
sink.head = tail.pop() | ||
SegmentPool.recycle(tail) | ||
} | ||
return -1 | ||
} | ||
tail.limit += bytesRead.toInt() | ||
sink.size += bytesRead | ||
return bytesRead.convert() | ||
} | ||
|
||
override fun close() = input.close() | ||
|
||
override fun timeout() = Timeout.NONE | ||
|
||
override fun toString() = "source($input)" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/* | ||
* Copyright (C) 2020 Square, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package okio | ||
|
||
import kotlinx.cinterop.CPointerVar | ||
import kotlinx.cinterop.UnsafeNumber | ||
import kotlinx.cinterop.addressOf | ||
import kotlinx.cinterop.alloc | ||
import kotlinx.cinterop.convert | ||
import kotlinx.cinterop.get | ||
import kotlinx.cinterop.memScoped | ||
import kotlinx.cinterop.ptr | ||
import kotlinx.cinterop.reinterpret | ||
import kotlinx.cinterop.usePinned | ||
import kotlinx.cinterop.value | ||
import platform.Foundation.NSInputStream | ||
import platform.darwin.NSUIntegerVar | ||
import platform.darwin.UInt8Var | ||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
import kotlin.test.assertFalse | ||
import kotlin.test.assertNotNull | ||
import kotlin.test.assertTrue | ||
|
||
@OptIn(UnsafeNumber::class) | ||
class AppleBufferedSourceTest { | ||
@Test fun bufferInputStream() { | ||
val source = Buffer() | ||
source.writeUtf8("abc") | ||
testInputStream(source.inputStream()) | ||
} | ||
|
||
@Test fun realBufferedSourceInputStream() { | ||
val source = Buffer() | ||
source.writeUtf8("abc") | ||
testInputStream(RealBufferedSource(source).inputStream()) | ||
} | ||
|
||
private fun testInputStream(nsis: NSInputStream) { | ||
nsis.open() | ||
val byteArray = ByteArray(4) | ||
byteArray.usePinned { | ||
val cPtr = it.addressOf(0).reinterpret<UInt8Var>() | ||
|
||
byteArray.fill(-5) | ||
assertEquals(3, nsis.read(cPtr, 4)) | ||
assertEquals("[97, 98, 99, -5]", byteArray.contentToString()) | ||
|
||
byteArray.fill(-7) | ||
assertEquals(0, nsis.read(cPtr, 4)) | ||
assertEquals("[-7, -7, -7, -7]", byteArray.contentToString()) | ||
} | ||
} | ||
|
||
@Test fun nsInputStreamGetBuffer() { | ||
val source = Buffer() | ||
source.writeUtf8("abc") | ||
|
||
val nsis = source.inputStream() | ||
nsis.open() | ||
assertTrue(nsis.hasBytesAvailable) | ||
|
||
memScoped { | ||
val bufferPtr = alloc<CPointerVar<UInt8Var>>() | ||
val lengthPtr = alloc<NSUIntegerVar>() | ||
assertTrue(nsis.getBuffer(bufferPtr.ptr, lengthPtr.ptr)) | ||
|
||
val length = lengthPtr.value | ||
assertNotNull(length) | ||
assertEquals(3.convert(), length) | ||
|
||
val buffer = bufferPtr.value | ||
assertNotNull(buffer) | ||
assertEquals('a'.code.toUByte(), buffer[0]) | ||
assertEquals('b'.code.toUByte(), buffer[1]) | ||
assertEquals('c'.code.toUByte(), buffer[2]) | ||
} | ||
} | ||
|
||
@Test fun nsInputStreamClose() { | ||
val buffer = Buffer() | ||
buffer.writeUtf8("abc") | ||
val source = RealBufferedSource(buffer) | ||
assertFalse(source.closed) | ||
|
||
val nsis = source.inputStream() | ||
nsis.open() | ||
nsis.close() | ||
assertTrue(source.closed) | ||
|
||
val byteArray = ByteArray(4) | ||
byteArray.usePinned { | ||
val cPtr = it.addressOf(0).reinterpret<UInt8Var>() | ||
|
||
byteArray.fill(-5) | ||
assertEquals(-1, nsis.read(cPtr, 4)) | ||
assertNotNull(nsis.streamError) | ||
assertEquals("closed", nsis.streamError?.localizedDescription) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice |
||
assertEquals("[-5, -5, -5, -5]", byteArray.contentToString()) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rewrite this if clause as
That will guarantee the buffer has at least one byte in it after!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, nice. I see
RealBufferedSource.commonExhausted()
is this exact logic. I went ahead and replaced this same logic in other places as well 3571012. Let me know if you'd rather have that reverted though.