Skip to content

Commit

Permalink
Fix dictionary encoding and decoding (#234)
Browse files Browse the repository at this point in the history
* Make `map` behave like an unpacked repeated field
* Fix _encoded_size
---------

Co-authored-by: Drvi <[email protected]>
  • Loading branch information
JamieMair and Drvi authored Aug 14, 2023
1 parent 548017c commit 851224a
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 184 deletions.
90 changes: 40 additions & 50 deletions src/codec/decode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,75 +30,65 @@ function decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Enum{Int32
end
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(d.io, T)
function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V<:_ScalarTypesEnum}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V)
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, Ref{V})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, Ref{V})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

for T in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(T)})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(T)})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V)
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(S)})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(S)})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end
end
Expand Down
56 changes: 28 additions & 28 deletions src/codec/encode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,52 @@ function _encode(io::IO, x::Vector{T}) where {T<:Union{UInt32,UInt64,Int32,Int64
return nothing
end

function _encode(_e::ProtoEncoder, x::Dict{K,V}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k)
encode(_e, 2, v)
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1) + _encoded_size(v, 2)))
encode(e, 1, k)
encode(e, 2, v)
end
nothing
end

for T in (:(:fixed), :(:zigzag))
@eval function _encode(_e::ProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k, Val{$(T)})
encode(_e, 2, v)
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1, Val{$(T)}) + _encoded_size(v, 2)))
encode(e, 1, k, Val{$(T)})
encode(e, 2, v)
end
nothing
end
@eval function _encode(_e::ProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k)
encode(_e, 2, v, Val{$(T)})
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1) + _encoded_size(v, 2, Val{$(T)})))
encode(e, 1, k)
encode(e, 2, v, Val{$(T)})
end
nothing
end
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval function _encode(_e::AbstractProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k, Val{$(T)})
encode(_e, 2, v, Val{$(S)})
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1, Val{$(T)}) + _encoded_size(v, 2, Val{$(S)})))
encode(e, 1, k, Val{$(T)})
encode(e, 2, v, Val{$(S)})
end
nothing
end
Expand Down Expand Up @@ -269,18 +281,6 @@ function encode(e::AbstractProtoEncoder, i::Int, x::Vector{T}) where {T<:Union{U
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}) where {K,V}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e, x)
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}, ::Type{W}) where {K,V,W}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e, x, W)
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Vector{T}, ::Type{Val{:zigzag}}) where {T<:Union{Int32,Int64}}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e.io, x, Val{:zigzag})
Expand Down
38 changes: 30 additions & 8 deletions src/codec/encoded_size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ _encoded_size(x::T, ::Type{Val{:fixed}}) where {T<:Union{Int32,UInt32,Int64,UInt

# For Length-Delimited fields we don't include the encoded number of bytes
# unless we also provide the field number in which case we encode both the
# tag and lenght
# tag and length
_with_size(n::Int) = (n + _encoded_size(n))
_encoded_size(x::String) = sizeof(x)

Expand All @@ -29,19 +29,41 @@ _encoded_size(xs::AbstractVector{T}) where {T<:Union{UInt8,Bool,Float64,Float32}
_encoded_size(xs::AbstractVector{T}) where {T<:Union{String,AbstractVector{UInt8}}} = sum(x->_with_size(_encoded_size(x)), xs, init=0)
_encoded_size(xs::AbstractVector{T}, ::Type{Val{:fixed}}) where {T<:Union{Int32,UInt32,Int64,UInt64}} = sizeof(xs)

# Dicts add dummy tags to both keys and values
_encoded_size(d::AbstractDict) = mapreduce(x->_encoded_size(x.first, 1) + _encoded_size(x.second, 2), +, d, init=0)
# Dicts add dummy tags to both keys and values and to each pair
# _encoded_size(::AbstractDict) does not include the "pair" tag and field number
# those are added in the _encoded_size(::AbstractDict, ::Int) methods below because the field number
# is not known at this point
function _encoded_size(d::AbstractDict)
mapreduce(x->begin
total_size = _encoded_size(x.first, 1) + _encoded_size(x.second, 2)
return _varint_size(total_size) + total_size
end, +, d, init=0)
end
_encoded_size(xs::AbstractDict, i::Int) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs)

for T in (:(:fixed), :(:zigzag))
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),Nothing}}}) = mapreduce(x->_encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{Nothing,$(T)}}}) = mapreduce(x->_encoded_size(x.first, 1) + _encoded_size(x.second, 2, Val{$(T)}), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),Nothing}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2)
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{$(T),Nothing}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{$(T),Nothing}})

@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{Nothing,$(T)}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1) + _encoded_size(x.second, 2, Val{$(T)})
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{Nothing,$(T)}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{Nothing,$(T)}})

@eval _encoded_size(xs::Union{AbstractDict,AbstractVector}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _with_size(_encoded_size(xs, Val{$(T)}))
@eval _encoded_size(xs::Union{Int32,Int64,UInt64,UInt32}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _encoded_size(xs, Val{$(T)})
@eval _encoded_size(xs::AbstractVector, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _with_size(_encoded_size(xs, Val{$(T)}))
@eval _encoded_size(xs::Union{Int32,Int64,UInt64,UInt32}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _encoded_size(xs, Val{$(T)})
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),$(S)}}}) = mapreduce(x->_encoded_size(x.first, 1, Val{$(S)}) + _encoded_size(x.second, 2, Val{$(S)}), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),$(S)}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2, Val{$(S)})
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{$(T),$(S)}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{$(T),$(S)}})
end

# These methods handle fields that refer to messages/groups
Expand Down
Loading

0 comments on commit 851224a

Please sign in to comment.