Skip to content

Commit

Permalink
refactor IO internals
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasIsensee committed Sep 19, 2024
1 parent 72acce8 commit 05023d4
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 211 deletions.
21 changes: 19 additions & 2 deletions src/JLD2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,27 @@ jlsizeof(x) = Base.sizeof(x)
jlunsafe_store!(p, x) = Base.unsafe_store!(p, x)
jlunsafe_load(p) = Base.unsafe_load(p)

include("mmapio.jl")
include("bufferedio.jl")
"""
MemoryBackedIO <: IO
Abstract type for IO objects that are backed by memory in such a way that
one can use pointer based `unsafe_load` and `unsafe_store!` operations
after ensuring that there is enough memory allocated.
This is used for dispatch to [`MmapIO`](@ref) and [`BufferedWriter`](@ref).
It needs to provide:
- `getproperty(io, :curptr)` to get the current pointer
- `ensureroom(io, nb)` to ensure that there are at least nb bytes available
- `position(io)` to get the current (zero-based) position
- `seek(io, pos)` to set the current position (zero-based)
"""
abstract type MemoryBackedIO <: IO end

include("macros_utils.jl")
include("types.jl")
include("mmapio.jl")
include("bufferedio.jl")
include("julia_compat.jl")
include("file_header.jl")
include("Lookup3.jl")
Expand Down
166 changes: 89 additions & 77 deletions src/bufferedio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,149 +2,161 @@
# BufferedIO
#

const DEFAULT_BUFFER_SIZE = 1024

struct BufferedWriter{io} <: IO
mutable struct BufferedWriter{io} <: MemoryBackedIO
f::io
buffer::Vector{UInt8}
file_position::Int64
position::Base.RefValue{Int}
curptr::Ptr{Nothing}
extensible::Bool
end

function BufferedWriter(io, buffer_size::Int)
function BufferedWriter(io, buffer_size::Int = 0; extensible::Bool=false)
pos = position(io)
skip(io, buffer_size)
BufferedWriter(io, Vector{UInt8}(undef, buffer_size), pos, Ref{Int}(0))
buf = Vector{UInt8}(undef, buffer_size)
BufferedWriter(io, buf, pos, Ptr{Nothing}(pointer(buf)), extensible)
end
Base.show(io::IO, ::BufferedWriter) = print(io, "BufferedWriter")

function ensureroom(io::BufferedWriter, n::Integer)
if bufferpos(io) + n > length(io.buffer)
if io.extensible
pos = bufferpos(io)
resize!(io.buffer, length(io.buffer) + n)
io.curptr = pointer(io.buffer, pos+1)
else
throw(InternalError("BufferedWriter: not enough room"))
end
end
end

Base.position(io::BufferedWriter) = io.file_position + bufferpos(io)

function Base.seek(io::BufferedWriter, offset::Integer)
buffer_offset = offset - io.file_position
buffer_offset < 0 && throw(ArgumentError("cannot seek before start of buffer"))
ensureroom(io, buffer_offset - bufferpos(io))
io.curptr += buffer_offset - bufferpos(io)
end

function finish!(io::BufferedWriter)
f = io.f
buffer = io.buffer
io.position[] == length(buffer) ||
error("buffer not written to end; position is $(io.position[]) but length is $(length(buffer))")
bufferpos(io) == length(io.buffer) ||
throw(InternalError("BufferedWriter: buffer not written to end; position is $(bufferpos(io)) but length is $(length(io.buffer))"))
seek(f, io.file_position)
jlwrite(f, buffer)
io.position[] = 0
jlwrite(f, io.buffer)
nothing
end

@inline function _write(io::BufferedWriter, x)
position = io.position[]
buffer = io.buffer
function _write(io::BufferedWriter, x)
n = jlsizeof(x)
n + position <= length(buffer) || throw(EOFError())
io.position[] = position + n
jlunsafe_store!(Ptr{typeof(x)}(pointer(buffer, position+1)), x)
# Base.show_backtrace(STDOUT, backtrace())
# gc()
return n
end
@inline jlwrite(io::BufferedWriter, x::UInt8) = _write(io, x)
@inline jlwrite(io::BufferedWriter, x::Int8) = _write(io, x)
@inline jlwrite(io::BufferedWriter, x::Plain) = _write(io, x)
@inline Base.write(io::BufferedWriter, x::UInt8) = _write(io, x)
@inline Base.write(io::BufferedWriter, x::Int8) = _write(io, x)
@inline Base.write(io::BufferedWriter, x::Plain) = _write(io, x)
ensureroom(io, n)
jlunsafe_store!(Ptr{typeof(x)}(io.curptr), x)
io.curptr += n
end

jlwrite(io::BufferedWriter, x::Union{UInt8,Int8}) = _write(io, x)

function Base.unsafe_write(io::BufferedWriter, x::Ptr{UInt8}, n::UInt64)
buffer = io.buffer
position = io.position[]
n + position <= length(buffer) || throw(EOFError())
unsafe_copyto!(pointer(buffer, position+1), x, n)
io.position[] = position + n
return n
ensureroom(io, n)
unsafe_copyto!(Ptr{UInt8}(io.curptr), x, n)
io.curptr += n
end

Base.position(io::BufferedWriter) = io.file_position + io.position[]

