Skip to content

Commit

Permalink
append capsules to byte slice instead of writing to an io.Writer
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Oct 5, 2024
1 parent 730cc66 commit 2eece5e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 48 deletions.
63 changes: 27 additions & 36 deletions capsule.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,26 @@ func parseAddressAssignCapsule(r io.Reader) (*addressAssignCapsule, error) {
return &addressAssignCapsule{AssignedAddresses: assignedAddresses}, nil
}

func (c *addressAssignCapsule) marshal(w io.Writer) error {
func (c *addressAssignCapsule) append(b []byte) []byte {
totalLen := 0
for _, addr := range c.AssignedAddresses {
totalLen += addr.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressAssign))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeAddressAssign))
buf = quicvarint.Append(buf, uint64(totalLen))
b = quicvarint.Append(b, uint64(capsuleTypeAddressAssign))
b = quicvarint.Append(b, uint64(totalLen))

for _, addr := range c.AssignedAddresses {
buf = quicvarint.Append(buf, addr.RequestID)
b = quicvarint.Append(b, addr.RequestID)
if addr.IPPrefix.Addr().Is4() {
buf = append(buf, 4)
b = append(b, 4)
} else {
buf = append(buf, 6)
b = append(b, 6)
}
buf = append(buf, addr.IPPrefix.Addr().AsSlice()...)
buf = append(buf, byte(addr.IPPrefix.Bits()))
b = append(b, addr.IPPrefix.Addr().AsSlice()...)
b = append(b, byte(addr.IPPrefix.Bits()))
}

_, err := w.Write(buf)
return err
return b
}

func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) {
Expand All @@ -102,29 +99,26 @@ func parseAddressRequestCapsule(r io.Reader) (*addressRequestCapsule, error) {
return &addressRequestCapsule{RequestedAddresses: requestedAddresses}, nil
}

func (c *addressRequestCapsule) marshal(w io.Writer) error {
func (c *addressRequestCapsule) append(b []byte) []byte {
var totalLen int
for _, addr := range c.RequestedAddresses {
totalLen += addr.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeAddressRequest))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeAddressRequest))
buf = quicvarint.Append(buf, uint64(totalLen))
b = quicvarint.Append(b, uint64(capsuleTypeAddressRequest))
b = quicvarint.Append(b, uint64(totalLen))

for _, addr := range c.RequestedAddresses {
buf = quicvarint.Append(buf, addr.RequestID)
b = quicvarint.Append(b, addr.RequestID)
if addr.IPPrefix.Addr().Is4() {
buf = append(buf, 4)
b = append(b, 4)
} else {
buf = append(buf, 6)
b = append(b, 6)
}
buf = append(buf, addr.IPPrefix.Addr().AsSlice()...)
buf = append(buf, byte(addr.IPPrefix.Bits()))
b = append(b, addr.IPPrefix.Addr().AsSlice()...)
b = append(b, byte(addr.IPPrefix.Bits()))
}

_, err := w.Write(buf)
return err
return b
}

func parseAddress(r io.Reader) (requestID uint64, prefix netip.Prefix, _ error) {
Expand Down Expand Up @@ -197,29 +191,26 @@ func parseRouteAdvertisementCapsule(r io.Reader) (*routeAdvertisementCapsule, er
return &routeAdvertisementCapsule{IPAddressRanges: ranges}, nil
}

func (c *routeAdvertisementCapsule) marshal(w io.Writer) error {
func (c *routeAdvertisementCapsule) append(b []byte) []byte {
var totalLen int
for _, ipRange := range c.IPAddressRanges {
totalLen += ipRange.len()
}

buf := make([]byte, 0, quicvarint.Len(uint64(capsuleTypeRouteAdvertisement))+quicvarint.Len(uint64(totalLen))+totalLen)
buf = quicvarint.Append(buf, uint64(capsuleTypeRouteAdvertisement))
buf = quicvarint.Append(buf, uint64(totalLen))
b = quicvarint.Append(b, uint64(capsuleTypeRouteAdvertisement))
b = quicvarint.Append(b, uint64(totalLen))

for _, ipRange := range c.IPAddressRanges {
if ipRange.StartIP.Is4() {
buf = append(buf, 4)
b = append(b, 4)
} else {
buf = append(buf, 6)
b = append(b, 6)
}
buf = append(buf, ipRange.StartIP.AsSlice()...)
buf = append(buf, ipRange.EndIP.AsSlice()...)
buf = append(buf, ipRange.IPProtocol)
b = append(b, ipRange.StartIP.AsSlice()...)
b = append(b, ipRange.EndIP.AsSlice()...)
b = append(b, ipRange.IPProtocol)
}

_, err := w.Write(buf)
return err
return b
}

func parseIPAddressRange(r io.Reader) (IPAddressRange, error) {
Expand Down
24 changes: 12 additions & 12 deletions capsule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func TestWriteAddressAssignCapsule(t *testing.T) {
{RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
data := c.append(nil)
r := bytes.NewReader(data)
typ, cr, err := http3.ParseCapsule(r)
require.NoError(t, err)
require.Equal(t, capsuleTypeAddressAssign, typ)
parsed, err := parseAddressAssignCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
require.Zero(t, r.Len())
}

func TestParseAddressAssignCapsuleInvalid(t *testing.T) {
Expand Down Expand Up @@ -181,15 +181,15 @@ func TestWriteAddressRequestCapsule(t *testing.T) {
{RequestID: 1338, IPPrefix: netip.MustParsePrefix("2001:db8::1/128")},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
data := c.append(nil)
r := bytes.NewReader(data)
typ, cr, err := http3.ParseCapsule(r)
require.NoError(t, err)
require.Equal(t, capsuleTypeAddressRequest, typ)
parsed, err := parseAddressRequestCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
require.Zero(t, r.Len())
}

func TestParseAddressRequestCapsuleInvalid(t *testing.T) {
Expand Down Expand Up @@ -237,15 +237,15 @@ func TestWriteRouteAdvertisementCapsule(t *testing.T) {
{StartIP: netip.MustParseAddr("2001:db8::1"), EndIP: netip.MustParseAddr("2001:db8::100"), IPProtocol: 37},
},
}
buf := &bytes.Buffer{}
require.NoError(t, c.marshal(buf))
typ, cr, err := http3.ParseCapsule(buf)
data := c.append(nil)
r := bytes.NewReader(data)
typ, cr, err := http3.ParseCapsule(r)
require.NoError(t, err)
require.Equal(t, capsuleTypeRouteAdvertisement, typ)
parsed, err := parseRouteAdvertisementCapsule(cr)
require.NoError(t, err)
require.Equal(t, c, parsed)
require.Zero(t, buf.Len())
require.Zero(t, r.Len())
}

func TestParseRouteAdvertisementCapsuleInvalid(t *testing.T) {
Expand Down

0 comments on commit 2eece5e

Please sign in to comment.