Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

safety improvement to Memory #167

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ mutable struct Buffer
# the total number of transcoded bytes
transcoded::Int64

# a flag to indicate if writing to it should be an error
# an application could choose to copy the data field and clear this flag
immutable::Bool

function Buffer(data::Vector{UInt8}, marginpos::Integer=length(data)+1)
@assert 1 <= marginpos <= length(data)+1
return new(data, 0, 1, marginpos, 0)
return new(data, 0, 1, marginpos, 0, false)
end
end

Expand All @@ -37,7 +41,9 @@ function Buffer(size::Integer = 0)
end

function Buffer(data::Base.CodeUnits{UInt8}, args...)
return Buffer(Vector{UInt8}(data), args...)
buf = Buffer(unsafe_wrap(Vector{UInt8}, String(data)), args...)
buf.immutable = true # application should try not to accidentally corrupt this String
return buf
end

function Base.length(buf::Buffer)
Expand All @@ -53,7 +59,7 @@ function buffersize(buf::Buffer)
end

function buffermem(buf::Buffer)
return Memory(bufferptr(buf), buffersize(buf))
return Memory(buf.data, buf.bufferpos, buffersize(buf))
end

function marginptr(buf::Buffer)
Expand All @@ -65,7 +71,8 @@ function marginsize(buf::Buffer)
end

function marginmem(buf::Buffer)
return Memory(marginptr(buf), marginsize(buf))
buf.immutable && throw(ArgumentError("Buffer is immutable"))
return Memory(buf.data, buf.marginpos, marginsize(buf))
end

function ismarked(buf::Buffer)
Expand Down Expand Up @@ -126,6 +133,7 @@ end
# Make margin with ≥`minsize` and return the size of it.
# If eager is true, it tries to move data even when the buffer has enough margin.
function makemargin!(buf::Buffer, minsize::Integer; eager::Bool = false)
buf.immutable && throw(ArgumentError("Buffer is immutable"))
@assert minsize ≥ 0
if buffersize(buf) == 0 && buf.markpos == 0
buf.bufferpos = buf.marginpos = 1
Expand Down Expand Up @@ -171,6 +179,7 @@ end

# Write a byte.
function writebyte!(buf::Buffer, b::UInt8)
buf.immutable && throw(ArgumentError("Buffer is immutable"))
buf.data[buf.marginpos] = b
supplied!(buf, 1)
return 1
Expand Down
8 changes: 4 additions & 4 deletions src/codec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,21 @@

Return the expected size of the transcoded `input` with `codec`.

The default method returns `input.size`.
The default method returns `length(input)`.
"""
function expectedsize(codec::Codec, input::Memory)::Int
return input.size
return length(input)

Check warning on line 112 in src/codec.jl

View check run for this annotation

Codecov / codecov/patch

src/codec.jl#L112

Added line #L112 was not covered by tests
end

"""
minoutsize(codec::Codec, input::Memory)::Int

Return the minimum output size to be ensured when calling `process`.

The default method returns `max(1, div(input.size, 4))`.
The default method returns `max(1, div(length(input), 4))`.
"""
function minoutsize(codec::Codec, input::Memory)::Int
return max(1, div(input.size, 4))
return max(1, div(length(input), 4))
end

"""
Expand Down
48 changes: 24 additions & 24 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,39 @@
# ======

"""
A contiguous memory.
A contiguous view into other memory.

This type works like a `Vector` method.
This type works like a `SubVector` method.
"""
struct Memory
ptr::Ptr{UInt8}
size::UInt
end

function Memory(data::ByteData)
return Memory(pointer(data), sizeof(data))
struct Memory # <: AbstractVector{UInt8}
data::Vector{UInt8}
first::Int
size::Int
function Memory(data::Vector{UInt8})
return new(data, 1, sizeof(data))
end
function Memory(data::Vector{UInt8}, first, length)
checkbounds(data, first:(first - 1 + length))
return new(data, first, length)
end
end

@inline function Base.length(mem::Memory)
return mem.size
function Memory(data::Base.CodeUnits{UInt8}, args...)
return Memory(unsafe_wrap(Vector{UInt8}, String(data)), args...)