struct BufferedReader{io} <: IO
mutable struct BufferedReader{io} <: MemoryBackedIO
f::io
buffer::Vector{UInt8}
file_position::Int64
position::Base.RefValue{Int}
curptr::Ptr{Nothing}
end

BufferedReader(io) =
BufferedReader(io, Vector{UInt8}(), position(io), Ref{Int}(0))
function BufferedReader(io)
buf = Vector{UInt8}()
BufferedReader(io, buf, position(io), Ptr{Nothing}(pointer(buf)))
end
Base.show(io::IO, ::BufferedReader) = print(io, "BufferedReader")

#Base.bytesavailable(io::BufferedReader) = bytesavailable(io.f) + length(io.buffer) - bufferpos(io)

function readmore!(io::BufferedReader, n::Integer)
f = io.f
amount = max(bytesavailable(f), n)
amount = max(bytesavailable(f), n) #TODO: check if this reads way to much
amount < n && throw(EOFError())
buffer = io.buffer
oldlen = length(buffer)
resize!(buffer, oldlen + amount)
unsafe_read(f, pointer(buffer, oldlen+1), amount)
pos = bufferpos(io)
resize!(buffer, length(buffer) + amount)
io.curptr = pointer(buffer, pos+1)
unsafe_read(f, io.curptr, amount)
end

@inline function _read(io::BufferedReader, T::DataType)
position = io.position[]
buffer = io.buffer
if length(buffer) - position < jlsizeof(T)
readmore!(io, jlsizeof(T))
end
io.position[] = position + jlsizeof(T)
jlunsafe_load(Ptr{T}(pointer(buffer, position+1)))
ensureroom(io::BufferedReader, n::Integer) =
(bufferpos(io) + n >= length(io.buffer)) && readmore!(io, n)

function _read(io::BufferedReader, T::DataType)
n = jlsizeof(T)
ensureroom(io, n)
v = jlunsafe_load(Ptr{T}(io.curptr))
io.curptr += n
v
end
@inline jlread(io::BufferedReader, T::Type{UInt8}) = _read(io, T)
@inline jlread(io::BufferedReader, T::Type{Int8}) = _read(io, T)
@inline jlread(io::BufferedReader, T::PlainType) = _read(io, T)
jlread(io::BufferedReader, T::Type{UInt8}) = _read(io, T)
jlread(io::BufferedReader, T::Type{Int8}) = _read(io, T)
jlread(io::BufferedReader, T::PlainType) = _read(io, T)

function jlread(io::BufferedReader, ::Type{T}, n::Int) where T
position = io.position[]
buffer = io.buffer
m = jlsizeof(T) * n
if length(buffer) - position < m
readmore!(io, m)
end
io.position[] = position + m
ensureroom(io, m)
arr = Vector{T}(undef, n)
unsafe_copyto!(pointer(arr), Ptr{T}(pointer(buffer, position+1)), n)
unsafe_copyto!(pointer(arr), Ptr{T}(io.curptr), n)
io.curptr += m
arr
end
jlread(io::BufferedReader, ::Type{T}, n::Integer) where {T} =
jlread(io, T, Int(n))

Base.position(io::BufferedReader) = io.file_position + io.position[]
Base.position(io::BufferedReader) = io.file_position + bufferpos(io)

"""
bufferpos(io::Union{BufferedReader, BufferedWriter})
Get the current position in the buffer.
"""
bufferpos(io::Union{BufferedReader, BufferedWriter}) = Int(io.curptr - pointer(io.buffer))

function adjust_position!(io::BufferedReader, position::Integer)
if position < 0
throw(ArgumentError("cannot seek before start of buffer"))
elseif position > length(io.buffer)
position < 0 && throw(ArgumentError("cannot seek before start of buffer"))
if position > length(io.buffer)
readmore!(io, position - length(io.buffer))
end
io.position[] = position
io.curptr = pointer(io.buffer, position+1)
position
end

Base.seek(io::BufferedReader, offset::Integer) =
adjust_position!(io, offset - io.file_position)

Base.skip(io::BufferedReader, offset::Integer) =
adjust_position!(io, io.position[] + offset)
Base.seek(io::BufferedReader, offset::Integer) = adjust_position!(io, offset - io.file_position)
Base.skip(io::BufferedReader, offset::Integer) = adjust_position!(io, bufferpos(io) + offset)

finish!(io::BufferedReader) =
seek(io.f, io.file_position + io.position[])
finish!(io::BufferedReader) = seek(io.f, io.file_position + bufferpos(io))

function truncate_and_close(io::IOStream, endpos::Integer)
truncate(io, endpos)
close(io)
end

Base.close(::BufferedReader) = nothing


# We sometimes need to compute checksums. We do this by first calling begin_checksum when
# starting to handle whatever needs checksumming, and calling end_checksum afterwards. Note
# that we never compute nested checksums, but we may compute multiple checksums
# simultaneously.

function begin_checksum_read(io::IOStream)
BufferedReader(io)
end
begin_checksum_read(io::IO) = BufferedReader(io)

function begin_checksum_write(io::IOStream, sz::Integer)
BufferedWriter(io, sz)
end
function end_checksum(io::Union{BufferedReader,BufferedWriter})
ret = Lookup3.hash(io.buffer, 1, io.position[])
ret = Lookup3.hash(io.buffer, 1, bufferpos(io))
finish!(io)
ret
end
Expand Down
Loading

0 comments on commit 05023d4

Please sign in to comment.