diff --git a/cli/parsefile.go b/cli/parsefile.go index 0960ad0..d735b97 100644 --- a/cli/parsefile.go +++ b/cli/parsefile.go @@ -90,7 +90,7 @@ func main() { } } - pkt := make(packet.Packet, packet.PacketSize) + var pkt packet.Packet var numPackets uint64 ebps := make(map[uint64]ebp.EncoderBoundaryPoint) scte35PIDs := make(map[uint16]bool) @@ -107,7 +107,7 @@ func main() { } for { - if _, err := io.ReadFull(reader, pkt); err != nil { + if _, err := io.ReadFull(reader, pkt[:]); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { break } @@ -116,13 +116,13 @@ func main() { } numPackets++ if *dumpSCTE35 { - currPID, err := packet.Pid(pkt) + currPID, err := packet.Pid(&pkt) if err != nil { fmt.Printf("Cannot get packet PID for %d\n", currPID) continue } if scte35PIDs[currPID] { - pay, err := packet.Payload(pkt) + pay, err := packet.Payload(&pkt) if err != nil { fmt.Printf("Cannot get payload for packet number %d on PID %d Error=%s\n", numPackets, currPID, err) continue @@ -138,7 +138,7 @@ func main() { } if *showEbp { - ebpBytes, err := adaptationfield.EncoderBoundaryPoint(pkt) + ebpBytes, err := adaptationfield.EncoderBoundaryPoint(&pkt) if err != nil { // Not an EBP continue @@ -154,7 +154,7 @@ func main() { } if *showPacketNumberOfPID != 0 { pid := uint16(*showPacketNumberOfPID) - pktPid, err := packet.Pid(pkt) + pktPid, err := packet.Pid(&pkt) if err != nil { continue } diff --git a/packet/accumulator.go b/packet/accumulator.go index 4ccff28..e2d92d9 100644 --- a/packet/accumulator.go +++ b/packet/accumulator.go @@ -30,48 +30,37 @@ import ( "github.com/Comcast/gots" ) -var ( - emptyByteArray []byte -) - -type doneFunc func([]byte) (bool, error) - type accumulator struct { - f doneFunc - packets []Packet + f func([]byte) (bool, error) + packets []*Packet } -// NewAccumulator creates a new packet accumulator -// that is done when the provided doneFunc returns true. -// PacketAccumulator is not thread safe -func NewAccumulator(f doneFunc) Accumulator { +// NewAccumulator creates a new packet accumulator that is done when +// the provided function returns done as true. +func NewAccumulator(f func(data []byte) (done bool, err error)) Accumulator { return &accumulator{f: f} } // Add a packet to the accumulator. If the added packet completes // the accumulation, based on the provided doneFunc, true is returned. // Returns an error if the packet is not valid. -func (a *accumulator) Add(pkt Packet) (bool, error) { +func (a *accumulator) Add(pkt []byte) (bool, error) { if badLen(pkt) { return false, gots.ErrInvalidPacketLength } + var pp Packet + copy(pp[:], pkt) // technically we could get a packet without a payload. Check this and // return false if we get one - p, err := ContainsPayload(pkt) - if err != nil { + p, err := ContainsPayload(&pp) + if !p || err != nil { return false, err - } else if !p { - return false, nil } - if payloadUnitStartIndicator(pkt) { - a.packets = make([]Packet, 0) - } else if len(a.packets) == 0 { + if !payloadUnitStartIndicator(&pp) && len(a.packets) == 0 { // First packet must have payload unit start indicator return false, gots.ErrNoPayloadUnitStartIndicator } - pktCopy := make(Packet, PacketSize) - copy(pktCopy, pkt) - a.packets = append(a.packets, pktCopy) + a.packets = append(a.packets, &pp) b, err := a.Parse() if err != nil { return false, err @@ -91,14 +80,14 @@ func (a *accumulator) Parse() ([]byte, error) { for _, pkt := range a.packets { pay, err := Payload(pkt) if err != nil { - return emptyByteArray, err + return nil, err } buf.Write(pay) } return buf.Bytes(), nil } -func (a *accumulator) Packets() []Packet { +func (a *accumulator) Packets() []*Packet { return a.packets } diff --git a/packet/accumulator_test.go b/packet/accumulator_test.go index d38692d..0f8b460 100644 --- a/packet/accumulator_test.go +++ b/packet/accumulator_test.go @@ -38,7 +38,7 @@ func ExamplePacketAccumulator() { secondPacket, _ := hex.DecodeString("47006411f0002b59bc22ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - packets := []Packet{firstPacket, secondPacket} + packets := [][]byte{firstPacket, secondPacket} // Just a simple func to accumulate two packets dFunc := func(b []byte) (bool, error) { if len(b) <= PacketSize { @@ -48,9 +48,11 @@ func ExamplePacketAccumulator() { } acc := NewAccumulator(dFunc) - done, err := acc.Add(packets[0]) - for i := 1; !done; i++ { - done, err = acc.Add(packets[i]) + for _, pkt := range packets { + done, err := acc.Add(pkt) + if done { + break + } if err != nil { fmt.Printf("%v\n", err) } diff --git a/packet/adaptationfield/adaptationfield.go b/packet/adaptationfield/adaptationfield.go index c727501..6fe23fc 100644 --- a/packet/adaptationfield/adaptationfield.go +++ b/packet/adaptationfield/adaptationfield.go @@ -5,64 +5,59 @@ import ( "github.com/Comcast/gots/packet" ) -var emptyByteSlice []byte - // Length returns the length of the adaptation field in bytes -func Length(packet packet.Packet) uint8 { - return uint8(packet[4]) +func Length(pkt *packet.Packet) uint8 { + return uint8(pkt[4]) } // IsDiscontinuous returns the discontinuity indicator for this adaptation field -func IsDiscontinuous(packet packet.Packet) bool { - return (packet[5] & 0x80) != 0 +func IsDiscontinuous(pkt *packet.Packet) bool { + return pkt[5]&0x80 != 0 } // IsRandomAccess returns the random access indicator for this adaptation field -func IsRandomAccess(packet packet.Packet) bool { - return (packet[5] & 0x40) != 0 +func IsRandomAccess(pkt *packet.Packet) bool { + return pkt[5]&0x40 != 0 } // IsESHigherPriority returns true if this elementary stream is // high priority. Corresponds to the elementary stream // priority indicator. -func IsESHigherPriority(packet packet.Packet) bool { - return (packet[5] & 0x20) != 0 +func IsESHigherPriority(pkt *packet.Packet) bool { + return pkt[5]&0x20 != 0 } // HasPCR returns true when the PCR flag is set -func HasPCR(packet packet.Packet) bool { - return (packet[5] & 0x10) != 0 +func HasPCR(pkt *packet.Packet) bool { + return pkt[5]&0x10 != 0 } // HasOPCR returns true when the OPCR flag is set -func HasOPCR(packet packet.Packet) bool { - return (packet[5] & 0x08) != 0 +func HasOPCR(pkt *packet.Packet) bool { + return pkt[5]&0x08 != 0 } // HasSplicingPoint returns true when the splicing countdown field is present -func HasSplicingPoint(packet packet.Packet) bool { - return (packet[5] & 0x04) != 0 +func HasSplicingPoint(pkt *packet.Packet) bool { + return pkt[5]&0x04 != 0 } // HasTransportPrivateData returns true when the private data field is present -func HasTransportPrivateData(packet packet.Packet) bool { - return (packet[5] & 0x02) != 0 +func HasTransportPrivateData(pkt *packet.Packet) bool { + return pkt[5]&0x02 != 0 } // HasAdaptationFieldExtension returns true if this adaptation field contains an extension field -func HasAdaptationFieldExtension(packet packet.Packet) bool { - return (packet[5] & 0x01) != 0 +func HasAdaptationFieldExtension(pkt *packet.Packet) bool { + return pkt[5]&0x01 != 0 } // EncoderBoundaryPoint returns the byte array located in the optional TransportPrivateData of the (also optional) // AdaptationField of the Packet. If either of these optional fields are missing an empty byte array is returned with an error -func EncoderBoundaryPoint(pkt packet.Packet) ([]byte, error) { - if badLen(pkt) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } +func EncoderBoundaryPoint(pkt *packet.Packet) ([]byte, error) { hasAdapt, err := packet.ContainsAdaptationField(pkt) if err != nil { - return emptyByteSlice, nil + return nil, nil } if hasAdapt && Length(pkt) > 0 && HasTransportPrivateData(pkt) { ebp, err := TransportPrivateData(pkt) @@ -78,81 +73,64 @@ func EncoderBoundaryPoint(pkt packet.Packet) ([]byte, error) { // First 33 bits are PCR base. // Next 6 bits are reserved. // Final 9 bits are PCR extension. -func PCR(packet packet.Packet) ([]byte, error) { - if !HasPCR(packet) { - return emptyByteSlice, gots.ErrNoPCR +func PCR(pkt *packet.Packet) ([]byte, error) { + if !HasPCR(pkt) { + return nil, gots.ErrNoPCR } offset := 6 - return packet[offset : offset+6], nil + return pkt[offset : offset+6], nil } // OPCR is the Original Program Clock Reference. // First 33 bits are original PCR base. // Next 6 bits are reserved. // Final 9 bits are original PCR extension. -func OPCR(packet packet.Packet) ([]byte, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } - if !HasOPCR(packet) { - return emptyByteSlice, gots.ErrNoOPCR +func OPCR(pkt *packet.Packet) ([]byte, error) { + if !HasOPCR(pkt) { + return nil, gots.ErrNoOPCR } offset := 6 - if HasPCR(packet) { + if HasPCR(pkt) { offset += 6 } - return packet[offset : offset+6], nil + return pkt[offset : offset+6], nil } // SpliceCountdown returns a count of how many packets after this one until // a splice point occurs or an error if none exist. This function calls // HasSplicingPoint to check for the existence of a splice countdown. -func SpliceCountdown(packet packet.Packet) (uint8, error) { - if badLen(packet) { - return 0, gots.ErrInvalidPacketLength - } - if !HasSplicingPoint(packet) { +func SpliceCountdown(pkt *packet.Packet) (uint8, error) { + if !HasSplicingPoint(pkt) { return 0, gots.ErrNoSplicePoint } offset := 6 - if HasPCR(packet) { + if HasPCR(pkt) { offset += 6 } - if HasOPCR(packet) { + if HasOPCR(pkt) { offset += 6 } - return packet[offset], nil + return pkt[offset], nil } // TransportPrivateData returns the private data from this adaptation field // or an empty array and an error if there is none. This function calls // HasTransportPrivateData to check for the existence of private data. -func TransportPrivateData(packet packet.Packet) ([]byte, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } - if !HasTransportPrivateData(packet) { - return emptyByteSlice, gots.ErrNoPrivateTransportData +func TransportPrivateData(pkt *packet.Packet) ([]byte, error) { + if !HasTransportPrivateData(pkt) { + return nil, gots.ErrNoPrivateTransportData } offset := 6 - if HasPCR(packet) { + if HasPCR(pkt) { offset += 6 } - if HasOPCR(packet) { + if HasOPCR(pkt) { offset += 6 } - if HasSplicingPoint(packet) { + if HasSplicingPoint(pkt) { offset++ } - dataLength := uint8(packet[offset]) + dataLength := uint8(pkt[offset]) offset++ - return packet[uint8(offset) : uint8(offset)+dataLength], nil -} - -// badLen returns true if the packet has invalid length -func badLen(pkt packet.Packet) bool { - if len(pkt) != packet.PacketSize { - return true - } - return false + return pkt[uint8(offset) : uint8(offset)+dataLength], nil } diff --git a/packet/adaptationfield/create.go b/packet/adaptationfield/create.go index 414bb5f..fddb8ef 100644 --- a/packet/adaptationfield/create.go +++ b/packet/adaptationfield/create.go @@ -4,18 +4,17 @@ import "github.com/Comcast/gots/packet" func SetPrivateData(pkt *packet.Packet, af []byte) { offset := 6 - if HasPCR(*pkt) { + if HasPCR(pkt) { offset += 6 } - if HasOPCR(*pkt) { + if HasOPCR(pkt) { offset += 6 } - if HasSplicingPoint(*pkt) { + if HasSplicingPoint(pkt) { offset++ } - (*pkt)[offset] = byte(0x04) // data length + pkt[offset] = byte(0x04) // data length offset++ - for i, b := range af { - (*pkt)[offset+i] = b - } + // FIXME(kortschak): Handle len(af) != 4. + copy(pkt[offset:offset+4], af) } diff --git a/packet/create.go b/packet/create.go index b24022e..782c255 100644 --- a/packet/create.go +++ b/packet/create.go @@ -34,7 +34,7 @@ var ( var ( // TestPatPacket is a minimal PAT packet for testing. It contains a single program stream with no payload. - TestPatPacket = []byte{ + TestPatPacket = Packet{ 0x47, 0x40, 0x00, 0x10, 0x00, 0x00, 0xb0, 0x0d, 0x00, 0x01, 0xcb, 0x00, 0x00, 0x00, 0x01, 0xe0, 0x64, 0x68, 0xd6, 0x84, 0x2e, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, @@ -51,7 +51,7 @@ var ( 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - TestPmtPacket = []byte{ + TestPmtPacket = Packet{ 0x47, 0x40, 0x64, 0x10, 0x00, 0x02, 0xb0, 0x2d, 0x00, 0x01, 0xcb, 0x00, 0x00, 0xe0, 0x65, 0xf0, 0x06, 0x05, 0x04, 0x43, 0x55, 0x45, 0x49, 0x1b, 0xe0, 0x65, 0xf0, 0x05, 0x0e, 0x03, 0x00, 0x04, 0xb0, 0x0f, 0xe0, 0x66, @@ -79,8 +79,8 @@ var ( // WithContinuousAF, // WithPUSI), // cc) -func Create(pid uint16, options ...func(*Packet)) Packet { - var pkt Packet = make([]byte, 188) +func Create(pid uint16, options ...func(*Packet)) *Packet { + var pkt Packet setPid(&pkt, pid) for _, option := range options { option(&pkt) @@ -88,13 +88,13 @@ func Create(pid uint16, options ...func(*Packet)) Packet { for _, option := range required { option(&pkt) } - return pkt + return &pkt } // CreateTestPacket creates a test packet with the given PID, continuity counter, payload unit start indicator and payload flag // This is a convenience function for often used packet creatio options functions -func CreateTestPacket(pid uint16, cc uint8, pusi, hasPay bool) Packet { - var pkt Packet +func CreateTestPacket(pid uint16, cc uint8, pusi, hasPay bool) *Packet { + var pkt *Packet if hasPay && pusi { pkt, _ = SetCC( Create(pid, @@ -116,13 +116,13 @@ func CreateTestPacket(pid uint16, cc uint8, pusi, hasPay bool) Packet { } // CreateDCPacket creates a new packet with a discontinuous adapataion field and the given PID and CC -func CreateDCPacket(pid uint16, cc uint8) Packet { +func CreateDCPacket(pid uint16, cc uint8) *Packet { pkt, _ := SetCC(Create(pid, WithDiscontinuousAF, WithHasPayloadFlag), cc) return pkt } // CreatePacketWithPayload creates a new packet with the given PID, CC and payload -func CreatePacketWithPayload(pid uint16, cc uint8, pay []byte) Packet { +func CreatePacketWithPayload(pid uint16, cc uint8, pay []byte) *Packet { pkt, _ := SetCC( Create( pid, @@ -137,38 +137,38 @@ func CreatePacketWithPayload(pid uint16, cc uint8, pay []byte) Packet { return pkt } func setPid(pkt *Packet, pid uint16) { - (*pkt)[1] = byte(pid >> 8 & 0x1f) - (*pkt)[2] = byte(pid & 0xff) + pkt[1] = byte(pid >> 8 & 0x1f) + pkt[2] = byte(pid & 0xff) } // WithHasPayloadFlag is an option function for creating a packet with a payload flag func WithHasPayloadFlag(pkt *Packet) { - (*pkt)[3] = byte((*pkt)[3] | 0x10) + pkt[3] |= 0x10 } // WithHasAdaptationFieldFlag is an option function for creating a packet with an adaptation field func WithHasAdaptationFieldFlag(pkt *Packet) { - (*pkt)[3] = (*pkt)[3] | 0x20 + pkt[3] |= 0x02 } // WithAFPrivateDataFlag is an option function for creating a packet with a adaptation field private data flag func WithAFPrivateDataFlag(pkt *Packet) { - (*pkt)[5] = (*pkt)[5] | 0x02 + pkt[5] |= 0x02 } // WithPUSI is an option function for creating a packet with the payload unit start indicator flag set func WithPUSI(pkt *Packet) { - (*pkt)[1] = byte((*pkt)[1] | 0x40) + pkt[1] |= 0x40 } // WithContinuousAF is an option function for creating a packet with a continuous adaptation field func WithContinuousAF(pkt *Packet) { - (*pkt)[5] = byte((*pkt)[5] & 0x7f) + pkt[5] |= 0x7f } // WithDisconinuousAF is an option function for creating a packet with a discontinuous adaptation field func WithDiscontinuousAF(pkt *Packet) { - (*pkt)[5] = byte((*pkt)[5] | 0x80) + pkt[5] |= 0x80 } // WithPES is an option function for creating a packet with a PES header @@ -197,20 +197,20 @@ func WithPES(pkt *Packet, pts uint64) { // InsertPTS insterts a given pts time into a byte slice func InsertPTS(b []byte, pts uint64) { - b[0] = byte(pts >> 29 & 0x07) - b[1] = byte(pts >> 22 & 0x4f) - b[2] = byte(pts >> 14 & 0xff) - b[3] = byte(pts >> 7 & 0x7f) b[4] = byte(pts&0xff) << 1 + b[3] = byte(pts >> 7 & 0x7f) + b[2] = byte(pts >> 14 & 0xff) + b[1] = byte(pts >> 22 & 0x4f) + b[0] = byte(pts >> 29 & 0x07) } // SetPayload sets the payload of a given packet func SetPayload(pkt *Packet, pay []byte) int { - start := payloadStart((*pkt)) + start := payloadStart(pkt) i := start j := 0 for i < PacketSize && j < len(pay) { - (*pkt)[i] = pay[j] + pkt[i] = pay[j] i++ j++ } @@ -219,5 +219,5 @@ func SetPayload(pkt *Packet, pay []byte) int { // Required parts of a packet func setSyncByte(pkt *Packet) { - (*pkt)[0] = SyncByte + pkt[0] = SyncByte } diff --git a/packet/doc.go b/packet/doc.go index 3138e7a..f9252fd 100644 --- a/packet/doc.go +++ b/packet/doc.go @@ -35,18 +35,18 @@ const ( ) // Packet is the basic unit in a transport stream. -type Packet []byte +type Packet [PacketSize]byte // Accumulator is used to gather multiple packets // and return their concatenated payloads. // Accumulator is not thread safe. type Accumulator interface { // Add adds a packet to the accumulator and returns true if done. - Add(Packet) (bool, error) + Add([]byte) (bool, error) // Parse returns the concatenated payloads of all the packets that have been added to the accumulator Parse() ([]byte, error) // Packets returns the accumulated packets - Packets() []Packet + Packets() []*Packet // Reset clears all packets in the accumulator Reset() } diff --git a/packet/packet.go b/packet/packet.go index e3b8f14..32c910c 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -24,108 +24,70 @@ SOFTWARE. package packet -import ( - "bytes" - - "github.com/Comcast/gots" -) - -var emptyByteSlice []byte +import "github.com/Comcast/gots" // PayloadUnitStartIndicator (PUSI) is a flag that indicates the start of PES data // or PSI (Program-Specific Information) such as AT, CAT, PMT or NIT. The PUSI // flag is contained in the second bit of the second byte of the Packet. -func PayloadUnitStartIndicator(packet Packet) (bool, error) { - if badLen(packet) { - return false, gots.ErrInvalidPacketLength - } +func PayloadUnitStartIndicator(packet *Packet) (bool, error) { return payloadUnitStartIndicator(packet), nil } -func payloadUnitStartIndicator(packet Packet) bool { +func payloadUnitStartIndicator(packet *Packet) bool { return packet[1]&0x040 != 0 } // PID is the Packet Identifier. Each table or elementary stream in the // transport stream is identified by a PID. The PID is contained in the 13 // bits that span the last 5 bits of second byte and all bits in the byte. -func Pid(packet Packet) (uint16, error) { - if badLen(packet) { - return 0, gots.ErrInvalidPacketLength - } +func Pid(packet *Packet) (uint16, error) { return pid(packet), nil } -func pid(packet Packet) uint16 { +func pid(packet *Packet) uint16 { return uint16(packet[1]&0x1f)<<8 | uint16(packet[2]) } // ContainsPayload is a flag that indicates the packet has a payload. The flag is // contained in the 3rd bit of the 4th byte of the Packet. -func ContainsPayload(packet Packet) (bool, error) { - if badLen(packet) { - return false, gots.ErrInvalidPacketLength - } +func ContainsPayload(packet *Packet) (bool, error) { return containsPayload(packet), nil } -func containsPayload(packet Packet) bool { +func containsPayload(packet *Packet) bool { return packet[3]&0x10 != 0 } // ContainsAdaptationField is a flag that indicates the packet has an adaptation field. -func ContainsAdaptationField(packet Packet) (bool, error) { - if badLen(packet) { - return false, gots.ErrInvalidPacketLength - } +func ContainsAdaptationField(packet *Packet) (bool, error) { return hasAdaptField(packet), nil } -func hasAdaptField(packet Packet) bool { +func hasAdaptField(packet *Packet) bool { return packet[3]&0x20 != 0 } // ContinuityCounter is a 4-bit sequence number of payload packets. Incremented // only when a payload is present (see ContainsPayload() above). -func ContinuityCounter(packet Packet) (uint8, error) { - if badLen(packet) { - return 0, gots.ErrInvalidPacketLength - } +func ContinuityCounter(packet *Packet) (uint8, error) { return packet[3] & uint8(0x0f), nil } // IsNull returns true if the provided packet is a Null packet // (i.e., PID == 0x1fff (8191)). -func IsNull(packet Packet) (bool, error) { - if badLen(packet) { - return false, gots.ErrInvalidPacketLength - } - - if pid(packet) == NullPacketPid { - return true, nil - } - return false, nil +func IsNull(packet *Packet) (bool, error) { + return pid(packet) == NullPacketPid, nil } // IsPat returns true if the provided packet is a PAT -func IsPat(packet Packet) (bool, error) { - if badLen(packet) { - return false, gots.ErrInvalidPacketLength - } - - if pid(packet) == 0 { - return true, nil - } - return false, nil +func IsPat(packet *Packet) (bool, error) { + return pid(packet) == 0, nil } -// badLen returns true is the packet is of -// invalid length -func badLen(packet Packet) bool { - if len(packet) != PacketSize { - return true - } - return false +// badLen returns true if the packet is not of +// valid length +func badLen(packet []byte) bool { + return len(packet) != PacketSize } // Returns the index of the first byte of Payload data in packetBytes. -func payloadStart(packet Packet) int { +func payloadStart(packet *Packet) int { var dataOffset = int(4) // packet header bytes if hasAdaptField(packet) { afLength := int(packet[4]) @@ -137,12 +99,9 @@ func payloadStart(packet Packet) int { // Payload returns a slice containing the packet payload. If the packet // does not have a payload, an empty byte slice is returned -func Payload(packet Packet) ([]byte, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } +func Payload(packet *Packet) ([]byte, error) { if !containsPayload(packet) { - return emptyByteSlice, gots.ErrNoPayload + return nil, gots.ErrNoPayload } start := payloadStart(packet) pay := packet[start:] @@ -151,31 +110,25 @@ func Payload(packet Packet) ([]byte, error) { // IncrementCC creates a new packet where the new packet has // a continuity counter that is increased by one -func IncrementCC(packet Packet) (Packet, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } - newPacket := make([]byte, len(packet)) - copy(newPacket, packet) +func IncrementCC(packet *Packet) (*Packet, error) { + var newPacket Packet + copy(newPacket[:], packet[:]) ccByte := newPacket[3] newCC := increment4BitInt(ccByte) newCCByte := (ccByte & byte(0xf0)) | newCC newPacket[3] = newCCByte - return newPacket, nil + return &newPacket, nil } // ZeroCC creates a new packet where the new packet has // a continuity counter that zero -func ZeroCC(packet Packet) (Packet, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } - newPacket := make([]byte, len(packet)) - copy(newPacket, packet) +func ZeroCC(packet *Packet) (*Packet, error) { + var newPacket Packet + copy(newPacket[:], packet[:]) ccByte := newPacket[3] - newCCByte := (ccByte & byte(0xf0)) + newCCByte := ccByte & byte(0xf0) newPacket[3] = newCCByte - return newPacket, nil + return &newPacket, nil } func increment4BitInt(cc uint8) uint8 { return (cc + 1) & 0x0f @@ -183,24 +136,18 @@ func increment4BitInt(cc uint8) uint8 { // SetCC creates a new packet where the new packet has // the continuity counter provided -func SetCC(packet Packet, newCC uint8) (Packet, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } - newPacket := make([]byte, len(packet)) - copy(newPacket, packet) +func SetCC(packet *Packet, newCC uint8) (*Packet, error) { + var newPacket Packet + copy(newPacket[:], packet[:]) ccByte := newPacket[3] newCCByte := (ccByte & byte(0xf0)) | newCC newPacket[3] = newCCByte - return newPacket, nil + return &newPacket, nil } // Returns a byte slice containing the PES header if the Packet contains one, // otherwise returns an error -func PESHeader(packet Packet) ([]byte, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } +func PESHeader(packet *Packet) ([]byte, error) { if containsPayload(packet) && payloadUnitStartIndicator(packet) { dataOffset := payloadStart(packet) // A PES Header has a Packet Start Code Prefix of 0x000001 @@ -212,19 +159,22 @@ func PESHeader(packet Packet) ([]byte, error) { return pay, nil } } - return emptyByteSlice, gots.ErrNoPayload + return nil, gots.ErrNoPayload } // Header Returns a slice containing the Packer Header. -func Header(packet Packet) ([]byte, error) { - if badLen(packet) { - return emptyByteSlice, gots.ErrInvalidPacketLength - } +func Header(packet *Packet) ([]byte, error) { start := payloadStart(packet) - return packet[0:start], nil + return packet[:start], nil } // Equal returns true if the bytes of the two packets are equal -func Equal(a, b Packet) bool { - return bytes.Equal(a, b) +func Equal(a, b *Packet) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b } diff --git a/packet/packet_test.go b/packet/packet_test.go index a7d26d6..97dd89b 100644 --- a/packet/packet_test.go +++ b/packet/packet_test.go @@ -28,10 +28,24 @@ import ( "bytes" "encoding/hex" "testing" + + "github.com/Comcast/gots" ) +func parseHexString(h string) *Packet { + b, err := hex.DecodeString(h) + if err != nil { + panic("bad test: " + h) + } + pkt := new(Packet) + if copy(pkt[:], b) != PacketSize { + panic("bad test (wrong length): " + h) + } + return pkt +} + func TestPayloadUnitStartIndicatorTrue(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "474000130000b00d0001c700000001e0642273423bffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -44,7 +58,7 @@ func TestPayloadUnitStartIndicatorTrue(t *testing.T) { } } func TestPayloadUnitStartIndicatorFalse(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4700673b7000ffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -58,7 +72,7 @@ func TestPayloadUnitStartIndicatorFalse(t *testing.T) { } func TestPid(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "47406618000001c000f280800523fae5b8a3fff94c801d4010210994fd959f4b" + "6108806a912e4b972d025c92429595817016dca64a18e7fc5c271bb40a0f9150" + "3c0057776bdd66c0e9ab2ba7614de80ee468cc6e860846241710cfda6dabc569" + @@ -72,7 +86,7 @@ func TestPid(t *testing.T) { } func TestPidGreaterThan255(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4701221B000001c000f280800523fae5b8a3fff94c801d4010210994fd959f4b" + "6108806a912e4b972d025c92429595817016dca64a18e7fc5c271bb40a0f9150" + "3c0057776bdd66c0e9ab2ba7614de80ee468cc6e860846241710cfda6dabc569" + @@ -86,7 +100,7 @@ func TestPidGreaterThan255(t *testing.T) { } func TestContainsPayloadTrue(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "47406618000001c000f280800523fae5b8a3fff94c801d4010210994fd959f4b" + "6108806a912e4b972d025c92429595817016dca64a18e7fc5c271bb40a0f9150" + "3c0057776bdd66c0e9ab2ba7614de80ee468cc6e860846241710cfda6dabc569" + @@ -100,7 +114,7 @@ func TestContainsPayloadTrue(t *testing.T) { } func TestContainsPayloadFalse(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "47006523b7103f5c99597ef7ffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -114,7 +128,7 @@ func TestContainsPayloadFalse(t *testing.T) { } func TestContinuityCounter(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "47006518dc0eff960f094176e794721d00cfedc13c1b039abf71e0f16bfeef88" + "de1d1901a576793da53551cfc53363e00be1417c08383ce8bc51efda4c4a465c" + "9aee27f76997169968829cf3343253c16243f7c21602cb2161767fda0485d4de" + @@ -128,7 +142,7 @@ func TestContinuityCounter(t *testing.T) { } func TestZeroLenthAdaptationField(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4701e1320034fcabf65d866a87eca0195db5ce1dcb6e0f75ba45a351722714db" + "a013cea9665e9e1866b13431429454a37cb5663ea00353624c5d1f84c9463651" + "634497dd837080b99ddf4bb26242f18d22ecd74dde47cd84041e5df3f0c57c40" + @@ -144,7 +158,7 @@ func TestZeroLenthAdaptationField(t *testing.T) { } func TestPayloadWhenPacketHasNoAdaptationField(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "47006518dc0eff960f094176e794721d00cfedc13c1b039abf71e0f16bfeef88" + "de1d1901a576793da53551cfc53363e00be1417c08383ce8bc51efda4c4a465c" + "9aee27f76997169968829cf3343253c16243f7c21602cb2161767fda0485d4de" + @@ -166,7 +180,7 @@ func TestPayloadWhenPacketHasNoAdaptationField(t *testing.T) { } func TestPayloadWhenPacketHasAdaptationField(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + @@ -187,7 +201,7 @@ func TestPayloadWhenPacketHasAdaptationField(t *testing.T) { } func TestIncrementCC(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4700673b7000ffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -208,9 +222,12 @@ func TestIncrementCC(t *testing.T) { func TestBadLength(t *testing.T) { packet, _ := hex.DecodeString("4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084") - _, err := Header(packet) - - if err == nil { + acc := NewAccumulator(nil) + ok, err := acc.Add(packet) + if ok { + t.Errorf("BadLength, expected failure from new packet") + } + if err != gots.ErrInvalidPacketLength { t.Errorf("BadLength, expected error from new packet") } } @@ -231,7 +248,7 @@ func TestIncrementCCFunc(t *testing.T) { } func TestContainsAdaptationField(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4700663a7700ffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -244,7 +261,7 @@ func TestContainsAdaptationField(t *testing.T) { } func TestEqualsNilPacket(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + @@ -257,21 +274,21 @@ func TestEqualsNilPacket(t *testing.T) { } func TestEqualsIdenticalPackets(t *testing.T) { - packet, _ := hex.DecodeString( + packet := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + "060007818a378085f8c00104007820100601c40411b500314741393403c2fffd" + "2980fc8080ff800000000125b80100017fb2c69de69e51f57c4a1b8623115f78" + "053598e7f47c066bf03c90c6233c0405369fd5f8e20957e40437f784") - same := packet[:] - if !Equal(packet, same) { - t.Errorf("Identical packets are different p1%v p2%v", packet, same) + same := *packet + if !Equal(packet, &same) { + t.Errorf("Identical packets are different p1%v p2%v", packet, &same) } } func TestEqualsHeadersNotEqual(t *testing.T) { - packet1, _ := hex.DecodeString( + packet1 := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + @@ -280,7 +297,7 @@ func TestEqualsHeadersNotEqual(t *testing.T) { "053598e7f47c066bf03c90c6233c0405369fd5f8e20957e40437f784") // Same as above, but with the MPEG-TS headers TEI bit flipped. - packet2, _ := hex.DecodeString( + packet2 := parseHexString( "4780653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + @@ -294,7 +311,7 @@ func TestEqualsHeadersNotEqual(t *testing.T) { } func TestNullPacketIsNull(t *testing.T) { - p, _ := hex.DecodeString( + p := parseHexString( "471fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -309,7 +326,7 @@ func TestNullPacketIsNull(t *testing.T) { } func TestNonNullPacketIsNotNull(t *testing.T) { - packet1, _ := hex.DecodeString( + packet1 := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + @@ -323,7 +340,7 @@ func TestNonNullPacketIsNotNull(t *testing.T) { } func TestIsPat(t *testing.T) { - pat, _ := hex.DecodeString( + pat := parseHexString( "4740001f0000b00d0031e100000001e064bfcd282fffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -335,7 +352,7 @@ func TestIsPat(t *testing.T) { t.Error("PAT packet should be counted as a PAT") } - notPat, _ := hex.DecodeString( + notPat := parseHexString( "4740653214723f5d09c67ec90ca90ad800d6ae02c11e66772d000001e0000084" + "c00a33faf9760713faf900b900000001091000000001274d401f9a6281004b60" + "2d1000003e90000ea60e8601d400057e4bbcb8280000000128ee388000000001" + diff --git a/pes/pesheader_test.go b/pes/pesheader_test.go index 8a7f326..dac6dbd 100644 --- a/pes/pesheader_test.go +++ b/pes/pesheader_test.go @@ -30,9 +30,20 @@ import ( "github.com/Comcast/gots/packet" ) -func TestPESHeader(t *testing.T) { +func parseHexString(h string) *packet.Packet { + b, err := hex.DecodeString(h) + if err != nil { + panic("bad test: " + h) + } + pkt := new(packet.Packet) + if copy(pkt[:], b) != packet.PacketSize { + panic("bad test (wrong length): " + h) + } + return pkt +} - pkt, _ := hex.DecodeString( +func TestPESHeader(t *testing.T) { + pkt := parseHexString( "4740661a000001c006ff80800521dee9ca57fff94c801d2000210995341d9d43" + "61089848180b0884626048901425ddc09249220129d2fce728111c987e67ecb7" + "4284af5099181d8cd095b841b0c7539ad6c06260536e137615560052369fc984" + @@ -61,8 +72,7 @@ func TestPESHeader(t *testing.T) { } func TestPESHeader2(t *testing.T) { - - pkt, _ := hex.DecodeString( + pkt := parseHexString( "4740651C000001E0000084C00A39EFF33A7519EFF30B89000000010950000000" + "01060104001A20100411B500314741393403C2FFFD8080FC942FFF8000000001" + "21A81C29145C6FEB86EB239E2EE231302CF5163D32D183B7822FE37E7FB84549" + @@ -92,11 +102,10 @@ func TestPESHeader2(t *testing.T) { } func TestNewPESHeaderMissingBytes(t *testing.T) { - // Actual data from Cisco Transcoder (AMC channel). Below packet was causing // index out of bounds exception. It has the PES prefix code but we were not // checking to see if it's a PUSI to begin with - pkt, _ := hex.DecodeString( + pkt := parseHexString( "47006531b300ffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -110,8 +119,7 @@ func TestNewPESHeaderMissingBytes(t *testing.T) { } func TestPESHeaderTS(t *testing.T) { - - pkt, _ := hex.DecodeString( + pkt := parseHexString( "4752a31c000001e0000080c00a210005bf21210005a7ab000001000697fffb80" + "000001b5844ffb9400000001b24741393403d4fffc8080fd8fdffa0000fa0000" + "fa0000fa0000fa0000fa0000fa0000fa0000fa0000fa0000fa0000fa0000fa00" + diff --git a/psi/pat.go b/psi/pat.go index 48d79bb..e0e32bb 100644 --- a/psi/pat.go +++ b/psi/pat.go @@ -53,8 +53,10 @@ func NewPAT(patBytes []byte) (PAT, error) { } if len(patBytes) == 188 { + var pkt packet.Packet + copy(pkt[:], patBytes) var err error - patBytes, err = packet.Payload(patBytes) + patBytes, err = packet.Payload(&pkt) if err != nil { return nil, err } @@ -112,21 +114,21 @@ func (pat pat) SPTSpmtPID() (uint16, error) { // It returns a new PAT object parsed from the packet, if found, and otherwise // returns an error. func ReadPAT(r io.Reader) (PAT, error) { - pkt := make(packet.Packet, packet.PacketSize) + var pkt packet.Packet var pat PAT for pat == nil { - if _, err := io.ReadFull(r, pkt); err != nil { + if _, err := io.ReadFull(r, pkt[:]); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { break } return nil, err } - isPat, err := packet.IsPat(pkt) + isPat, err := packet.IsPat(&pkt) if err != nil { return nil, err } if isPat { - pay, err := packet.Payload(pkt) + pay, err := packet.Payload(&pkt) if err != nil { return nil, err } diff --git a/psi/pmt.go b/psi/pmt.go index f0187b3..2bec79b 100644 --- a/psi/pmt.go +++ b/psi/pmt.go @@ -192,7 +192,7 @@ func (p *pmt) String() string { } // FilterPMTPacketsToPids filters the PMT contents of the provided packet to the PIDs provided and returns a new packet. For example: if the provided PMT has PIDs 101, 102, and 103 and the provides PIDs are 101 and 102, the new PMT will have only descriptors for PID 101 and 102. The descriptor for PID 103 will be stripped from the new PMT packet. -func FilterPMTPacketsToPids(packets []packet.Packet, pids []uint16) []packet.Packet { +func FilterPMTPacketsToPids(packets []*packet.Packet, pids []uint16) []*packet.Packet { // make sure we have packets if len(packets) == 0 { return nil @@ -255,7 +255,7 @@ func FilterPMTPacketsToPids(packets []packet.Packet, pids []uint16) []packet.Pac // Recalculate the CRC fPMT = append(fPMT, gots.ComputeCRC(fPMT[pointerField:])...) - var filteredPMTPackets []packet.Packet + var filteredPMTPackets []*packet.Packet for _, pkt := range packets { var pktBuf bytes.Buffer header, err := packet.Header(pkt) @@ -276,8 +276,7 @@ func FilterPMTPacketsToPids(packets []packet.Packet, pids []uint16) []packet.Pac // all done break } - padPacket(&pktBuf) - filteredPMTPackets = append(filteredPMTPackets, pktBuf.Bytes()) + filteredPMTPackets = append(filteredPMTPackets, padPacket(&pktBuf)) } return filteredPMTPackets } @@ -286,7 +285,7 @@ func FilterPMTPacketsToPids(packets []packet.Packet, pids []uint16) []packet.Pac // defined by the PAT provided. Returns ErrNilPAT if pat // is nil, or any error encountered in parsing the PID // of pkt. -func IsPMT(pkt packet.Packet, pat PAT) (bool, error) { +func IsPMT(pkt *packet.Packet, pat PAT) (bool, error) { if pat == nil { return false, gots.ErrNilPAT } @@ -314,10 +313,12 @@ func safeSlice(byteArray []byte, start, end int) []byte { return byteArray[start:len(byteArray)] } -func padPacket(pkt *bytes.Buffer) { - for i := pkt.Len(); i < packet.PacketSize; i++ { - pkt.WriteByte(255) +func padPacket(buf *bytes.Buffer) *packet.Packet { + var pkt packet.Packet + for i := copy(pkt[:], buf.Bytes()); i < packet.PacketSize; i++ { + pkt[i] = 0xff } + return &pkt } func pidIn(pids []uint16, target uint16) bool { @@ -335,25 +336,25 @@ func pidIn(pids []uint16, target uint16) bool { // It returns a new PMT object parsed from the packet(s), if found, and // otherwise returns an error. func ReadPMT(r io.Reader, pid uint16) (PMT, error) { - pkt := make(packet.Packet, packet.PacketSize) + var pkt packet.Packet pmtAcc := packet.NewAccumulator(PmtAccumulatorDoneFunc) done := false var pmt PMT for !done { - if _, err := io.ReadFull(r, pkt); err != nil { + if _, err := io.ReadFull(r, pkt[:]); err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF { return nil, gots.ErrPMTNotFound } return nil, err } - currPid, err := packet.Pid(pkt) + currPid, err := packet.Pid(&pkt) if err != nil { return nil, err } if currPid != pid { continue } - done, err = pmtAcc.Add(pkt) + done, err = pmtAcc.Add(pkt[:]) if err != nil { return nil, err } diff --git a/psi/pmt_test.go b/psi/pmt_test.go index e986de6..93a0d5a 100644 --- a/psi/pmt_test.go +++ b/psi/pmt_test.go @@ -33,6 +33,18 @@ import ( "github.com/Comcast/gots/packet" ) +func parseHexString(h string) *packet.Packet { + b, err := hex.DecodeString(h) + if err != nil { + panic("bad test: " + h) + } + pkt := new(packet.Packet) + if copy(pkt[:], b) != packet.PacketSize { + panic("bad test (wrong length): " + h) + } + return pkt +} + type testPmtElementaryStream struct { elementaryPid uint16 streamType uint8 @@ -416,7 +428,7 @@ func TestStringFormat(t *testing.T) { } func TestFilterPMTPacketsToPids_SinglePacketPMT(t *testing.T) { - bytes := []byte{ + bytes := packet.Packet{ 0x47, 0x40, 0x64, 0x10, 0x00, 0x02, 0xb0, 0x2d, 0x00, 0x01, 0xcb, 0x00, 0x00, 0xe0, 0x65, 0xf0, 0x06, 0x05, 0x04, 0x43, 0x55, 0x45, 0x49, 0x1b, 0xe0, 0x65, 0xf0, 0x05, 0x0e, 0x03, 0x00, 0x04, 0xb0, 0x0f, 0xe0, 0x66, @@ -435,7 +447,7 @@ func TestFilterPMTPacketsToPids_SinglePacketPMT(t *testing.T) { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} acc := packet.NewAccumulator(PmtAccumulatorDoneFunc) - acc.Add(bytes) + acc.Add(bytes[:]) payload, err := acc.Parse() if err != nil { t.Error(err) @@ -448,11 +460,11 @@ func TestFilterPMTPacketsToPids_SinglePacketPMT(t *testing.T) { pids := unfilteredPmt.Pids() pids = pids[:len(pids)-1] - filteredPmtPackets := FilterPMTPacketsToPids([]packet.Packet{bytes}, pids) + filteredPmtPackets := FilterPMTPacketsToPids([]*packet.Packet{&bytes}, pids) acc = packet.NewAccumulator(PmtAccumulatorDoneFunc) for _, p := range filteredPmtPackets { - acc.Add(p) + acc.Add(p[:]) } payload, err = acc.Parse() filteredPmt, err := NewPMT(payload) @@ -467,13 +479,13 @@ func TestFilterPMTPacketsToPids_SinglePacketPMT(t *testing.T) { } func TestFilterPMTPacketsToPids_MultiPacketPMT(t *testing.T) { - firstPacketBytes, _ := hex.DecodeString("474064100002b0ba0001c10000e065f00b0504435545490e03c03dd01be065f016970028046400283fe907108302808502800e03c0392087e066f0219700050445414333cc03c0c2100a04656e6700e907108302808502800e03c000f087e067f0219700050445414333cc03c0c4100a0473706100e907108302808502800e03c001e00fe068f01697000a04656e6700e907108302808502800e03c000f00fe069f01697000a0473706100e907108302808502800e03c000f086e0dc") + firstPacketBytes := parseHexString("474064100002b0ba0001c10000e065f00b0504435545490e03c03dd01be065f016970028046400283fe907108302808502800e03c0392087e066f0219700050445414333cc03c0c2100a04656e6700e907108302808502800e03c000f087e067f0219700050445414333cc03c0c4100a0473706100e907108302808502800e03c001e00fe068f01697000a04656e6700e907108302808502800e03c000f00fe069f01697000a0473706100e907108302808502800e03c000f086e0dc") - secondPacketBytes, _ := hex.DecodeString("47006411f0002b59bc22ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + secondPacketBytes := parseHexString("47006411f0002b59bc22ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") acc := packet.NewAccumulator(PmtAccumulatorDoneFunc) - acc.Add(firstPacketBytes) - acc.Add(secondPacketBytes) + acc.Add(firstPacketBytes[:]) + acc.Add(secondPacketBytes[:]) payload, err := acc.Parse() if err != nil { t.Error(err) @@ -482,10 +494,10 @@ func TestFilterPMTPacketsToPids_MultiPacketPMT(t *testing.T) { wantedPids := []uint16{101, 102, 103, 104, 105, 220} filteredPids := wantedPids[:len(wantedPids)-1] - filteredPMTPackets := FilterPMTPacketsToPids([]packet.Packet{firstPacketBytes, secondPacketBytes}, filteredPids) + filteredPMTPackets := FilterPMTPacketsToPids([]*packet.Packet{firstPacketBytes, secondPacketBytes}, filteredPids) acc = packet.NewAccumulator(PmtAccumulatorDoneFunc) for _, p := range filteredPMTPackets { - acc.Add(p) + acc.Add(p[:]) } wantedPids = []uint16{101, 102, 103, 104, 105} @@ -575,7 +587,7 @@ func TestPMTIsIFrameStreamNegative(t *testing.T) { } func TestIsPMT(t *testing.T) { - patPkt, _ := hex.DecodeString("4740003001000000b00d0001c100000001e1e02d507804ffffffffffffffffff" + + patPkt := parseHexString("4740003001000000b00d0001c100000001e1e02d507804ffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -589,14 +601,14 @@ func TestIsPMT(t *testing.T) { t.Error("Couldn't load the PAT") } - pmt, _ := hex.DecodeString("4741e03001000002b0480001c10000e1e1f0050e03c004751be1e1f016970028" + + pmt := parseHexString("4741e03001000002b0480001c10000e1e1f0050e03c004751be1e1f016970028" + "044d401f3fe907108302808502800e03c003350fe1e2f01697000a04656e6700" + "e907108302808502800e03c00104db121f57ffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - notPMT, _ := hex.DecodeString("4741e13117f200014307ff050fdf0d45425030c8dae4dd8000000000000001e0" + + notPMT := parseHexString("4741e13117f200014307ff050fdf0d45425030c8dae4dd8000000000000001e0" + "000084d00d31000bab4111000b93cb80054700000001091000000001674d401f" + "ba202833f3e022000007d20001d4c1c040020f400041eb4d4601f18311200000" + "000168ebef20000000010600068232993c76c08000000001060447b500314741" + @@ -615,7 +627,7 @@ func TestIsPMT(t *testing.T) { func TestIsPMTErrorConditions(t *testing.T) { // Test nil PAT - pmt, _ := hex.DecodeString("4741e03001000002b0480001c10000e1e1f0050e03c004751be1e1f016970028" + + pmt := parseHexString("4741e03001000002b0480001c10000e1e1f0050e03c004751be1e1f016970028" + "044d401f3fe907108302808502800e03c003350fe1e2f01697000a04656e6700" + "e907108302808502800e03c00104db121f57ffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -631,14 +643,7 @@ func TestIsPMTErrorConditions(t *testing.T) { t.Error("Nil Pat should return nil pat error") } - badPMT, _ := hex.DecodeString("4741e03001000002b0480001c10000e1e1f0050e03c004751be1e1f016970028" + - "044d401f3fe907108302808502800e03c003350fe1e2f01697000a04656e6700" + - "e907108302808502800e03c00104db121f57ffffffffffffffffffffffffffff" + - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + - "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + - "ffffffffffffffffffffffffffffffffffffffffffffffffffffff") - - patPkt, _ := hex.DecodeString("4740003001000000b00d0001c100000001e1e02d507804ffffffffffffffffff" + + patPkt := parseHexString("4740003001000000b00d0001c100000001e1e02d507804ffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + @@ -651,16 +656,6 @@ func TestIsPMTErrorConditions(t *testing.T) { if pat == nil { t.Error("Couldn't load the PAT") } - - isPMTExpectFalse, errExpectBadLen := IsPMT(badPMT, pat) - - if isPMTExpectFalse == true { - t.Error("Bad PMT Length should return false") - } - - if errExpectBadLen == nil { - t.Error("Bad PMT Length should return an error, probably invalid packet length") - } } func TestReadPMTForSmoke(t *testing.T) { bs, _ := hex.DecodeString("474000100000b00d0001c100000001e256f803e71bfffffff" +