Check warning on line 22 in src/memory.jl

View check run for this annotation

Codecov / codecov/patch

src/memory.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end

@inline function Base.lastindex(mem::Memory)
return Int(mem.size)
@inline function Base.getproperty(mem::Memory, field::Symbol)
field === :ptr && return pointer(getfield(mem, :data), getfield(mem, :first))
return getfield(mem, field)
end
Base.length(mem::Memory) = mem.size % UInt
Base.sizeof(mem::Memory) = mem.size

Check warning on line 30 in src/memory.jl

View check run for this annotation

Codecov / codecov/patch

src/memory.jl#L30

Added line #L30 was not covered by tests
Base.lastindex(mem::Memory) = mem.size

@inline function Base.checkbounds(mem::Memory, i::Integer)
function Base.checkbounds(mem::Memory, i::Integer)

Check warning on line 33 in src/memory.jl

View check run for this annotation

Codecov / codecov/patch

src/memory.jl#L33

Added line #L33 was not covered by tests
if !(1 ≤ i ≤ lastindex(mem))
throw(BoundsError(mem, i))
end
end

@inline function Base.getindex(mem::Memory, i::Integer)
@boundscheck checkbounds(mem, i)
return unsafe_load(mem.ptr, i)
end

@inline function Base.setindex!(mem::Memory, val::UInt8, i::Integer)
@boundscheck checkbounds(mem, i)
return unsafe_store!(mem.ptr, val, i)
end
Base.getindex(mem::Memory, i::Integer) = mem.data[i]
Base.setindex!(mem::Memory, val::UInt8, i::Integer) = (mem.data[i] = val; mem)
2 changes: 1 addition & 1 deletion src/noop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ initial_output_size(codec::Noop, input::Memory) = length(input)
function process(codec::Noop, input::Memory, output::Memory, error::Error)
iszero(length(input)) && return (0, 0, :end)
n = min(length(input), length(output))
unsafe_copyto!(output.ptr, input.ptr, n)
GC.@preserve input output unsafe_copyto!(output.ptr, input.ptr, n)
(n, n, :ok)
end

Expand Down
2 changes: 1 addition & 1 deletion src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ function Base.unsafe_read(stream::TranscodingStream, output::Ptr{UInt8}, nbytes:
return
end

function Base.readbytes!(stream::TranscodingStream, b::AbstractArray{UInt8}, nb=length(b))
function Base.readbytes!(stream::TranscodingStream, b::DenseArray{UInt8}, nb=length(b))
ready_to_read!(stream)
filled = 0
resized = false
Expand Down
14 changes: 7 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ using TranscodingStreams:
data = Vector{UInt8}(b"foobar")
buf = Buffer(data)
@test buf isa Buffer
@test bufferptr(buf) === pointer(data)
@test bufferptr(buf) === pointer(data) === buffermem(buf).ptr
@test buffersize(buf) === 6
@test buffermem(buf) === Memory(pointer(data), 6)
@test marginptr(buf) === pointer(data) + 6
@test buffermem(buf) === Memory(data, 1, 6)
@test marginptr(buf) === pointer(data) + 6 === marginmem(buf).ptr
@test marginsize(buf) === 0
@test marginmem(buf) === Memory(pointer(data)+6, 0)
@test marginmem(buf) === Memory(data, 7, 0)

buf = Buffer(2)
writebyte!(buf, 0x34)
Expand Down Expand Up @@ -71,11 +71,11 @@ end

@testset "Memory" begin
data = Vector{UInt8}(b"foobar")
mem = TranscodingStreams.Memory(pointer(data), sizeof(data))
mem = TranscodingStreams.Memory(data, 1, sizeof(data))
@test mem isa TranscodingStreams.Memory
@test mem.ptr === pointer(data)
@test mem.size === length(mem) === UInt(sizeof(data))
@test lastindex(mem) === 6
@test length(mem) === UInt(sizeof(data))
@test mem.size === lastindex(mem) === 6
@test mem[1] === UInt8('f')
@test mem[2] === UInt8('o')
@test mem[3] === UInt8('o')
Expand Down
Loading