diff --git a/go.mod b/go.mod index 701e9b9..a1d47d0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/donnie4w/gothrift -go 1.21.0 +go 1.22.4 diff --git a/thrift/binary_protocol_test.go b/thrift/binary_protocol_test.go index 88bfd26..67f9923 100644 --- a/thrift/binary_protocol_test.go +++ b/thrift/binary_protocol_test.go @@ -112,7 +112,7 @@ func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testi return func(b *testing.B) { data := make([]byte, dataSize) b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { safeReadBytes(askedSize, bytes.NewReader(data)) } } diff --git a/thrift/configuration.go b/thrift/configuration.go index de27edd..a9565d3 100644 --- a/thrift/configuration.go +++ b/thrift/configuration.go @@ -56,47 +56,47 @@ const ( // // For example, say you want to migrate this old code into using TConfiguration: // -// sccket, err := thrift.NewTSocketTimeout("host:port", time.Second, time.Second) -// transFactory := thrift.NewTFramedTransportFactoryMaxLength( -// thrift.NewTTransportFactory(), -// 1024 * 1024 * 256, -// ) -// protoFactory := thrift.NewTBinaryProtocolFactory(true, true) +// socket, err := thrift.NewTSocketTimeout("host:port", time.Second, time.Second) +// transFactory := thrift.NewTFramedTransportFactoryMaxLength( +// thrift.NewTTransportFactory(), +// 1024 * 1024 * 256, +// ) +// protoFactory := thrift.NewTBinaryProtocolFactory(true, true) // // This is the wrong way to do it because in the end the TConfiguration used by // socket and transFactory will be overwritten by the one used by protoFactory // because of TConfiguration propagation: // -// // bad example, DO NOT USE -// sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{ -// ConnectTimeout: time.Second, -// SocketTimeout: time.Second, -// }) -// transFactory := thrift.NewTFramedTransportFactoryConf( -// thrift.NewTTransportFactory(), -// &thrift.TConfiguration{ -// MaxFrameSize: 1024 * 1024 * 256, -// }, -// ) -// protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{ -// TBinaryStrictRead: thrift.BoolPtr(true), -// TBinaryStrictWrite: thrift.BoolPtr(true), -// }) +// // bad example, DO NOT USE +// socket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{ +// ConnectTimeout: time.Second, +// SocketTimeout: time.Second, +// }) +// transFactory := thrift.NewTFramedTransportFactoryConf( +// thrift.NewTTransportFactory(), +// &thrift.TConfiguration{ +// MaxFrameSize: 1024 * 1024 * 256, +// }, +// ) +// protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{ +// TBinaryStrictRead: thrift.BoolPtr(true), +// TBinaryStrictWrite: thrift.BoolPtr(true), +// }) // // This is the correct way to do it: // -// conf := &thrift.TConfiguration{ -// ConnectTimeout: time.Second, -// SocketTimeout: time.Second, +// conf := &thrift.TConfiguration{ +// ConnectTimeout: time.Second, +// SocketTimeout: time.Second, // -// MaxFrameSize: 1024 * 1024 * 256, +// MaxFrameSize: 1024 * 1024 * 256, // -// TBinaryStrictRead: thrift.BoolPtr(true), -// TBinaryStrictWrite: thrift.BoolPtr(true), -// } -// sccket := thrift.NewTSocketConf("host:port", conf) -// transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf) -// protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf) +// TBinaryStrictRead: thrift.BoolPtr(true), +// TBinaryStrictWrite: thrift.BoolPtr(true), +// } +// socket := thrift.NewTSocketConf("host:port", conf) +// transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf) +// protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf) // // [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md type TConfiguration struct { @@ -132,6 +132,8 @@ type TConfiguration struct { // THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions // are provided to help filling this value. THeaderProtocolID *THeaderProtocolID + // The write transforms to be applied to THeaderTransport. + THeaderTransforms []THeaderTransformID // Used internally by deprecated constructors, to avoid overriding // underlying TTransport/TProtocol's cfg by accidental propagations. @@ -245,6 +247,18 @@ func (tc *TConfiguration) GetTHeaderProtocolID() THeaderProtocolID { return protoID } +// GetTHeaderTransforms returns the THeaderTransformIDs to be applied on +// THeaderTransport writing. +// +// It's nil-safe. If tc is nil, empty slice will be returned (meaning no +// transforms to be applied). +func (tc *TConfiguration) GetTHeaderTransforms() []THeaderTransformID { + if tc == nil { + return nil + } + return tc.THeaderTransforms +} + // THeaderProtocolIDPtr validates and returns the pointer to id. // // If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault diff --git a/thrift/debug_protocol.go b/thrift/debug_protocol.go index 72304f1..d730411 100644 --- a/thrift/debug_protocol.go +++ b/thrift/debug_protocol.go @@ -23,670 +23,670 @@ import ( "context" "log/slog" ) - - type TDebugProtocol struct { - // Required. The actual TProtocol to do the read/write. - Delegate TProtocol - - // Optional. The logger and prefix to log all the args/return values - // from Delegate TProtocol calls. - // - // If Logger is nil, StdLogger using stdlib log package with os.Stderr - // will be used. If disable logging is desired, set Logger to NopLogger - // explicitly instead of leaving it as nil/unset. - // - // Deprecated: TDebugProtocol always use slog at debug level now. - // This field will be removed in a future version. - Logger Logger - - LogPrefix string - - // Optional. An TProtocol to duplicate everything read/written from Delegate. - // - // A typical use case of this is to use TSimpleJSONProtocol wrapping - // TMemoryBuffer in a middleware to json logging requests/responses. - // - // This feature is not available from TDebugProtocolFactory. In order to - // use it you have to construct TDebugProtocol directly, or set DuplicateTo - // field after getting a TDebugProtocol from the factory. - // - // Deprecated: Please use TDuplicateToProtocol instead. - DuplicateTo TProtocol - } - - type TDebugProtocolFactory struct { - Underlying TProtocolFactory - LogPrefix string - Logger Logger - } - - // NewTDebugProtocolFactory creates a TDebugProtocolFactory. - // - // Deprecated: Please use NewTDebugProtocolFactoryWithLogger or the struct - // itself instead. This version will use the default logger from standard - // library. - func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory { - return &TDebugProtocolFactory{ - Underlying: underlying, - LogPrefix: logPrefix, - Logger: StdLogger(nil), - } - } - - // NewTDebugProtocolFactoryWithLogger creates a TDebugProtocolFactory. - func NewTDebugProtocolFactoryWithLogger(underlying TProtocolFactory, logPrefix string, logger Logger) *TDebugProtocolFactory { - return &TDebugProtocolFactory{ - Underlying: underlying, - LogPrefix: logPrefix, - Logger: logger, - } - } - - func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol { - return &TDebugProtocol{ - Delegate: t.Underlying.GetProtocol(trans), - LogPrefix: t.LogPrefix, - Logger: fallbackLogger(t.Logger), - } - } - - func (tdp *TDebugProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error { - err := tdp.Delegate.WriteMessageBegin(ctx, name, typeId, seqid) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteMessageBegin", - "name", name, - "typeId", typeId, - "seqid", seqid, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid) - } - return err - } - func (tdp *TDebugProtocol) WriteMessageEnd(ctx context.Context) error { - err := tdp.Delegate.WriteMessageEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteMessageEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMessageEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteStructBegin(ctx context.Context, name string) error { - err := tdp.Delegate.WriteStructBegin(ctx, name) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteStructBegin", - "name", name, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteStructBegin(ctx, name) - } - return err - } - func (tdp *TDebugProtocol) WriteStructEnd(ctx context.Context) error { - err := tdp.Delegate.WriteStructEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteStructEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteStructEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { - err := tdp.Delegate.WriteFieldBegin(ctx, name, typeId, id) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteFieldBegin", - "name", name, - "typeId", typeId, - "id", id, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id) - } - return err - } - func (tdp *TDebugProtocol) WriteFieldEnd(ctx context.Context) error { - err := tdp.Delegate.WriteFieldEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteFieldEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteFieldEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteFieldStop(ctx context.Context) error { - err := tdp.Delegate.WriteFieldStop(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteFieldStop", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteFieldStop(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { - err := tdp.Delegate.WriteMapBegin(ctx, keyType, valueType, size) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteMapBegin", - "keyType", keyType, - "valueType", valueType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size) - } - return err - } - func (tdp *TDebugProtocol) WriteMapEnd(ctx context.Context) error { - err := tdp.Delegate.WriteMapEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteMapEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMapEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { - err := tdp.Delegate.WriteListBegin(ctx, elemType, size) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteListBegin", - "elemType", elemType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteListBegin(ctx, elemType, size) - } - return err - } - func (tdp *TDebugProtocol) WriteListEnd(ctx context.Context) error { - err := tdp.Delegate.WriteListEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteListEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteListEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { - err := tdp.Delegate.WriteSetBegin(ctx, elemType, size) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteSetBegin", - "elemType", elemType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size) - } - return err - } - func (tdp *TDebugProtocol) WriteSetEnd(ctx context.Context) error { - err := tdp.Delegate.WriteSetEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteSetEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteSetEnd(ctx) - } - return err - } - func (tdp *TDebugProtocol) WriteBool(ctx context.Context, value bool) error { - err := tdp.Delegate.WriteBool(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteBool", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteBool(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteByte(ctx context.Context, value int8) error { - err := tdp.Delegate.WriteByte(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteByte", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteByte(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteI16(ctx context.Context, value int16) error { - err := tdp.Delegate.WriteI16(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteI16", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI16(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteI32(ctx context.Context, value int32) error { - err := tdp.Delegate.WriteI32(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteI32", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI32(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteI64(ctx context.Context, value int64) error { - err := tdp.Delegate.WriteI64(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteI64", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI64(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteDouble(ctx context.Context, value float64) error { - err := tdp.Delegate.WriteDouble(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteDouble", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteDouble(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteString(ctx context.Context, value string) error { - err := tdp.Delegate.WriteString(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteString", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteString(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteBinary(ctx context.Context, value []byte) error { - err := tdp.Delegate.WriteBinary(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteBinary", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteBinary(ctx, value) - } - return err - } - func (tdp *TDebugProtocol) WriteUUID(ctx context.Context, value Tuuid) error { - err := tdp.Delegate.WriteUUID(ctx, value) - slog.DebugContext( - ctx, - tdp.LogPrefix+"WriteUUID", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteUUID(ctx, value) - } - return err - } - - func (tdp *TDebugProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) { - name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadMessageBegin", - "name", name, - "typeId", typeId, - "seqid", seqid, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid) - } - return - } - func (tdp *TDebugProtocol) ReadMessageEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadMessageEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadMessageEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMessageEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { - name, err = tdp.Delegate.ReadStructBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadStructBegin", - "name", name, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteStructBegin(ctx, name) - } - return - } - func (tdp *TDebugProtocol) ReadStructEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadStructEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadStructEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteStructEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) { - name, typeId, id, err = tdp.Delegate.ReadFieldBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadFieldBegin", - "name", name, - "typeId", typeId, - "id", id, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id) - } - return - } - func (tdp *TDebugProtocol) ReadFieldEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadFieldEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadFieldEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteFieldEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) { - keyType, valueType, size, err = tdp.Delegate.ReadMapBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadMapBegin", - "keyType", keyType, - "valueType", valueType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size) - } - return - } - func (tdp *TDebugProtocol) ReadMapEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadMapEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadMapEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteMapEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { - elemType, size, err = tdp.Delegate.ReadListBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadListBegin", - "elemType", elemType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteListBegin(ctx, elemType, size) - } - return - } - func (tdp *TDebugProtocol) ReadListEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadListEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadListEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteListEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { - elemType, size, err = tdp.Delegate.ReadSetBegin(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadSetBegin", - "elemType", elemType, - "size", size, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size) - } - return - } - func (tdp *TDebugProtocol) ReadSetEnd(ctx context.Context) (err error) { - err = tdp.Delegate.ReadSetEnd(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadSetEnd", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteSetEnd(ctx) - } - return - } - func (tdp *TDebugProtocol) ReadBool(ctx context.Context) (value bool, err error) { - value, err = tdp.Delegate.ReadBool(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadBool", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteBool(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadByte(ctx context.Context) (value int8, err error) { - value, err = tdp.Delegate.ReadByte(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadByte", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteByte(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadI16(ctx context.Context) (value int16, err error) { - value, err = tdp.Delegate.ReadI16(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadI16", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI16(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadI32(ctx context.Context) (value int32, err error) { - value, err = tdp.Delegate.ReadI32(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadI32", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI32(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadI64(ctx context.Context) (value int64, err error) { - value, err = tdp.Delegate.ReadI64(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadI64", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteI64(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadDouble(ctx context.Context) (value float64, err error) { - value, err = tdp.Delegate.ReadDouble(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadDouble", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteDouble(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadString(ctx context.Context) (value string, err error) { - value, err = tdp.Delegate.ReadString(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadString", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteString(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadBinary(ctx context.Context) (value []byte, err error) { - value, err = tdp.Delegate.ReadBinary(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadBinary", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteBinary(ctx, value) - } - return - } - func (tdp *TDebugProtocol) ReadUUID(ctx context.Context) (value Tuuid, err error) { - value, err = tdp.Delegate.ReadUUID(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"ReadUUID", - "value", value, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.WriteUUID(ctx, value) - } - return - } - func (tdp *TDebugProtocol) Skip(ctx context.Context, fieldType TType) (err error) { - err = tdp.Delegate.Skip(ctx, fieldType) - slog.DebugContext( - ctx, - tdp.LogPrefix+"Skip", - "fieldType", fieldType, - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.Skip(ctx, fieldType) - } - return - } - func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) { - err = tdp.Delegate.Flush(ctx) - slog.DebugContext( - ctx, - tdp.LogPrefix+"Flush", - "err", err, - ) - if tdp.DuplicateTo != nil { - tdp.DuplicateTo.Flush(ctx) - } - return - } - - func (tdp *TDebugProtocol) Transport() TTransport { - return tdp.Delegate.Transport() - } - - // SetTConfiguration implements TConfigurationSetter for propagation. - func (tdp *TDebugProtocol) SetTConfiguration(conf *TConfiguration) { - PropagateTConfiguration(tdp.Delegate, conf) - PropagateTConfiguration(tdp.DuplicateTo, conf) - } - - var _ TConfigurationSetter = (*TDebugProtocol)(nil) \ No newline at end of file + +type TDebugProtocol struct { + // Required. The actual TProtocol to do the read/write. + Delegate TProtocol + + // Optional. The logger and prefix to log all the args/return values + // from Delegate TProtocol calls. + // + // If Logger is nil, StdLogger using stdlib log package with os.Stderr + // will be used. If disable logging is desired, set Logger to NopLogger + // explicitly instead of leaving it as nil/unset. + // + // Deprecated: TDebugProtocol always use slog at debug level now. + // This field will be removed in a future version. + Logger Logger + + LogPrefix string + + // Optional. An TProtocol to duplicate everything read/written from Delegate. + // + // A typical use case of this is to use TSimpleJSONProtocol wrapping + // TMemoryBuffer in a middleware to json logging requests/responses. + // + // This feature is not available from TDebugProtocolFactory. In order to + // use it you have to construct TDebugProtocol directly, or set DuplicateTo + // field after getting a TDebugProtocol from the factory. + // + // Deprecated: Please use TDuplicateToProtocol instead. + DuplicateTo TProtocol +} + +type TDebugProtocolFactory struct { + Underlying TProtocolFactory + LogPrefix string + Logger Logger +} + +// NewTDebugProtocolFactory creates a TDebugProtocolFactory. +// +// Deprecated: Please use NewTDebugProtocolFactoryWithLogger or the struct +// itself instead. This version will use the default logger from standard +// library. +func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory { + return &TDebugProtocolFactory{ + Underlying: underlying, + LogPrefix: logPrefix, + Logger: StdLogger(nil), + } +} + +// NewTDebugProtocolFactoryWithLogger creates a TDebugProtocolFactory. +func NewTDebugProtocolFactoryWithLogger(underlying TProtocolFactory, logPrefix string, logger Logger) *TDebugProtocolFactory { + return &TDebugProtocolFactory{ + Underlying: underlying, + LogPrefix: logPrefix, + Logger: logger, + } +} + +func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return &TDebugProtocol{ + Delegate: t.Underlying.GetProtocol(trans), + LogPrefix: t.LogPrefix, + Logger: fallbackLogger(t.Logger), + } +} + +func (tdp *TDebugProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error { + err := tdp.Delegate.WriteMessageBegin(ctx, name, typeId, seqid) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteMessageBegin", + "name", name, + "typeId", typeId, + "seqid", seqid, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid) + } + return err +} +func (tdp *TDebugProtocol) WriteMessageEnd(ctx context.Context) error { + err := tdp.Delegate.WriteMessageEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteMessageEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMessageEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteStructBegin(ctx context.Context, name string) error { + err := tdp.Delegate.WriteStructBegin(ctx, name) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteStructBegin", + "name", name, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteStructBegin(ctx, name) + } + return err +} +func (tdp *TDebugProtocol) WriteStructEnd(ctx context.Context) error { + err := tdp.Delegate.WriteStructEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteStructEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteStructEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + err := tdp.Delegate.WriteFieldBegin(ctx, name, typeId, id) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteFieldBegin", + "name", name, + "typeId", typeId, + "id", id, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id) + } + return err +} +func (tdp *TDebugProtocol) WriteFieldEnd(ctx context.Context) error { + err := tdp.Delegate.WriteFieldEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteFieldEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteFieldEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteFieldStop(ctx context.Context) error { + err := tdp.Delegate.WriteFieldStop(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteFieldStop", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteFieldStop(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { + err := tdp.Delegate.WriteMapBegin(ctx, keyType, valueType, size) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteMapBegin", + "keyType", keyType, + "valueType", valueType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size) + } + return err +} +func (tdp *TDebugProtocol) WriteMapEnd(ctx context.Context) error { + err := tdp.Delegate.WriteMapEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteMapEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMapEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { + err := tdp.Delegate.WriteListBegin(ctx, elemType, size) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteListBegin", + "elemType", elemType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteListBegin(ctx, elemType, size) + } + return err +} +func (tdp *TDebugProtocol) WriteListEnd(ctx context.Context) error { + err := tdp.Delegate.WriteListEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteListEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteListEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { + err := tdp.Delegate.WriteSetBegin(ctx, elemType, size) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteSetBegin", + "elemType", elemType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size) + } + return err +} +func (tdp *TDebugProtocol) WriteSetEnd(ctx context.Context) error { + err := tdp.Delegate.WriteSetEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteSetEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteSetEnd(ctx) + } + return err +} +func (tdp *TDebugProtocol) WriteBool(ctx context.Context, value bool) error { + err := tdp.Delegate.WriteBool(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteBool", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteBool(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteByte(ctx context.Context, value int8) error { + err := tdp.Delegate.WriteByte(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteByte", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteByte(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteI16(ctx context.Context, value int16) error { + err := tdp.Delegate.WriteI16(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteI16", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI16(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteI32(ctx context.Context, value int32) error { + err := tdp.Delegate.WriteI32(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteI32", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI32(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteI64(ctx context.Context, value int64) error { + err := tdp.Delegate.WriteI64(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteI64", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI64(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteDouble(ctx context.Context, value float64) error { + err := tdp.Delegate.WriteDouble(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteDouble", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteDouble(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteString(ctx context.Context, value string) error { + err := tdp.Delegate.WriteString(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteString", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteString(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteBinary(ctx context.Context, value []byte) error { + err := tdp.Delegate.WriteBinary(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteBinary", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteBinary(ctx, value) + } + return err +} +func (tdp *TDebugProtocol) WriteUUID(ctx context.Context, value Tuuid) error { + err := tdp.Delegate.WriteUUID(ctx, value) + slog.DebugContext( + ctx, + tdp.LogPrefix+"WriteUUID", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteUUID(ctx, value) + } + return err +} + +func (tdp *TDebugProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) { + name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadMessageBegin", + "name", name, + "typeId", typeId, + "seqid", seqid, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid) + } + return +} +func (tdp *TDebugProtocol) ReadMessageEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadMessageEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadMessageEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMessageEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { + name, err = tdp.Delegate.ReadStructBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadStructBegin", + "name", name, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteStructBegin(ctx, name) + } + return +} +func (tdp *TDebugProtocol) ReadStructEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadStructEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadStructEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteStructEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) { + name, typeId, id, err = tdp.Delegate.ReadFieldBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadFieldBegin", + "name", name, + "typeId", typeId, + "id", id, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id) + } + return +} +func (tdp *TDebugProtocol) ReadFieldEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadFieldEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadFieldEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteFieldEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) { + keyType, valueType, size, err = tdp.Delegate.ReadMapBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadMapBegin", + "keyType", keyType, + "valueType", valueType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size) + } + return +} +func (tdp *TDebugProtocol) ReadMapEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadMapEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadMapEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteMapEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadListBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadListBegin", + "elemType", elemType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteListBegin(ctx, elemType, size) + } + return +} +func (tdp *TDebugProtocol) ReadListEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadListEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadListEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteListEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadSetBegin(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadSetBegin", + "elemType", elemType, + "size", size, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size) + } + return +} +func (tdp *TDebugProtocol) ReadSetEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadSetEnd(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadSetEnd", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteSetEnd(ctx) + } + return +} +func (tdp *TDebugProtocol) ReadBool(ctx context.Context) (value bool, err error) { + value, err = tdp.Delegate.ReadBool(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadBool", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteBool(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadByte(ctx context.Context) (value int8, err error) { + value, err = tdp.Delegate.ReadByte(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadByte", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteByte(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadI16(ctx context.Context) (value int16, err error) { + value, err = tdp.Delegate.ReadI16(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadI16", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI16(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadI32(ctx context.Context) (value int32, err error) { + value, err = tdp.Delegate.ReadI32(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadI32", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI32(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadI64(ctx context.Context) (value int64, err error) { + value, err = tdp.Delegate.ReadI64(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadI64", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteI64(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadDouble(ctx context.Context) (value float64, err error) { + value, err = tdp.Delegate.ReadDouble(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadDouble", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteDouble(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadString(ctx context.Context) (value string, err error) { + value, err = tdp.Delegate.ReadString(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadString", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteString(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadBinary(ctx context.Context) (value []byte, err error) { + value, err = tdp.Delegate.ReadBinary(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadBinary", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteBinary(ctx, value) + } + return +} +func (tdp *TDebugProtocol) ReadUUID(ctx context.Context) (value Tuuid, err error) { + value, err = tdp.Delegate.ReadUUID(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"ReadUUID", + "value", value, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.WriteUUID(ctx, value) + } + return +} +func (tdp *TDebugProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + err = tdp.Delegate.Skip(ctx, fieldType) + slog.DebugContext( + ctx, + tdp.LogPrefix+"Skip", + "fieldType", fieldType, + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.Skip(ctx, fieldType) + } + return +} +func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) { + err = tdp.Delegate.Flush(ctx) + slog.DebugContext( + ctx, + tdp.LogPrefix+"Flush", + "err", err, + ) + if tdp.DuplicateTo != nil { + tdp.DuplicateTo.Flush(ctx) + } + return +} + +func (tdp *TDebugProtocol) Transport() TTransport { + return tdp.Delegate.Transport() +} + +// SetTConfiguration implements TConfigurationSetter for propagation. +func (tdp *TDebugProtocol) SetTConfiguration(conf *TConfiguration) { + PropagateTConfiguration(tdp.Delegate, conf) + PropagateTConfiguration(tdp.DuplicateTo, conf) +} + +var _ TConfigurationSetter = (*TDebugProtocol)(nil) diff --git a/thrift/exception.go b/thrift/exception.go index e2f1728..5b4cad9 100644 --- a/thrift/exception.go +++ b/thrift/exception.go @@ -121,20 +121,20 @@ var _ TException = wrappedTException{} // // For a endpoint defined in thrift IDL like this: // -// service MyService { -// FooResponse foo(1: FooRequest request) throws ( -// 1: Exception1 error1, -// 2: Exception2 error2, -// ) -// } +// service MyService { +// FooResponse foo(1: FooRequest request) throws ( +// 1: Exception1 error1, +// 2: Exception2 error2, +// ) +// } // // The thrift compiler generated go code for the result TStruct would be like: // -// type MyServiceFooResult struct { -// Success *FooResponse `thrift:"success,0" db:"success" json:"success,omitempty"` -// Error1 *Exception1 `thrift:"error1,1" db:"error1" json:"error1,omitempty"` -// Error2 *Exception2 `thrift:"error2,2" db:"error2" json:"error2,omitempty"` -// } +// type MyServiceFooResult struct { +// Success *FooResponse `thrift:"success,0" db:"success" json:"success,omitempty"` +// Error1 *Exception1 `thrift:"error1,1" db:"error1" json:"error1,omitempty"` +// Error2 *Exception2 `thrift:"error2,2" db:"error2" json:"error2,omitempty"` +// } // // And this function extracts the first non-nil exception out of // *MyServiceFooResult. @@ -144,7 +144,7 @@ func ExtractExceptionFromResult(result TStruct) error { return nil } typ := v.Type() - for i := 0; i < v.NumField(); i++ { + for i := range v.NumField() { if typ.Field(i).Name == "Success" { continue } diff --git a/thrift/framed_transport_test.go b/thrift/framed_transport_test.go index d23ec59..e5aa470 100644 --- a/thrift/framed_transport_test.go +++ b/thrift/framed_transport_test.go @@ -42,7 +42,7 @@ func TestTFramedTransportReuseTransport(t *testing.T) { writer := NewTFramedTransport(trans) t.Run("pair", func(t *testing.T) { - for i := 0; i < n; i++ { + for i := range n { // write if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { t.Fatalf("Failed to write on #%d: %v", i, err) @@ -64,7 +64,7 @@ func TestTFramedTransportReuseTransport(t *testing.T) { t.Run("batched", func(t *testing.T) { // write - for i := 0; i < n; i++ { + for i := range n { if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { t.Fatalf("Failed to write on #%d: %v", i, err) } @@ -74,7 +74,7 @@ func TestTFramedTransportReuseTransport(t *testing.T) { } // read - for i := 0; i < n; i++ { + for i := range n { const ( size = len(content) ) diff --git a/thrift/header_protocol.go b/thrift/header_protocol.go index 36777b4..bec84b8 100644 --- a/thrift/header_protocol.go +++ b/thrift/header_protocol.go @@ -119,6 +119,11 @@ func (p *THeaderProtocol) ClearWriteHeaders() { } // AddTransform add a transform for writing. +// +// Deprecated: This only applies to the next message written, and the next read +// message will cause write transforms to be reset from what's configured in +// TConfiguration. For sticky transforms, use TConfiguration.THeaderTransforms +// instead. func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error { return p.transport.AddTransform(transform) } diff --git a/thrift/header_protocol_test.go b/thrift/header_protocol_test.go index 48a69bf..dfd84f8 100644 --- a/thrift/header_protocol_test.go +++ b/thrift/header_protocol_test.go @@ -39,4 +39,24 @@ func TestReadWriteHeaderProtocol(t *testing.T) { })) }, ) + + t.Run( + "binary-zlib", + func(t *testing.T) { + ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{ + THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolBinary), + THeaderTransforms: []THeaderTransformID{TransformZlib}, + })) + }, + ) + + t.Run( + "compact-zlib", + func(t *testing.T) { + ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{ + THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact), + THeaderTransforms: []THeaderTransformID{TransformZlib}, + })) + }, + ) } diff --git a/thrift/header_transport.go b/thrift/header_transport.go index 3aea5a9..d6d6416 100644 --- a/thrift/header_transport.go +++ b/thrift/header_transport.go @@ -128,7 +128,7 @@ var _ io.ReadCloser = (*TransformReader)(nil) // // If you don't know the closers capacity beforehand, just use // -// &TransformReader{Reader: baseReader} +// &TransformReader{Reader: baseReader} // // instead would be sufficient. func NewTransformReaderWithCapacity(baseReader io.Reader, capacity int) *TransformReader { @@ -151,6 +151,11 @@ func (tr *TransformReader) Close() error { } // AddTransform adds a transform. +// +// Deprecated: This only applies to the next message written, and the next read +// message will cause write transforms to be reset from what's configured in +// TConfiguration. For sticky transforms, use TConfiguration.THeaderTransforms +// instead. func (tr *TransformReader) AddTransform(id THeaderTransformID) error { switch id { default: @@ -206,6 +211,25 @@ func (tw *TransformWriter) Close() error { return nil } +var zlibDefaultLevelWriterPool = newPool( + func() *zlib.Writer { + return zlib.NewWriter(nil) + }, + nil, +) + +type zlibPoolCloser struct { + writer *zlib.Writer +} + +func (z *zlibPoolCloser) Close() error { + defer func() { + z.writer.Reset(nil) + zlibDefaultLevelWriterPool.put(&z.writer) + }() + return z.writer.Close() +} + // AddTransform adds a transform. func (tw *TransformWriter) AddTransform(id THeaderTransformID) error { switch id { @@ -217,9 +241,12 @@ func (tw *TransformWriter) AddTransform(id THeaderTransformID) error { case TransformNone: // no-op case TransformZlib: - writeCloser := zlib.NewWriter(tw.Writer) + writeCloser := zlibDefaultLevelWriterPool.get() + writeCloser.Reset(tw.Writer) tw.Writer = writeCloser - tw.closers = append(tw.closers, writeCloser) + tw.closers = append(tw.closers, &zlibPoolCloser{ + writer: writeCloser, + }) } return nil } @@ -300,11 +327,12 @@ func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTra } PropagateTConfiguration(trans, conf) return &THeaderTransport{ - transport: trans, - reader: bufio.NewReader(trans), - writeHeaders: make(THeaderMap), - protocolID: conf.GetTHeaderProtocolID(), - cfg: conf, + transport: trans, + reader: bufio.NewReader(trans), + writeHeaders: make(THeaderMap), + writeTransforms: conf.GetTHeaderTransforms(), + protocolID: conf.GetTHeaderProtocolID(), + cfg: conf, } } @@ -449,6 +477,11 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e } t.protocolID = THeaderProtocolID(protoID) + // Reset writeTransforms to the ones from cfg, as we are going to add + // compression transforms from what we read, we don't want to accumulate + // different transforms read from different requests + t.writeTransforms = t.cfg.GetTHeaderTransforms() + var transformCount int32 transformCount, err = hp.readVarint32() if err != nil { @@ -461,12 +494,21 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e ) t.frameReader = reader transformIDs := make([]THeaderTransformID, transformCount) - for i := 0; i < int(transformCount); i++ { + for i := range int(transformCount) { id, err := hp.readVarint32() if err != nil { return err } - transformIDs[i] = THeaderTransformID(id) + tID := THeaderTransformID(id) + transformIDs[i] = tID + + // For compression transforms, we should also add them + // to writeTransforms so that the response (assuming we + // are reading a request) would do the same compression. + switch tID { + case TransformZlib: + t.addWriteTransformsDedupe(tID) + } } // The transform IDs on the wire was added based on the order of // writing, so on the reading side we need to reverse the order. @@ -494,7 +536,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e if err != nil { return err } - for i := 0; i < int(count); i++ { + for range int(count) { key, err := hp.ReadString(ctx) if err != nil { return err @@ -544,7 +586,7 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) { // the last Read finished the frame, do endOfFrame // handling here. err = t.endOfFrame() - } else if err == io.EOF { + } else if errors.Is(err, io.EOF) { err = t.endOfFrame() if err != nil { return @@ -726,6 +768,9 @@ func (t *THeaderTransport) ClearWriteHeaders() { } // AddTransform add a transform for writing. +// +// NOTE: This is provided as a low-level API, but in general you should use +// TConfiguration.THeaderTransforms to set transforms for writing instead. func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error { if !supportedTransformIDs[transform] { return NewTProtocolExceptionWithType( @@ -758,6 +803,17 @@ func (t *THeaderTransport) isFramed() bool { } } +// addWriteTransformsDedupe adds id to writeTransforms only if it's not already +// there. +func (t *THeaderTransport) addWriteTransformsDedupe(id THeaderTransformID) { + for _, existingID := range t.writeTransforms { + if existingID == id { + return + } + } + t.writeTransforms = append(t.writeTransforms, id) +} + // SetTConfiguration implements TConfigurationSetter. func (t *THeaderTransport) SetTConfiguration(cfg *TConfiguration) { PropagateTConfiguration(t.transport, cfg) diff --git a/thrift/header_transport_test.go b/thrift/header_transport_test.go index 125a5fd..09a0331 100644 --- a/thrift/header_transport_test.go +++ b/thrift/header_transport_test.go @@ -316,7 +316,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) { writer := NewTHeaderTransport(trans) t.Run("pair", func(t *testing.T) { - for i := 0; i < n; i++ { + for i := range n { // write if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { t.Fatalf("Failed to write on #%d: %v", i, err) @@ -338,7 +338,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) { t.Run("batched", func(t *testing.T) { // write - for i := 0; i < n; i++ { + for i := range n { if _, err := io.Copy(writer, strings.NewReader(content)); err != nil { t.Fatalf("Failed to write on #%d: %v", i, err) } @@ -348,7 +348,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) { } // read - for i := 0; i < n; i++ { + for i := range n { const ( size = len(content) ) diff --git a/thrift/json_protocol_test.go b/thrift/json_protocol_test.go index 39e52d1..1680532 100644 --- a/thrift/json_protocol_test.go +++ b/thrift/json_protocol_test.go @@ -451,7 +451,7 @@ func TestReadJSONProtocolBinary(t *testing.T) { if len(v) != len(value) { t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) } - for i := 0; i < len(v); i++ { + for i := range v { if v[i] != value[i] { t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) } diff --git a/thrift/logger.go b/thrift/logger.go index 722a5fa..4a0affe 100644 --- a/thrift/logger.go +++ b/thrift/logger.go @@ -38,44 +38,44 @@ import ( // // Deprecated: This is no longer used by any thrift go library code, // will be removed in the future version. - type Logger func(msg string) - - // NopLogger is a Logger implementation that does nothing. - // - // Deprecated: This is no longer used by any thrift go library code, - // will be removed in the future version. - func NopLogger(msg string) {} - - // StdLogger wraps stdlib log package into a Logger. - // - // If logger passed in is nil, it will fallback to use stderr and default flags. - // - // Deprecated: This is no longer used by any thrift go library code, - // will be removed in the future version. - func StdLogger(logger *log.Logger) Logger { - if logger == nil { - logger = log.New(os.Stderr, "", log.LstdFlags) - } - return func(msg string) { - logger.Print(msg) - } - } - - // TestLogger is a Logger implementation can be used in test codes. - // - // It fails the test when being called. - // - // Deprecated: This is no longer used by any thrift go library code, - // will be removed in the future version. - func TestLogger(tb testing.TB) Logger { - return func(msg string) { - tb.Errorf("logger called with msg: %q", msg) - } - } - - func fallbackLogger(logger Logger) Logger { - if logger == nil { - return StdLogger(nil) - } - return logger - } \ No newline at end of file +type Logger func(msg string) + +// NopLogger is a Logger implementation that does nothing. +// +// Deprecated: This is no longer used by any thrift go library code, +// will be removed in the future version. +func NopLogger(msg string) {} + +// StdLogger wraps stdlib log package into a Logger. +// +// If logger passed in is nil, it will fallback to use stderr and default flags. +// +// Deprecated: This is no longer used by any thrift go library code, +// will be removed in the future version. +func StdLogger(logger *log.Logger) Logger { + if logger == nil { + logger = log.New(os.Stderr, "", log.LstdFlags) + } + return func(msg string) { + logger.Print(msg) + } +} + +// TestLogger is a Logger implementation can be used in test codes. +// +// It fails the test when being called. +// +// Deprecated: This is no longer used by any thrift go library code, +// will be removed in the future version. +func TestLogger(tb testing.TB) Logger { + return func(msg string) { + tb.Errorf("logger called with msg: %q", msg) + } +} + +func fallbackLogger(logger Logger) Logger { + if logger == nil { + return StdLogger(nil) + } + return logger +} diff --git a/thrift/lowlevel_benchmarks_test.go b/thrift/lowlevel_benchmarks_test.go index e173655..b389388 100644 --- a/thrift/lowlevel_benchmarks_test.go +++ b/thrift/lowlevel_benchmarks_test.go @@ -41,7 +41,7 @@ func BenchmarkBinaryBool_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -52,7 +52,7 @@ func BenchmarkBinaryByte_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -63,7 +63,7 @@ func BenchmarkBinaryI16_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -74,7 +74,7 @@ func BenchmarkBinaryI32_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -84,7 +84,7 @@ func BenchmarkBinaryI64_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -94,7 +94,7 @@ func BenchmarkBinaryDouble_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -104,7 +104,7 @@ func BenchmarkBinaryString_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -114,7 +114,7 @@ func BenchmarkBinaryBinary_0(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } @@ -125,7 +125,7 @@ func BenchmarkBinaryBool_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -136,7 +136,7 @@ func BenchmarkBinaryByte_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -147,7 +147,7 @@ func BenchmarkBinaryI16_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -158,7 +158,7 @@ func BenchmarkBinaryI32_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -168,7 +168,7 @@ func BenchmarkBinaryI64_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -178,7 +178,7 @@ func BenchmarkBinaryDouble_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -188,7 +188,7 @@ func BenchmarkBinaryString_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -198,7 +198,7 @@ func BenchmarkBinaryBinary_1(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } @@ -209,7 +209,7 @@ func BenchmarkBinaryBool_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -220,7 +220,7 @@ func BenchmarkBinaryByte_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -231,7 +231,7 @@ func BenchmarkBinaryI16_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -242,7 +242,7 @@ func BenchmarkBinaryI32_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -252,7 +252,7 @@ func BenchmarkBinaryI64_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -262,7 +262,7 @@ func BenchmarkBinaryDouble_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -272,7 +272,7 @@ func BenchmarkBinaryString_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -282,7 +282,7 @@ func BenchmarkBinaryBinary_2(b *testing.B) { b.Fatal(err) } p := binaryProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } @@ -293,7 +293,7 @@ func BenchmarkCompactBool_0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -304,7 +304,7 @@ func BenchmarkCompactByte_0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -315,7 +315,7 @@ func BenchmarkCompactI16_0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -326,7 +326,7 @@ func BenchmarkCompactI32_0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -336,7 +336,7 @@ func BenchmarkCompactI64_0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -346,7 +346,7 @@ func BenchmarkCompactDouble0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -356,7 +356,7 @@ func BenchmarkCompactString0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -366,7 +366,7 @@ func BenchmarkCompactBinary0(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } @@ -377,7 +377,7 @@ func BenchmarkCompactBool_1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -388,7 +388,7 @@ func BenchmarkCompactByte_1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -399,7 +399,7 @@ func BenchmarkCompactI16_1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -410,7 +410,7 @@ func BenchmarkCompactI32_1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -420,7 +420,7 @@ func BenchmarkCompactI64_1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -430,7 +430,7 @@ func BenchmarkCompactDouble1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -440,7 +440,7 @@ func BenchmarkCompactString1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -450,7 +450,7 @@ func BenchmarkCompactBinary1(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } @@ -461,7 +461,7 @@ func BenchmarkCompactBool_2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBool(b, p, trans) } } @@ -472,7 +472,7 @@ func BenchmarkCompactByte_2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteByte(b, p, trans) } } @@ -483,7 +483,7 @@ func BenchmarkCompactI16_2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI16(b, p, trans) } } @@ -494,7 +494,7 @@ func BenchmarkCompactI32_2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI32(b, p, trans) } } @@ -504,7 +504,7 @@ func BenchmarkCompactI64_2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteI64(b, p, trans) } } @@ -514,7 +514,7 @@ func BenchmarkCompactDouble2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteDouble(b, p, trans) } } @@ -524,7 +524,7 @@ func BenchmarkCompactString2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteString(b, p, trans) } } @@ -534,7 +534,7 @@ func BenchmarkCompactBinary2(b *testing.B) { b.Fatal(err) } p := compactProtoF.GetProtocol(trans) - for i := 0; i < b.N; i++ { + for range b.N { ReadWriteBinary(b, p, trans) } } diff --git a/thrift/protocol.go b/thrift/protocol.go index 2ee14ca..68cfe4a 100644 --- a/thrift/protocol.go +++ b/thrift/protocol.go @@ -146,7 +146,7 @@ func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (e if err != nil { return err } - for i := 0; i < size; i++ { + for range size { err := Skip(ctx, self, keyType, maxDepth-1) if err != nil { return err @@ -163,7 +163,7 @@ func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (e if err != nil { return err } - for i := 0; i < size; i++ { + for range size { err := Skip(ctx, self, elemType, maxDepth-1) if err != nil { return err @@ -175,7 +175,7 @@ func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (e if err != nil { return err } - for i := 0; i < size; i++ { + for range size { err := Skip(ctx, self, elemType, maxDepth-1) if err != nil { return err diff --git a/thrift/protocol_test.go b/thrift/protocol_test.go index 1093c94..4fac801 100644 --- a/thrift/protocol_test.go +++ b/thrift/protocol_test.go @@ -45,7 +45,7 @@ var ( func init() { protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE) - for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ { + for i := range PROTOCOL_BINARY_DATA_SIZE { protocol_bdata[i] = byte((i + 'a') % 255) } BOOL_VALUES = []bool{false, true, false, false, true} @@ -531,7 +531,7 @@ func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { if len(v) != len(value) { t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value)) } else { - for i := 0; i < len(v); i++ { + for i := range v { if v[i] != value[i] { t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value) } diff --git a/thrift/serializer_test.go b/thrift/serializer_test.go index 425ce06..19879c5 100644 --- a/thrift/serializer_test.go +++ b/thrift/serializer_test.go @@ -328,7 +328,7 @@ func BenchmarkSerializer(b *testing.B) { b.Run( c.Label, func(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { s := c.Serializer() m := MyTestStruct{} str, _ := s.WriteString(context.Background(), &m) diff --git a/thrift/serializer_types_test.go b/thrift/serializer_types_test.go index 4d1e992..d960016 100644 --- a/thrift/serializer_types_test.go +++ b/thrift/serializer_types_test.go @@ -319,7 +319,7 @@ func (p *MyTestStruct) readField9(ctx context.Context, iprot TProtocol) error { } tMap := make(map[string]string, size) p.StringMap = tMap - for i := 0; i < size; i++ { + for range size { var _key0 string if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) @@ -347,7 +347,7 @@ func (p *MyTestStruct) readField10(ctx context.Context, iprot TProtocol) error { } tSlice := make([]string, 0, size) p.StringList = tSlice - for i := 0; i < size; i++ { + for range size { var _elem2 string if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) @@ -369,7 +369,7 @@ func (p *MyTestStruct) readField11(ctx context.Context, iprot TProtocol) error { } tSet := make(map[string]struct{}, size) p.StringSet = tSet - for i := 0; i < size; i++ { + for range size { var _elem3 string if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) diff --git a/thrift/simple_json_protocol.go b/thrift/simple_json_protocol.go index 4273983..ec12991 100644 --- a/thrift/simple_json_protocol.go +++ b/thrift/simple_json_protocol.go @@ -32,1321 +32,1321 @@ import ( "strconv" "strings" ) - - type _ParseContext int - - const ( - _CONTEXT_INVALID _ParseContext = iota - _CONTEXT_IN_TOPLEVEL // 1 - _CONTEXT_IN_LIST_FIRST // 2 - _CONTEXT_IN_LIST // 3 - _CONTEXT_IN_OBJECT_FIRST // 4 - _CONTEXT_IN_OBJECT_NEXT_KEY // 5 - _CONTEXT_IN_OBJECT_NEXT_VALUE // 6 - ) - - func (p _ParseContext) String() string { - switch p { - case _CONTEXT_IN_TOPLEVEL: - return "TOPLEVEL" - case _CONTEXT_IN_LIST_FIRST: - return "LIST-FIRST" - case _CONTEXT_IN_LIST: - return "LIST" - case _CONTEXT_IN_OBJECT_FIRST: - return "OBJECT-FIRST" - case _CONTEXT_IN_OBJECT_NEXT_KEY: - return "OBJECT-NEXT-KEY" - case _CONTEXT_IN_OBJECT_NEXT_VALUE: - return "OBJECT-NEXT-VALUE" - } - return "UNKNOWN-PARSE-CONTEXT" - } - - type jsonContextStack []_ParseContext - - func (s *jsonContextStack) push(v _ParseContext) { - *s = append(*s, v) - } - - func (s jsonContextStack) peek() (v _ParseContext, ok bool) { - l := len(s) - if l <= 0 { - return - } - return s[l-1], true - } - - func (s *jsonContextStack) pop() (v _ParseContext, ok bool) { - l := len(*s) - if l <= 0 { - return - } - v = (*s)[l-1] - *s = (*s)[0 : l-1] - return v, true - } - - var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack")) - - // Simple JSON protocol implementation for thrift. - // - // This protocol produces/consumes a simple output format - // suitable for parsing by scripting languages. It should not be - // confused with the full-featured TJSONProtocol. - type TSimpleJSONProtocol struct { - trans TTransport - - cfg *TConfiguration - - parseContextStack jsonContextStack - dumpContext jsonContextStack - - writer *bufio.Writer - reader *bufio.Reader - } - - // Deprecated: Use NewTSimpleJSONProtocolConf instead.: - func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { - return NewTSimpleJSONProtocolConf(t, &TConfiguration{ - noPropagation: true, - }) - } - - func NewTSimpleJSONProtocolConf(t TTransport, conf *TConfiguration) *TSimpleJSONProtocol { - PropagateTConfiguration(t, conf) - v := &TSimpleJSONProtocol{ - trans: t, - cfg: conf, - writer: bufio.NewWriter(t), - reader: bufio.NewReader(t), - } - v.resetContextStack() - return v - } - - // Factory - type TSimpleJSONProtocolFactory struct { - cfg *TConfiguration - } - - func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { - return NewTSimpleJSONProtocolConf(trans, p.cfg) - } - - // SetTConfiguration implements TConfigurationSetter for propagation. - func (p *TSimpleJSONProtocolFactory) SetTConfiguration(conf *TConfiguration) { - p.cfg = conf - } - - // Deprecated: Use NewTSimpleJSONProtocolFactoryConf instead. - func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { - return &TSimpleJSONProtocolFactory{ - cfg: &TConfiguration{ - noPropagation: true, - }, - } - } - - func NewTSimpleJSONProtocolFactoryConf(conf *TConfiguration) *TSimpleJSONProtocolFactory { - return &TSimpleJSONProtocolFactory{ - cfg: conf, - } - } - - var ( - JSON_COMMA []byte - JSON_COLON []byte - JSON_LBRACE []byte - JSON_RBRACE []byte - JSON_LBRACKET []byte - JSON_RBRACKET []byte - JSON_QUOTE byte - JSON_QUOTE_BYTES []byte - JSON_NULL []byte - JSON_TRUE []byte - JSON_FALSE []byte - JSON_INFINITY string - JSON_NEGATIVE_INFINITY string - JSON_NAN string - JSON_INFINITY_BYTES []byte - JSON_NEGATIVE_INFINITY_BYTES []byte - JSON_NAN_BYTES []byte - ) - - func init() { - JSON_COMMA = []byte{','} - JSON_COLON = []byte{':'} - JSON_LBRACE = []byte{'{'} - JSON_RBRACE = []byte{'}'} - JSON_LBRACKET = []byte{'['} - JSON_RBRACKET = []byte{']'} - JSON_QUOTE = '"' - JSON_QUOTE_BYTES = []byte{'"'} - JSON_NULL = []byte{'n', 'u', 'l', 'l'} - JSON_TRUE = []byte{'t', 'r', 'u', 'e'} - JSON_FALSE = []byte{'f', 'a', 'l', 's', 'e'} - JSON_INFINITY = "Infinity" - JSON_NEGATIVE_INFINITY = "-Infinity" - JSON_NAN = "NaN" - JSON_INFINITY_BYTES = []byte{'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} - JSON_NEGATIVE_INFINITY_BYTES = []byte{'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} - JSON_NAN_BYTES = []byte{'N', 'a', 'N'} - } - - func jsonQuote(s string) string { - b, _ := json.Marshal(s) - s1 := string(b) - return s1 - } - - func jsonUnquote(s string) (string, bool) { - s1 := new(string) - err := json.Unmarshal([]byte(s), s1) - return *s1, err == nil - } - - func mismatch(expected, actual string) error { - return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) - } - - func (p *TSimpleJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { - p.resetContextStack() // THRIFT-3735 - if e := p.OutputListBegin(); e != nil { - return e - } - if e := p.WriteString(ctx, name); e != nil { - return e - } - if e := p.WriteByte(ctx, int8(typeId)); e != nil { - return e - } - if e := p.WriteI32(ctx, seqId); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) WriteMessageEnd(ctx context.Context) error { - return p.OutputListEnd() - } - - func (p *TSimpleJSONProtocol) WriteStructBegin(ctx context.Context, name string) error { - if e := p.OutputObjectBegin(); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) WriteStructEnd(ctx context.Context) error { - return p.OutputObjectEnd() - } - - func (p *TSimpleJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { - if e := p.WriteString(ctx, name); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) WriteFieldEnd(ctx context.Context) error { - return nil - } - - func (p *TSimpleJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil } - - func (p *TSimpleJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { - if e := p.OutputListBegin(); e != nil { - return e - } - if e := p.WriteByte(ctx, int8(keyType)); e != nil { - return e - } - if e := p.WriteByte(ctx, int8(valueType)); e != nil { - return e - } - return p.WriteI32(ctx, int32(size)) - } - - func (p *TSimpleJSONProtocol) WriteMapEnd(ctx context.Context) error { - return p.OutputListEnd() - } - - func (p *TSimpleJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { - return p.OutputElemListBegin(elemType, size) - } - - func (p *TSimpleJSONProtocol) WriteListEnd(ctx context.Context) error { - return p.OutputListEnd() - } - - func (p *TSimpleJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { - return p.OutputElemListBegin(elemType, size) - } - - func (p *TSimpleJSONProtocol) WriteSetEnd(ctx context.Context) error { - return p.OutputListEnd() - } - - func (p *TSimpleJSONProtocol) WriteBool(ctx context.Context, b bool) error { - return p.OutputBool(b) - } - - func (p *TSimpleJSONProtocol) WriteByte(ctx context.Context, b int8) error { - return p.WriteI32(ctx, int32(b)) - } - - func (p *TSimpleJSONProtocol) WriteI16(ctx context.Context, v int16) error { - return p.WriteI32(ctx, int32(v)) - } - - func (p *TSimpleJSONProtocol) WriteI32(ctx context.Context, v int32) error { - return p.OutputI64(int64(v)) - } - - func (p *TSimpleJSONProtocol) WriteI64(ctx context.Context, v int64) error { - return p.OutputI64(int64(v)) - } - - func (p *TSimpleJSONProtocol) WriteDouble(ctx context.Context, v float64) error { - return p.OutputF64(v) - } - - func (p *TSimpleJSONProtocol) WriteString(ctx context.Context, v string) error { - return p.OutputString(v) - } - - func (p *TSimpleJSONProtocol) WriteBinary(ctx context.Context, v []byte) error { - // JSON library only takes in a string, - // not an arbitrary byte array, to ensure bytes are transmitted - // efficiently we must convert this into a valid JSON string - // therefore we use base64 encoding to avoid excessive escaping/quoting - if e := p.OutputPreValue(); e != nil { - return e - } - if _, e := p.write(JSON_QUOTE_BYTES); e != nil { - return NewTProtocolException(e) - } - writer := base64.NewEncoder(base64.StdEncoding, p.writer) - if _, e := writer.Write(v); e != nil { - p.writer.Reset(p.trans) // THRIFT-3735 - return NewTProtocolException(e) - } - if e := writer.Close(); e != nil { - return NewTProtocolException(e) - } - if _, e := p.write(JSON_QUOTE_BYTES); e != nil { - return NewTProtocolException(e) - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) WriteUUID(ctx context.Context, v Tuuid) error { - return p.OutputString(v.String()) - } - - // Reading methods. - func (p *TSimpleJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { - p.resetContextStack() // THRIFT-3735 - if isNull, err := p.ParseListBegin(); isNull || err != nil { - return name, typeId, seqId, err - } - if name, err = p.ReadString(ctx); err != nil { - return name, typeId, seqId, err - } - bTypeId, err := p.ReadByte(ctx) - typeId = TMessageType(bTypeId) - if err != nil { - return name, typeId, seqId, err - } - if seqId, err = p.ReadI32(ctx); err != nil { - return name, typeId, seqId, err - } - return name, typeId, seqId, nil - } - - func (p *TSimpleJSONProtocol) ReadMessageEnd(ctx context.Context) error { - return p.ParseListEnd() - } - - func (p *TSimpleJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { - _, err = p.ParseObjectStart() - return "", err - } - - func (p *TSimpleJSONProtocol) ReadStructEnd(ctx context.Context) error { - return p.ParseObjectEnd() - } - - func (p *TSimpleJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) { - if err := p.ParsePreValue(); err != nil { - return "", STOP, 0, err - } - b, _ := p.reader.Peek(1) - if len(b) > 0 { - switch b[0] { - case JSON_RBRACE[0]: - return "", STOP, 0, nil - case JSON_QUOTE: - p.reader.ReadByte() - name, err := p.ParseStringBody() - // simplejson is not meant to be read back into thrift - // - see http://wiki.apache.org/thrift/ThriftUsageJava - // - use JSON instead - if err != nil { - return name, STOP, 0, err - } - return name, STOP, -1, p.ParsePostValue() - } - e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) - return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - return "", STOP, 0, NewTProtocolException(io.EOF) - } - - func (p *TSimpleJSONProtocol) ReadFieldEnd(ctx context.Context) error { - return nil - } - - func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) { - if isNull, e := p.ParseListBegin(); isNull || e != nil { - return VOID, VOID, 0, e - } - - // read keyType - bKeyType, e := p.ReadByte(ctx) - keyType = TType(bKeyType) - if e != nil { - return keyType, valueType, size, e - } - - // read valueType - bValueType, e := p.ReadByte(ctx) - valueType = TType(bValueType) - if e != nil { - return keyType, valueType, size, e - } - - // read size - iSize, err := p.ReadI64(ctx) - if err != nil { - return keyType, valueType, 0, err - } - err = checkSizeForProtocol(int32(size), p.cfg) - if err != nil { - return keyType, valueType, 0, err - } - size = int(iSize) - return keyType, valueType, size, err - } - - func (p *TSimpleJSONProtocol) ReadMapEnd(ctx context.Context) error { - return p.ParseListEnd() - } - - func (p *TSimpleJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) { - return p.ParseElemListBegin() - } - - func (p *TSimpleJSONProtocol) ReadListEnd(ctx context.Context) error { - return p.ParseListEnd() - } - - func (p *TSimpleJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) { - return p.ParseElemListBegin() - } - - func (p *TSimpleJSONProtocol) ReadSetEnd(ctx context.Context) error { - return p.ParseListEnd() - } - - func (p *TSimpleJSONProtocol) ReadBool(ctx context.Context) (bool, error) { - var value bool - - if err := p.ParsePreValue(); err != nil { - return value, err - } - f, _ := p.reader.Peek(1) - if len(f) > 0 { - switch f[0] { - case JSON_TRUE[0]: - b := make([]byte, len(JSON_TRUE)) - _, err := p.reader.Read(b) - if err != nil { - return false, NewTProtocolException(err) - } - if string(b) == string(JSON_TRUE) { - value = true - } else { - e := fmt.Errorf("Expected \"true\" but found: %s", string(b)) - return value, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - case JSON_FALSE[0]: - b := make([]byte, len(JSON_FALSE)) - _, err := p.reader.Read(b) - if err != nil { - return false, NewTProtocolException(err) - } - if string(b) == string(JSON_FALSE) { - value = false - } else { - e := fmt.Errorf("Expected \"false\" but found: %s", string(b)) - return value, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - case JSON_NULL[0]: - b := make([]byte, len(JSON_NULL)) - _, err := p.reader.Read(b) - if err != nil { - return false, NewTProtocolException(err) - } - if string(b) == string(JSON_NULL) { - value = false - } else { - e := fmt.Errorf("Expected \"null\" but found: %s", string(b)) - return value, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - default: - e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f)) - return value, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } - return value, p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ReadByte(ctx context.Context) (int8, error) { - v, err := p.ReadI64(ctx) - return int8(v), err - } - - func (p *TSimpleJSONProtocol) ReadI16(ctx context.Context) (int16, error) { - v, err := p.ReadI64(ctx) - return int16(v), err - } - - func (p *TSimpleJSONProtocol) ReadI32(ctx context.Context) (int32, error) { - v, err := p.ReadI64(ctx) - return int32(v), err - } - - func (p *TSimpleJSONProtocol) ReadI64(ctx context.Context) (int64, error) { - v, _, err := p.ParseI64() - return v, err - } - - func (p *TSimpleJSONProtocol) ReadDouble(ctx context.Context) (float64, error) { - v, _, err := p.ParseF64() - return v, err - } - - func (p *TSimpleJSONProtocol) ReadString(ctx context.Context) (string, error) { - var v string - if err := p.ParsePreValue(); err != nil { - return v, err - } - f, _ := p.reader.Peek(1) - if len(f) > 0 && f[0] == JSON_QUOTE { - p.reader.ReadByte() - value, err := p.ParseStringBody() - v = value - if err != nil { - return v, err - } - } else if len(f) > 0 && f[0] == JSON_NULL[0] { - b := make([]byte, len(JSON_NULL)) - _, err := p.reader.Read(b) - if err != nil { - return v, NewTProtocolException(err) - } - if string(b) != string(JSON_NULL) { - e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) - return v, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } else { - e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) - return v, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - return v, p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) { - var v []byte - if err := p.ParsePreValue(); err != nil { - return nil, err - } - f, _ := p.reader.Peek(1) - if len(f) > 0 && f[0] == JSON_QUOTE { - p.reader.ReadByte() - value, err := p.ParseBase64EncodedBody() - v = value - if err != nil { - return v, err - } - } else if len(f) > 0 && f[0] == JSON_NULL[0] { - b := make([]byte, len(JSON_NULL)) - _, err := p.reader.Read(b) - if err != nil { - return v, NewTProtocolException(err) - } - if string(b) != string(JSON_NULL) { - e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) - return v, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } else { - e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) - return v, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - - return v, p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ReadUUID(ctx context.Context) (v Tuuid, err error) { - var s string - s, err = p.ReadString(ctx) - if err != nil { - return v, err - } - v, err = ParseTuuid(s) - return v, NewTProtocolExceptionWithType(INVALID_DATA, err) - } - - func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) { - return NewTProtocolException(p.writer.Flush()) - } - - func (p *TSimpleJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) { - return SkipDefaultDepth(ctx, p, fieldType) - } - - func (p *TSimpleJSONProtocol) Transport() TTransport { - return p.trans - } - - func (p *TSimpleJSONProtocol) OutputPreValue() error { - cxt, ok := p.dumpContext.peek() - if !ok { - return errEmptyJSONContextStack - } - switch cxt { - case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: - if _, e := p.write(JSON_COMMA); e != nil { - return NewTProtocolException(e) - } - case _CONTEXT_IN_OBJECT_NEXT_VALUE: - if _, e := p.write(JSON_COLON); e != nil { - return NewTProtocolException(e) - } - } - return nil - } - - func (p *TSimpleJSONProtocol) OutputPostValue() error { - cxt, ok := p.dumpContext.peek() - if !ok { - return errEmptyJSONContextStack - } - switch cxt { - case _CONTEXT_IN_LIST_FIRST: - p.dumpContext.pop() - p.dumpContext.push(_CONTEXT_IN_LIST) - case _CONTEXT_IN_OBJECT_FIRST: - p.dumpContext.pop() - p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) - case _CONTEXT_IN_OBJECT_NEXT_KEY: - p.dumpContext.pop() - p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) - case _CONTEXT_IN_OBJECT_NEXT_VALUE: - p.dumpContext.pop() - p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY) - } - return nil - } - - func (p *TSimpleJSONProtocol) OutputBool(value bool) error { - if e := p.OutputPreValue(); e != nil { - return e - } - var v string - if value { - v = string(JSON_TRUE) - } else { - v = string(JSON_FALSE) - } - cxt, ok := p.dumpContext.peek() - if !ok { - return errEmptyJSONContextStack - } - switch cxt { - case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: - v = jsonQuote(v) - } - if e := p.OutputStringData(v); e != nil { - return e - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) OutputNull() error { - if e := p.OutputPreValue(); e != nil { - return e - } - if _, e := p.write(JSON_NULL); e != nil { - return NewTProtocolException(e) - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) OutputF64(value float64) error { - if e := p.OutputPreValue(); e != nil { - return e - } - var v string - if math.IsNaN(value) { - v = string(JSON_QUOTE) + JSON_NAN + string(JSON_QUOTE) - } else if math.IsInf(value, 1) { - v = string(JSON_QUOTE) + JSON_INFINITY + string(JSON_QUOTE) - } else if math.IsInf(value, -1) { - v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE) - } else { - cxt, ok := p.dumpContext.peek() - if !ok { - return errEmptyJSONContextStack - } - v = strconv.FormatFloat(value, 'g', -1, 64) - switch cxt { - case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: - v = string(JSON_QUOTE) + v + string(JSON_QUOTE) - } - } - if e := p.OutputStringData(v); e != nil { - return e - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) OutputI64(value int64) error { - if e := p.OutputPreValue(); e != nil { - return e - } - cxt, ok := p.dumpContext.peek() - if !ok { - return errEmptyJSONContextStack - } - v := strconv.FormatInt(value, 10) - switch cxt { - case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: - v = jsonQuote(v) - } - if e := p.OutputStringData(v); e != nil { - return e - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) OutputString(s string) error { - if e := p.OutputPreValue(); e != nil { - return e - } - if e := p.OutputStringData(jsonQuote(s)); e != nil { - return e - } - return p.OutputPostValue() - } - - func (p *TSimpleJSONProtocol) OutputStringData(s string) error { - _, e := p.write([]byte(s)) - return NewTProtocolException(e) - } - - func (p *TSimpleJSONProtocol) OutputObjectBegin() error { - if e := p.OutputPreValue(); e != nil { - return e - } - if _, e := p.write(JSON_LBRACE); e != nil { - return NewTProtocolException(e) - } - p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST) - return nil - } - - func (p *TSimpleJSONProtocol) OutputObjectEnd() error { - if _, e := p.write(JSON_RBRACE); e != nil { - return NewTProtocolException(e) - } - _, ok := p.dumpContext.pop() - if !ok { - return errEmptyJSONContextStack - } - if e := p.OutputPostValue(); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) OutputListBegin() error { - if e := p.OutputPreValue(); e != nil { - return e - } - if _, e := p.write(JSON_LBRACKET); e != nil { - return NewTProtocolException(e) - } - p.dumpContext.push(_CONTEXT_IN_LIST_FIRST) - return nil - } - - func (p *TSimpleJSONProtocol) OutputListEnd() error { - if _, e := p.write(JSON_RBRACKET); e != nil { - return NewTProtocolException(e) - } - _, ok := p.dumpContext.pop() - if !ok { - return errEmptyJSONContextStack - } - if e := p.OutputPostValue(); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) error { - if e := p.OutputListBegin(); e != nil { - return e - } - if e := p.OutputI64(int64(elemType)); e != nil { - return e - } - if e := p.OutputI64(int64(size)); e != nil { - return e - } - return nil - } - - func (p *TSimpleJSONProtocol) ParsePreValue() error { - if e := p.readNonSignificantWhitespace(); e != nil { - return NewTProtocolException(e) - } - cxt, ok := p.parseContextStack.peek() - if !ok { - return errEmptyJSONContextStack - } - b, _ := p.reader.Peek(1) - switch cxt { - case _CONTEXT_IN_LIST: - if len(b) > 0 { - switch b[0] { - case JSON_RBRACKET[0]: - return nil - case JSON_COMMA[0]: - p.reader.ReadByte() - if e := p.readNonSignificantWhitespace(); e != nil { - return NewTProtocolException(e) - } - return nil - default: - e := fmt.Errorf("Expected \"]\" or \",\" in list context, but found \"%s\"", string(b)) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } - case _CONTEXT_IN_OBJECT_NEXT_KEY: - if len(b) > 0 { - switch b[0] { - case JSON_RBRACE[0]: - return nil - case JSON_COMMA[0]: - p.reader.ReadByte() - if e := p.readNonSignificantWhitespace(); e != nil { - return NewTProtocolException(e) - } - return nil - default: - e := fmt.Errorf("Expected \"}\" or \",\" in object context, but found \"%s\"", string(b)) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } - case _CONTEXT_IN_OBJECT_NEXT_VALUE: - if len(b) > 0 { - switch b[0] { - case JSON_COLON[0]: - p.reader.ReadByte() - if e := p.readNonSignificantWhitespace(); e != nil { - return NewTProtocolException(e) - } - return nil - default: - e := fmt.Errorf("Expected \":\" in object context, but found \"%s\"", string(b)) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } - } - return nil - } - - func (p *TSimpleJSONProtocol) ParsePostValue() error { - if e := p.readNonSignificantWhitespace(); e != nil { - return NewTProtocolException(e) - } - cxt, ok := p.parseContextStack.peek() - if !ok { - return errEmptyJSONContextStack - } - switch cxt { - case _CONTEXT_IN_LIST_FIRST: - p.parseContextStack.pop() - p.parseContextStack.push(_CONTEXT_IN_LIST) - case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: - p.parseContextStack.pop() - p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) - case _CONTEXT_IN_OBJECT_NEXT_VALUE: - p.parseContextStack.pop() - p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY) - } - return nil - } - - func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { - for { - b, _ := p.reader.Peek(1) - if len(b) < 1 { - return nil - } - switch b[0] { - case ' ', '\r', '\n', '\t': - p.reader.ReadByte() - continue - } - break - } - return nil - } - - func (p *TSimpleJSONProtocol) ParseStringBody() (string, error) { - line, err := p.reader.ReadString(JSON_QUOTE) - if err != nil { - return "", NewTProtocolException(err) - } - if endsWithoutEscapedQuote(line) { - v, ok := jsonUnquote(string(JSON_QUOTE) + line) - if !ok { - return "", NewTProtocolException(err) - } - return v, nil - } - s, err := p.ParseQuotedStringBody() - if err != nil { - return "", NewTProtocolException(err) - } - str := string(JSON_QUOTE) + line + s - v, ok := jsonUnquote(str) - if !ok { - e := fmt.Errorf("Unable to parse as JSON string %s", str) - return "", NewTProtocolExceptionWithType(INVALID_DATA, e) - } - return v, nil - } - - func (p *TSimpleJSONProtocol) ParseQuotedStringBody() (string, error) { - var sb strings.Builder - - for { - line, err := p.reader.ReadString(JSON_QUOTE) - if err != nil { - return "", NewTProtocolException(err) - } - sb.WriteString(line) - if endsWithoutEscapedQuote(line) { - return sb.String(), nil - } - } - } - - func endsWithoutEscapedQuote(s string) bool { - l := len(s) - i := 1 - for ; i < l; i++ { - if s[l-i-1] != '\\' { - break - } - } - return i&0x01 == 1 - } - - func (p *TSimpleJSONProtocol) ParseBase64EncodedBody() ([]byte, error) { - line, err := p.reader.ReadBytes(JSON_QUOTE) - if err != nil { - return line, NewTProtocolException(err) - } - line2 := line[0 : len(line)-1] - l := len(line2) - if (l % 4) != 0 { - pad := 4 - (l % 4) - fill := [...]byte{'=', '=', '='} - line2 = append(line2, fill[:pad]...) - l = len(line2) - } - output := make([]byte, base64.StdEncoding.DecodedLen(l)) - n, err := base64.StdEncoding.Decode(output, line2) - return output[0:n], NewTProtocolException(err) - } - - func (p *TSimpleJSONProtocol) ParseI64() (int64, bool, error) { - if err := p.ParsePreValue(); err != nil { - return 0, false, err - } - var value int64 - var isnull bool - if p.safePeekContains(JSON_NULL) { - p.reader.Read(make([]byte, len(JSON_NULL))) - isnull = true - } else { - num, err := p.readNumeric() - isnull = (num == nil) - if !isnull { - value = num.Int64() - } - if err != nil { - return value, isnull, err - } - } - return value, isnull, p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ParseF64() (float64, bool, error) { - if err := p.ParsePreValue(); err != nil { - return 0, false, err - } - var value float64 - var isnull bool - if p.safePeekContains(JSON_NULL) { - p.reader.Read(make([]byte, len(JSON_NULL))) - isnull = true - } else { - num, err := p.readNumeric() - isnull = (num == nil) - if !isnull { - value = num.Float64() - } - if err != nil { - return value, isnull, err - } - } - return value, isnull, p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { - if err := p.ParsePreValue(); err != nil { - return false, err - } - var b []byte - b, err := p.reader.Peek(1) - if err != nil { - return false, err - } - if len(b) > 0 && b[0] == JSON_LBRACE[0] { - p.reader.ReadByte() - p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST) - return false, nil - } else if p.safePeekContains(JSON_NULL) { - return true, nil - } - e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b)) - return false, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - - func (p *TSimpleJSONProtocol) ParseObjectEnd() error { - if isNull, err := p.readIfNull(); isNull || err != nil { - return err - } - cxt, _ := p.parseContextStack.peek() - if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) { - e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - } - line, err := p.reader.ReadString(JSON_RBRACE[0]) - if err != nil { - return NewTProtocolException(err) - } - for _, char := range line { - switch char { - default: - e := fmt.Errorf("Expecting end of object \"}\", but found: \"%s\"", line) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - case ' ', '\n', '\r', '\t', '}': - // do nothing - } - } - p.parseContextStack.pop() - return p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { - if e := p.ParsePreValue(); e != nil { - return false, e - } - var b []byte - b, err = p.reader.Peek(1) - if err != nil { - return false, err - } - if len(b) >= 1 && b[0] == JSON_LBRACKET[0] { - p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST) - p.reader.ReadByte() - isNull = false - } else if p.safePeekContains(JSON_NULL) { - isNull = true - } else { - err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b) - } - return isNull, NewTProtocolExceptionWithType(INVALID_DATA, err) - } - - func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { - if isNull, e := p.ParseListBegin(); isNull || e != nil { - return VOID, 0, e - } - bElemType, _, err := p.ParseI64() - elemType = TType(bElemType) - if err != nil { - return elemType, size, err - } - nSize, _, err := p.ParseI64() - if err != nil { - return elemType, 0, err - } - err = checkSizeForProtocol(int32(nSize), p.cfg) - if err != nil { - return elemType, 0, err - } - size = int(nSize) - return elemType, size, nil - } - - func (p *TSimpleJSONProtocol) ParseListEnd() error { - if isNull, err := p.readIfNull(); isNull || err != nil { - return err - } - cxt, _ := p.parseContextStack.peek() - if cxt != _CONTEXT_IN_LIST { - e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - } - line, err := p.reader.ReadString(JSON_RBRACKET[0]) - if err != nil { - return NewTProtocolException(err) - } - for _, char := range line { - switch char { - default: - e := fmt.Errorf("Expecting end of list \"]\", but found: \"%v\"", line) - return NewTProtocolExceptionWithType(INVALID_DATA, e) - case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]): - // do nothing - } - } - p.parseContextStack.pop() - if cxt, ok := p.parseContextStack.peek(); !ok { - return errEmptyJSONContextStack - } else if cxt == _CONTEXT_IN_TOPLEVEL { - return nil - } - return p.ParsePostValue() - } - - func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { - cont := true - for cont { - b, _ := p.reader.Peek(1) - if len(b) < 1 { - return false, nil - } - switch b[0] { - default: - return false, nil - case JSON_NULL[0]: - cont = false - case ' ', '\n', '\r', '\t': - p.reader.ReadByte() - } - } - if p.safePeekContains(JSON_NULL) { - p.reader.Read(make([]byte, len(JSON_NULL))) - return true, nil - } - return false, nil - } - - func (p *TSimpleJSONProtocol) readQuoteIfNext() { - b, _ := p.reader.Peek(1) - if len(b) > 0 && b[0] == JSON_QUOTE { - p.reader.ReadByte() - } - } - - func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { - isNull, err := p.readIfNull() - if isNull || err != nil { - return NUMERIC_NULL, err - } - hasDecimalPoint := false - nextCanBeSign := true - hasE := false - MAX_LEN := 40 - buf := bytes.NewBuffer(make([]byte, 0, MAX_LEN)) - continueFor := true - inQuotes := false - for continueFor { - c, err := p.reader.ReadByte() - if err != nil { - if err == io.EOF { - break - } - return NUMERIC_NULL, NewTProtocolException(err) - } - switch c { - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': - buf.WriteByte(c) - nextCanBeSign = false - case '.': - if hasDecimalPoint { - e := fmt.Errorf("Unable to parse number with multiple decimal points '%s.'", buf.String()) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - if hasE { - e := fmt.Errorf("Unable to parse number with decimal points in the exponent '%s.'", buf.String()) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - buf.WriteByte(c) - hasDecimalPoint, nextCanBeSign = true, false - case 'e', 'E': - if hasE { - e := fmt.Errorf("Unable to parse number with multiple exponents '%s%c'", buf.String(), c) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - buf.WriteByte(c) - hasE, nextCanBeSign = true, true - case '-', '+': - if !nextCanBeSign { - e := fmt.Errorf("Negative sign within number") - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - buf.WriteByte(c) - nextCanBeSign = false - case ' ', 0, '\t', '\n', '\r', JSON_RBRACE[0], JSON_RBRACKET[0], JSON_COMMA[0], JSON_COLON[0]: - p.reader.UnreadByte() - continueFor = false - case JSON_NAN[0]: - if buf.Len() == 0 { - buffer := make([]byte, len(JSON_NAN)) - buffer[0] = c - _, e := p.reader.Read(buffer[1:]) - if e != nil { - return NUMERIC_NULL, NewTProtocolException(e) - } - if JSON_NAN != string(buffer) { - e := mismatch(JSON_NAN, string(buffer)) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - if inQuotes { - p.readQuoteIfNext() - } - return NAN, nil - } else { - e := fmt.Errorf("Unable to parse number starting with character '%c'", c) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - case JSON_INFINITY[0]: - if buf.Len() == 0 || (buf.Len() == 1 && buf.Bytes()[0] == '+') { - buffer := make([]byte, len(JSON_INFINITY)) - buffer[0] = c - _, e := p.reader.Read(buffer[1:]) - if e != nil { - return NUMERIC_NULL, NewTProtocolException(e) - } - if JSON_INFINITY != string(buffer) { - e := mismatch(JSON_INFINITY, string(buffer)) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - if inQuotes { - p.readQuoteIfNext() - } - return INFINITY, nil - } else if buf.Len() == 1 && buf.Bytes()[0] == JSON_NEGATIVE_INFINITY[0] { - buffer := make([]byte, len(JSON_NEGATIVE_INFINITY)) - buffer[0] = JSON_NEGATIVE_INFINITY[0] - buffer[1] = c - _, e := p.reader.Read(buffer[2:]) - if e != nil { - return NUMERIC_NULL, NewTProtocolException(e) - } - if JSON_NEGATIVE_INFINITY != string(buffer) { - e := mismatch(JSON_NEGATIVE_INFINITY, string(buffer)) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - if inQuotes { - p.readQuoteIfNext() - } - return NEGATIVE_INFINITY, nil - } else { - e := fmt.Errorf("Unable to parse number starting with character '%c' due to existing buffer %s", c, buf.String()) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - case JSON_QUOTE: - if !inQuotes { - inQuotes = true - } - default: - e := fmt.Errorf("Unable to parse number starting with character '%c'", c) - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - } - if buf.Len() == 0 { - e := fmt.Errorf("Unable to parse number from empty string ''") - return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) - } - return NewNumericFromJSONString(buf.String(), false), nil - } - - // Safely peeks into the buffer, reading only what is necessary - func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { - for i := 0; i < len(b); i++ { - a, _ := p.reader.Peek(i + 1) - if len(a) < (i+1) || a[i] != b[i] { - return false - } - } - return true - } - - // Reset the context stack to its initial state. - func (p *TSimpleJSONProtocol) resetContextStack() { - p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL} - p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL} - } - - func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { - n, err := p.writer.Write(b) - if err != nil { - p.writer.Reset(p.trans) // THRIFT-3735 - } - return n, err - } - - // SetTConfiguration implements TConfigurationSetter for propagation. - func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) { - PropagateTConfiguration(p.trans, conf) - p.cfg = conf - } - - // Reset resets this protocol's internal state. - // - // It's useful when a single protocol instance is reused after errors, to make - // sure the next use will not be in a bad state to begin with. An example is - // when it's used in serializer/deserializer pools. - func (p *TSimpleJSONProtocol) Reset() { - p.resetContextStack() - p.writer.Reset(p.trans) - p.reader.Reset(p.trans) - } - - var ( - _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil) - _ TConfigurationSetter = (*TSimpleJSONProtocolFactory)(nil) - ) \ No newline at end of file + +type _ParseContext int + +const ( + _CONTEXT_INVALID _ParseContext = iota + _CONTEXT_IN_TOPLEVEL // 1 + _CONTEXT_IN_LIST_FIRST // 2 + _CONTEXT_IN_LIST // 3 + _CONTEXT_IN_OBJECT_FIRST // 4 + _CONTEXT_IN_OBJECT_NEXT_KEY // 5 + _CONTEXT_IN_OBJECT_NEXT_VALUE // 6 +) + +func (p _ParseContext) String() string { + switch p { + case _CONTEXT_IN_TOPLEVEL: + return "TOPLEVEL" + case _CONTEXT_IN_LIST_FIRST: + return "LIST-FIRST" + case _CONTEXT_IN_LIST: + return "LIST" + case _CONTEXT_IN_OBJECT_FIRST: + return "OBJECT-FIRST" + case _CONTEXT_IN_OBJECT_NEXT_KEY: + return "OBJECT-NEXT-KEY" + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + return "OBJECT-NEXT-VALUE" + } + return "UNKNOWN-PARSE-CONTEXT" +} + +type jsonContextStack []_ParseContext + +func (s *jsonContextStack) push(v _ParseContext) { + *s = append(*s, v) +} + +func (s jsonContextStack) peek() (v _ParseContext, ok bool) { + l := len(s) + if l <= 0 { + return + } + return s[l-1], true +} + +func (s *jsonContextStack) pop() (v _ParseContext, ok bool) { + l := len(*s) + if l <= 0 { + return + } + v = (*s)[l-1] + *s = (*s)[0 : l-1] + return v, true +} + +var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack")) + +// Simple JSON protocol implementation for thrift. +// +// This protocol produces/consumes a simple output format +// suitable for parsing by scripting languages. It should not be +// confused with the full-featured TJSONProtocol. +type TSimpleJSONProtocol struct { + trans TTransport + + cfg *TConfiguration + + parseContextStack jsonContextStack + dumpContext jsonContextStack + + writer *bufio.Writer + reader *bufio.Reader +} + +// Deprecated: Use NewTSimpleJSONProtocolConf instead.: +func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { + return NewTSimpleJSONProtocolConf(t, &TConfiguration{ + noPropagation: true, + }) +} + +func NewTSimpleJSONProtocolConf(t TTransport, conf *TConfiguration) *TSimpleJSONProtocol { + PropagateTConfiguration(t, conf) + v := &TSimpleJSONProtocol{ + trans: t, + cfg: conf, + writer: bufio.NewWriter(t), + reader: bufio.NewReader(t), + } + v.resetContextStack() + return v +} + +// Factory +type TSimpleJSONProtocolFactory struct { + cfg *TConfiguration +} + +func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { + return NewTSimpleJSONProtocolConf(trans, p.cfg) +} + +// SetTConfiguration implements TConfigurationSetter for propagation. +func (p *TSimpleJSONProtocolFactory) SetTConfiguration(conf *TConfiguration) { + p.cfg = conf +} + +// Deprecated: Use NewTSimpleJSONProtocolFactoryConf instead. +func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { + return &TSimpleJSONProtocolFactory{ + cfg: &TConfiguration{ + noPropagation: true, + }, + } +} + +func NewTSimpleJSONProtocolFactoryConf(conf *TConfiguration) *TSimpleJSONProtocolFactory { + return &TSimpleJSONProtocolFactory{ + cfg: conf, + } +} + +var ( + JSON_COMMA []byte + JSON_COLON []byte + JSON_LBRACE []byte + JSON_RBRACE []byte + JSON_LBRACKET []byte + JSON_RBRACKET []byte + JSON_QUOTE byte + JSON_QUOTE_BYTES []byte + JSON_NULL []byte + JSON_TRUE []byte + JSON_FALSE []byte + JSON_INFINITY string + JSON_NEGATIVE_INFINITY string + JSON_NAN string + JSON_INFINITY_BYTES []byte + JSON_NEGATIVE_INFINITY_BYTES []byte + JSON_NAN_BYTES []byte +) + +func init() { + JSON_COMMA = []byte{','} + JSON_COLON = []byte{':'} + JSON_LBRACE = []byte{'{'} + JSON_RBRACE = []byte{'}'} + JSON_LBRACKET = []byte{'['} + JSON_RBRACKET = []byte{']'} + JSON_QUOTE = '"' + JSON_QUOTE_BYTES = []byte{'"'} + JSON_NULL = []byte{'n', 'u', 'l', 'l'} + JSON_TRUE = []byte{'t', 'r', 'u', 'e'} + JSON_FALSE = []byte{'f', 'a', 'l', 's', 'e'} + JSON_INFINITY = "Infinity" + JSON_NEGATIVE_INFINITY = "-Infinity" + JSON_NAN = "NaN" + JSON_INFINITY_BYTES = []byte{'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NEGATIVE_INFINITY_BYTES = []byte{'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} + JSON_NAN_BYTES = []byte{'N', 'a', 'N'} +} + +func jsonQuote(s string) string { + b, _ := json.Marshal(s) + s1 := string(b) + return s1 +} + +func jsonUnquote(s string) (string, bool) { + s1 := new(string) + err := json.Unmarshal([]byte(s), s1) + return *s1, err == nil +} + +func mismatch(expected, actual string) error { + return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) +} + +func (p *TSimpleJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { + p.resetContextStack() // THRIFT-3735 + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteString(ctx, name); e != nil { + return e + } + if e := p.WriteByte(ctx, int8(typeId)); e != nil { + return e + } + if e := p.WriteI32(ctx, seqId); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteMessageEnd(ctx context.Context) error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteStructBegin(ctx context.Context, name string) error { + if e := p.OutputObjectBegin(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteStructEnd(ctx context.Context) error { + return p.OutputObjectEnd() +} + +func (p *TSimpleJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + if e := p.WriteString(ctx, name); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldEnd(ctx context.Context) error { + return nil +} + +func (p *TSimpleJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil } + +func (p *TSimpleJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.WriteByte(ctx, int8(keyType)); e != nil { + return e + } + if e := p.WriteByte(ctx, int8(valueType)); e != nil { + return e + } + return p.WriteI32(ctx, int32(size)) +} + +func (p *TSimpleJSONProtocol) WriteMapEnd(ctx context.Context) error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteListEnd(ctx context.Context) error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { + return p.OutputElemListBegin(elemType, size) +} + +func (p *TSimpleJSONProtocol) WriteSetEnd(ctx context.Context) error { + return p.OutputListEnd() +} + +func (p *TSimpleJSONProtocol) WriteBool(ctx context.Context, b bool) error { + return p.OutputBool(b) +} + +func (p *TSimpleJSONProtocol) WriteByte(ctx context.Context, b int8) error { + return p.WriteI32(ctx, int32(b)) +} + +func (p *TSimpleJSONProtocol) WriteI16(ctx context.Context, v int16) error { + return p.WriteI32(ctx, int32(v)) +} + +func (p *TSimpleJSONProtocol) WriteI32(ctx context.Context, v int32) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteI64(ctx context.Context, v int64) error { + return p.OutputI64(int64(v)) +} + +func (p *TSimpleJSONProtocol) WriteDouble(ctx context.Context, v float64) error { + return p.OutputF64(v) +} + +func (p *TSimpleJSONProtocol) WriteString(ctx context.Context, v string) error { + return p.OutputString(v) +} + +func (p *TSimpleJSONProtocol) WriteBinary(ctx context.Context, v []byte) error { + // JSON library only takes in a string, + // not an arbitrary byte array, to ensure bytes are transmitted + // efficiently we must convert this into a valid JSON string + // therefore we use base64 encoding to avoid excessive escaping/quoting + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + writer := base64.NewEncoder(base64.StdEncoding, p.writer) + if _, e := writer.Write(v); e != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + return NewTProtocolException(e) + } + if e := writer.Close(); e != nil { + return NewTProtocolException(e) + } + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) WriteUUID(ctx context.Context, v Tuuid) error { + return p.OutputString(v.String()) +} + +// Reading methods. +func (p *TSimpleJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { + p.resetContextStack() // THRIFT-3735 + if isNull, err := p.ParseListBegin(); isNull || err != nil { + return name, typeId, seqId, err + } + if name, err = p.ReadString(ctx); err != nil { + return name, typeId, seqId, err + } + bTypeId, err := p.ReadByte(ctx) + typeId = TMessageType(bTypeId) + if err != nil { + return name, typeId, seqId, err + } + if seqId, err = p.ReadI32(ctx); err != nil { + return name, typeId, seqId, err + } + return name, typeId, seqId, nil +} + +func (p *TSimpleJSONProtocol) ReadMessageEnd(ctx context.Context) error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { + _, err = p.ParseObjectStart() + return "", err +} + +func (p *TSimpleJSONProtocol) ReadStructEnd(ctx context.Context) error { + return p.ParseObjectEnd() +} + +func (p *TSimpleJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) { + if err := p.ParsePreValue(); err != nil { + return "", STOP, 0, err + } + b, _ := p.reader.Peek(1) + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return "", STOP, 0, nil + case JSON_QUOTE: + p.reader.ReadByte() + name, err := p.ParseStringBody() + // simplejson is not meant to be read back into thrift + // - see http://wiki.apache.org/thrift/ThriftUsageJava + // - use JSON instead + if err != nil { + return name, STOP, 0, err + } + return name, STOP, -1, p.ParsePostValue() + } + e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) + return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return "", STOP, 0, NewTProtocolException(io.EOF) +} + +func (p *TSimpleJSONProtocol) ReadFieldEnd(ctx context.Context) error { + return nil +} + +func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, VOID, 0, e + } + + // read keyType + bKeyType, e := p.ReadByte(ctx) + keyType = TType(bKeyType) + if e != nil { + return keyType, valueType, size, e + } + + // read valueType + bValueType, e := p.ReadByte(ctx) + valueType = TType(bValueType) + if e != nil { + return keyType, valueType, size, e + } + + // read size + iSize, err := p.ReadI64(ctx) + if err != nil { + return keyType, valueType, 0, err + } + err = checkSizeForProtocol(int32(size), p.cfg) + if err != nil { + return keyType, valueType, 0, err + } + size = int(iSize) + return keyType, valueType, size, err +} + +func (p *TSimpleJSONProtocol) ReadMapEnd(ctx context.Context) error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadListEnd(ctx context.Context) error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) { + return p.ParseElemListBegin() +} + +func (p *TSimpleJSONProtocol) ReadSetEnd(ctx context.Context) error { + return p.ParseListEnd() +} + +func (p *TSimpleJSONProtocol) ReadBool(ctx context.Context) (bool, error) { + var value bool + + if err := p.ParsePreValue(); err != nil { + return value, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 { + switch f[0] { + case JSON_TRUE[0]: + b := make([]byte, len(JSON_TRUE)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_TRUE) { + value = true + } else { + e := fmt.Errorf("Expected \"true\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_FALSE[0]: + b := make([]byte, len(JSON_FALSE)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_FALSE) { + value = false + } else { + e := fmt.Errorf("Expected \"false\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_NULL[0]: + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return false, NewTProtocolException(err) + } + if string(b) == string(JSON_NULL) { + value = false + } else { + e := fmt.Errorf("Expected \"null\" but found: %s", string(b)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + default: + e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f)) + return value, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + return value, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadByte(ctx context.Context) (int8, error) { + v, err := p.ReadI64(ctx) + return int8(v), err +} + +func (p *TSimpleJSONProtocol) ReadI16(ctx context.Context) (int16, error) { + v, err := p.ReadI64(ctx) + return int16(v), err +} + +func (p *TSimpleJSONProtocol) ReadI32(ctx context.Context) (int32, error) { + v, err := p.ReadI64(ctx) + return int32(v), err +} + +func (p *TSimpleJSONProtocol) ReadI64(ctx context.Context) (int64, error) { + v, _, err := p.ParseI64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadDouble(ctx context.Context) (float64, error) { + v, _, err := p.ParseF64() + return v, err +} + +func (p *TSimpleJSONProtocol) ReadString(ctx context.Context) (string, error) { + var v string + if err := p.ParsePreValue(); err != nil { + return v, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseStringBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) { + var v []byte + if err := p.ParsePreValue(); err != nil { + return nil, err + } + f, _ := p.reader.Peek(1) + if len(f) > 0 && f[0] == JSON_QUOTE { + p.reader.ReadByte() + value, err := p.ParseBase64EncodedBody() + v = value + if err != nil { + return v, err + } + } else if len(f) > 0 && f[0] == JSON_NULL[0] { + b := make([]byte, len(JSON_NULL)) + _, err := p.reader.Read(b) + if err != nil { + return v, NewTProtocolException(err) + } + if string(b) != string(JSON_NULL) { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } else { + e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) + return v, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + + return v, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ReadUUID(ctx context.Context) (v Tuuid, err error) { + var s string + s, err = p.ReadString(ctx) + if err != nil { + return v, err + } + v, err = ParseTuuid(s) + return v, NewTProtocolExceptionWithType(INVALID_DATA, err) +} + +func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) { + return NewTProtocolException(p.writer.Flush()) +} + +func (p *TSimpleJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + return SkipDefaultDepth(ctx, p, fieldType) +} + +func (p *TSimpleJSONProtocol) Transport() TTransport { + return p.trans +} + +func (p *TSimpleJSONProtocol) OutputPreValue() error { + cxt, ok := p.dumpContext.peek() + if !ok { + return errEmptyJSONContextStack + } + switch cxt { + case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: + if _, e := p.write(JSON_COMMA); e != nil { + return NewTProtocolException(e) + } + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if _, e := p.write(JSON_COLON); e != nil { + return NewTProtocolException(e) + } + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputPostValue() error { + cxt, ok := p.dumpContext.peek() + if !ok { + return errEmptyJSONContextStack + } + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.dumpContext.pop() + p.dumpContext.push(_CONTEXT_IN_LIST) + case _CONTEXT_IN_OBJECT_FIRST: + p.dumpContext.pop() + p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) + case _CONTEXT_IN_OBJECT_NEXT_KEY: + p.dumpContext.pop() + p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.dumpContext.pop() + p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY) + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputBool(value bool) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if value { + v = string(JSON_TRUE) + } else { + v = string(JSON_FALSE) + } + cxt, ok := p.dumpContext.peek() + if !ok { + return errEmptyJSONContextStack + } + switch cxt { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputNull() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_NULL); e != nil { + return NewTProtocolException(e) + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputF64(value float64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + var v string + if math.IsNaN(value) { + v = string(JSON_QUOTE) + JSON_NAN + string(JSON_QUOTE) + } else if math.IsInf(value, 1) { + v = string(JSON_QUOTE) + JSON_INFINITY + string(JSON_QUOTE) + } else if math.IsInf(value, -1) { + v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE) + } else { + cxt, ok := p.dumpContext.peek() + if !ok { + return errEmptyJSONContextStack + } + v = strconv.FormatFloat(value, 'g', -1, 64) + switch cxt { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = string(JSON_QUOTE) + v + string(JSON_QUOTE) + } + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputI64(value int64) error { + if e := p.OutputPreValue(); e != nil { + return e + } + cxt, ok := p.dumpContext.peek() + if !ok { + return errEmptyJSONContextStack + } + v := strconv.FormatInt(value, 10) + switch cxt { + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + v = jsonQuote(v) + } + if e := p.OutputStringData(v); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputString(s string) error { + if e := p.OutputPreValue(); e != nil { + return e + } + if e := p.OutputStringData(jsonQuote(s)); e != nil { + return e + } + return p.OutputPostValue() +} + +func (p *TSimpleJSONProtocol) OutputStringData(s string) error { + _, e := p.write([]byte(s)) + return NewTProtocolException(e) +} + +func (p *TSimpleJSONProtocol) OutputObjectBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_LBRACE); e != nil { + return NewTProtocolException(e) + } + p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST) + return nil +} + +func (p *TSimpleJSONProtocol) OutputObjectEnd() error { + if _, e := p.write(JSON_RBRACE); e != nil { + return NewTProtocolException(e) + } + _, ok := p.dumpContext.pop() + if !ok { + return errEmptyJSONContextStack + } + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputListBegin() error { + if e := p.OutputPreValue(); e != nil { + return e + } + if _, e := p.write(JSON_LBRACKET); e != nil { + return NewTProtocolException(e) + } + p.dumpContext.push(_CONTEXT_IN_LIST_FIRST) + return nil +} + +func (p *TSimpleJSONProtocol) OutputListEnd() error { + if _, e := p.write(JSON_RBRACKET); e != nil { + return NewTProtocolException(e) + } + _, ok := p.dumpContext.pop() + if !ok { + return errEmptyJSONContextStack + } + if e := p.OutputPostValue(); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) error { + if e := p.OutputListBegin(); e != nil { + return e + } + if e := p.OutputI64(int64(elemType)); e != nil { + return e + } + if e := p.OutputI64(int64(size)); e != nil { + return e + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePreValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt, ok := p.parseContextStack.peek() + if !ok { + return errEmptyJSONContextStack + } + b, _ := p.reader.Peek(1) + switch cxt { + case _CONTEXT_IN_LIST: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACKET[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"]\" or \",\" in list context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + case _CONTEXT_IN_OBJECT_NEXT_KEY: + if len(b) > 0 { + switch b[0] { + case JSON_RBRACE[0]: + return nil + case JSON_COMMA[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \"}\" or \",\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + if len(b) > 0 { + switch b[0] { + case JSON_COLON[0]: + p.reader.ReadByte() + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + return nil + default: + e := fmt.Errorf("Expected \":\" in object context, but found \"%s\"", string(b)) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + } + return nil +} + +func (p *TSimpleJSONProtocol) ParsePostValue() error { + if e := p.readNonSignificantWhitespace(); e != nil { + return NewTProtocolException(e) + } + cxt, ok := p.parseContextStack.peek() + if !ok { + return errEmptyJSONContextStack + } + switch cxt { + case _CONTEXT_IN_LIST_FIRST: + p.parseContextStack.pop() + p.parseContextStack.push(_CONTEXT_IN_LIST) + case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: + p.parseContextStack.pop() + p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE) + case _CONTEXT_IN_OBJECT_NEXT_VALUE: + p.parseContextStack.pop() + p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY) + } + return nil +} + +func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { + for { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return nil + } + switch b[0] { + case ' ', '\r', '\n', '\t': + p.reader.ReadByte() + continue + } + break + } + return nil +} + +func (p *TSimpleJSONProtocol) ParseStringBody() (string, error) { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + if endsWithoutEscapedQuote(line) { + v, ok := jsonUnquote(string(JSON_QUOTE) + line) + if !ok { + return "", NewTProtocolException(err) + } + return v, nil + } + s, err := p.ParseQuotedStringBody() + if err != nil { + return "", NewTProtocolException(err) + } + str := string(JSON_QUOTE) + line + s + v, ok := jsonUnquote(str) + if !ok { + e := fmt.Errorf("Unable to parse as JSON string %s", str) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return v, nil +} + +func (p *TSimpleJSONProtocol) ParseQuotedStringBody() (string, error) { + var sb strings.Builder + + for { + line, err := p.reader.ReadString(JSON_QUOTE) + if err != nil { + return "", NewTProtocolException(err) + } + sb.WriteString(line) + if endsWithoutEscapedQuote(line) { + return sb.String(), nil + } + } +} + +func endsWithoutEscapedQuote(s string) bool { + l := len(s) + i := 1 + for ; i < l; i++ { + if s[l-i-1] != '\\' { + break + } + } + return i&0x01 == 1 +} + +func (p *TSimpleJSONProtocol) ParseBase64EncodedBody() ([]byte, error) { + line, err := p.reader.ReadBytes(JSON_QUOTE) + if err != nil { + return line, NewTProtocolException(err) + } + line2 := line[0 : len(line)-1] + l := len(line2) + if (l % 4) != 0 { + pad := 4 - (l % 4) + fill := [...]byte{'=', '=', '='} + line2 = append(line2, fill[:pad]...) + l = len(line2) + } + output := make([]byte, base64.StdEncoding.DecodedLen(l)) + n, err := base64.StdEncoding.Decode(output, line2) + return output[0:n], NewTProtocolException(err) +} + +func (p *TSimpleJSONProtocol) ParseI64() (int64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value int64 + var isnull bool + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Int64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseF64() (float64, bool, error) { + if err := p.ParsePreValue(); err != nil { + return 0, false, err + } + var value float64 + var isnull bool + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + isnull = true + } else { + num, err := p.readNumeric() + isnull = (num == nil) + if !isnull { + value = num.Float64() + } + if err != nil { + return value, isnull, err + } + } + return value, isnull, p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { + if err := p.ParsePreValue(); err != nil { + return false, err + } + var b []byte + b, err := p.reader.Peek(1) + if err != nil { + return false, err + } + if len(b) > 0 && b[0] == JSON_LBRACE[0] { + p.reader.ReadByte() + p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST) + return false, nil + } else if p.safePeekContains(JSON_NULL) { + return true, nil + } + e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b)) + return false, NewTProtocolExceptionWithType(INVALID_DATA, e) +} + +func (p *TSimpleJSONProtocol) ParseObjectEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + cxt, _ := p.parseContextStack.peek() + if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) { + e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACE[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of object \"}\", but found: \"%s\"", line) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', '}': + // do nothing + } + } + p.parseContextStack.pop() + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { + if e := p.ParsePreValue(); e != nil { + return false, e + } + var b []byte + b, err = p.reader.Peek(1) + if err != nil { + return false, err + } + if len(b) >= 1 && b[0] == JSON_LBRACKET[0] { + p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST) + p.reader.ReadByte() + isNull = false + } else if p.safePeekContains(JSON_NULL) { + isNull = true + } else { + err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b) + } + return isNull, NewTProtocolExceptionWithType(INVALID_DATA, err) +} + +func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { + if isNull, e := p.ParseListBegin(); isNull || e != nil { + return VOID, 0, e + } + bElemType, _, err := p.ParseI64() + elemType = TType(bElemType) + if err != nil { + return elemType, size, err + } + nSize, _, err := p.ParseI64() + if err != nil { + return elemType, 0, err + } + err = checkSizeForProtocol(int32(nSize), p.cfg) + if err != nil { + return elemType, 0, err + } + size = int(nSize) + return elemType, size, nil +} + +func (p *TSimpleJSONProtocol) ParseListEnd() error { + if isNull, err := p.readIfNull(); isNull || err != nil { + return err + } + cxt, _ := p.parseContextStack.peek() + if cxt != _CONTEXT_IN_LIST { + e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + } + line, err := p.reader.ReadString(JSON_RBRACKET[0]) + if err != nil { + return NewTProtocolException(err) + } + for _, char := range line { + switch char { + default: + e := fmt.Errorf("Expecting end of list \"]\", but found: \"%v\"", line) + return NewTProtocolExceptionWithType(INVALID_DATA, e) + case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]): + // do nothing + } + } + p.parseContextStack.pop() + if cxt, ok := p.parseContextStack.peek(); !ok { + return errEmptyJSONContextStack + } else if cxt == _CONTEXT_IN_TOPLEVEL { + return nil + } + return p.ParsePostValue() +} + +func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { + cont := true + for cont { + b, _ := p.reader.Peek(1) + if len(b) < 1 { + return false, nil + } + switch b[0] { + default: + return false, nil + case JSON_NULL[0]: + cont = false + case ' ', '\n', '\r', '\t': + p.reader.ReadByte() + } + } + if p.safePeekContains(JSON_NULL) { + p.reader.Read(make([]byte, len(JSON_NULL))) + return true, nil + } + return false, nil +} + +func (p *TSimpleJSONProtocol) readQuoteIfNext() { + b, _ := p.reader.Peek(1) + if len(b) > 0 && b[0] == JSON_QUOTE { + p.reader.ReadByte() + } +} + +func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { + isNull, err := p.readIfNull() + if isNull || err != nil { + return NUMERIC_NULL, err + } + hasDecimalPoint := false + nextCanBeSign := true + hasE := false + MAX_LEN := 40 + buf := bytes.NewBuffer(make([]byte, 0, MAX_LEN)) + continueFor := true + inQuotes := false + for continueFor { + c, err := p.reader.ReadByte() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return NUMERIC_NULL, NewTProtocolException(err) + } + switch c { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + buf.WriteByte(c) + nextCanBeSign = false + case '.': + if hasDecimalPoint { + e := fmt.Errorf("Unable to parse number with multiple decimal points '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if hasE { + e := fmt.Errorf("Unable to parse number with decimal points in the exponent '%s.'", buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasDecimalPoint, nextCanBeSign = true, false + case 'e', 'E': + if hasE { + e := fmt.Errorf("Unable to parse number with multiple exponents '%s%c'", buf.String(), c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + hasE, nextCanBeSign = true, true + case '-', '+': + if !nextCanBeSign { + e := fmt.Errorf("Negative sign within number") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + buf.WriteByte(c) + nextCanBeSign = false + case ' ', 0, '\t', '\n', '\r', JSON_RBRACE[0], JSON_RBRACKET[0], JSON_COMMA[0], JSON_COLON[0]: + p.reader.UnreadByte() + continueFor = false + case JSON_NAN[0]: + if buf.Len() == 0 { + buffer := make([]byte, len(JSON_NAN)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NAN != string(buffer) { + e := mismatch(JSON_NAN, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NAN, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_INFINITY[0]: + if buf.Len() == 0 || (buf.Len() == 1 && buf.Bytes()[0] == '+') { + buffer := make([]byte, len(JSON_INFINITY)) + buffer[0] = c + _, e := p.reader.Read(buffer[1:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_INFINITY != string(buffer) { + e := mismatch(JSON_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return INFINITY, nil + } else if buf.Len() == 1 && buf.Bytes()[0] == JSON_NEGATIVE_INFINITY[0] { + buffer := make([]byte, len(JSON_NEGATIVE_INFINITY)) + buffer[0] = JSON_NEGATIVE_INFINITY[0] + buffer[1] = c + _, e := p.reader.Read(buffer[2:]) + if e != nil { + return NUMERIC_NULL, NewTProtocolException(e) + } + if JSON_NEGATIVE_INFINITY != string(buffer) { + e := mismatch(JSON_NEGATIVE_INFINITY, string(buffer)) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + if inQuotes { + p.readQuoteIfNext() + } + return NEGATIVE_INFINITY, nil + } else { + e := fmt.Errorf("Unable to parse number starting with character '%c' due to existing buffer %s", c, buf.String()) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + case JSON_QUOTE: + if !inQuotes { + inQuotes = true + } + default: + e := fmt.Errorf("Unable to parse number starting with character '%c'", c) + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + } + if buf.Len() == 0 { + e := fmt.Errorf("Unable to parse number from empty string ''") + return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) + } + return NewNumericFromJSONString(buf.String(), false), nil +} + +// Safely peeks into the buffer, reading only what is necessary +func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { + for i := range b { + a, _ := p.reader.Peek(i + 1) + if len(a) < (i+1) || a[i] != b[i] { + return false + } + } + return true +} + +// Reset the context stack to its initial state. +func (p *TSimpleJSONProtocol) resetContextStack() { + p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL} + p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL} +} + +func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { + n, err := p.writer.Write(b) + if err != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + } + return n, err +} + +// SetTConfiguration implements TConfigurationSetter for propagation. +func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) { + PropagateTConfiguration(p.trans, conf) + p.cfg = conf +} + +// Reset resets this protocol's internal state. +// +// It's useful when a single protocol instance is reused after errors, to make +// sure the next use will not be in a bad state to begin with. An example is +// when it's used in serializer/deserializer pools. +func (p *TSimpleJSONProtocol) Reset() { + p.resetContextStack() + p.writer.Reset(p.trans) + p.reader.Reset(p.trans) +} + +var ( + _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil) + _ TConfigurationSetter = (*TSimpleJSONProtocolFactory)(nil) +) diff --git a/thrift/simple_json_protocol_test.go b/thrift/simple_json_protocol_test.go index 89753c6..002a231 100644 --- a/thrift/simple_json_protocol_test.go +++ b/thrift/simple_json_protocol_test.go @@ -497,7 +497,7 @@ func TestReadSimpleJSONProtocolBinary(t *testing.T) { if len(v) != len(value) { t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) } - for i := 0; i < len(v); i++ { + for i := range v { if v[i] != value[i] { t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) } diff --git a/thrift/simple_server.go b/thrift/simple_server.go index b3ab08f..a8634fc 100644 --- a/thrift/simple_server.go +++ b/thrift/simple_server.go @@ -37,378 +37,378 @@ import ( // implementations can change its value to control the behavior. // // If it's changed to <=0, the feature will be disabled. - var ServerConnectivityCheckInterval = time.Millisecond * 5 - - // ServerStopTimeout defines max stop wait duration used by - // server stop to avoid hanging too long to wait for all client connections to be closed gracefully. - // - // It's defined as a variable instead of constant, so that thrift server - // implementations can change its value to control the behavior. - // - // If it's set to <=0, the feature will be disabled(by default), and the server will wait for - // for all the client connections to be closed gracefully. - var ServerStopTimeout = time.Duration(0) - - /* - * This is not a typical TSimpleServer as it is not blocked after accept a socket. - * It is more like a TThreadedServer that can handle different connections in different goroutines. - * This will work if golang user implements a conn-pool like thing in client side. - */ - type TSimpleServer struct { - closed atomic.Int32 - wg sync.WaitGroup - mu sync.Mutex - stopChan chan struct{} - - processorFactory TProcessorFactory - serverTransport TServerTransport - inputTransportFactory TTransportFactory - outputTransportFactory TTransportFactory - inputProtocolFactory TProtocolFactory - outputProtocolFactory TProtocolFactory - - // Headers to auto forward in THeaderProtocol - forwardHeaders []string - - logContext atomic.Pointer[context.Context] - } - - func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { - return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport) - } - - func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { - return NewTSimpleServerFactory4(NewTProcessorFactory(processor), - serverTransport, - transportFactory, - protocolFactory, - ) - } - - func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { - return NewTSimpleServerFactory6(NewTProcessorFactory(processor), - serverTransport, - inputTransportFactory, - outputTransportFactory, - inputProtocolFactory, - outputProtocolFactory, - ) - } - - func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer { - return NewTSimpleServerFactory6(processorFactory, - serverTransport, - NewTTransportFactory(), - NewTTransportFactory(), - NewTBinaryProtocolFactoryDefault(), - NewTBinaryProtocolFactoryDefault(), - ) - } - - func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { - return NewTSimpleServerFactory6(processorFactory, - serverTransport, - transportFactory, - transportFactory, - protocolFactory, - protocolFactory, - ) - } - - func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { - return &TSimpleServer{ - processorFactory: processorFactory, - serverTransport: serverTransport, - inputTransportFactory: inputTransportFactory, - outputTransportFactory: outputTransportFactory, - inputProtocolFactory: inputProtocolFactory, - outputProtocolFactory: outputProtocolFactory, - stopChan: make(chan struct{}), - } - } - - func (p *TSimpleServer) ProcessorFactory() TProcessorFactory { - return p.processorFactory - } - - func (p *TSimpleServer) ServerTransport() TServerTransport { - return p.serverTransport - } - - func (p *TSimpleServer) InputTransportFactory() TTransportFactory { - return p.inputTransportFactory - } - - func (p *TSimpleServer) OutputTransportFactory() TTransportFactory { - return p.outputTransportFactory - } - - func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory { - return p.inputProtocolFactory - } - - func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory { - return p.outputProtocolFactory - } - - func (p *TSimpleServer) Listen() error { - return p.serverTransport.Listen() - } - - // SetForwardHeaders sets the list of header keys that will be auto forwarded - // while using THeaderProtocol. - // - // "forward" means that when the server is also a client to other upstream - // thrift servers, the context object user gets in the processor functions will - // have both read and write headers set, with write headers being forwarded. - // Users can always override the write headers by calling SetWriteHeaderList - // before calling thrift client functions. - func (p *TSimpleServer) SetForwardHeaders(headers []string) { - size := len(headers) - if size == 0 { - p.forwardHeaders = nil - return - } - - keys := make([]string, size) - copy(keys, headers) - p.forwardHeaders = keys - } - - // SetLogger sets the logger used by this TSimpleServer. - // - // If no logger was set before Serve is called, a default logger using standard - // log library will be used. - // - // Deprecated: The logging inside TSimpleServer is now done via slog on error - // level, this does nothing now. It will be removed in a future version. - func (p *TSimpleServer) SetLogger(_ Logger) {} - - // SetLogContext sets the context to be used when logging errors inside - // TSimpleServer. - // - // If this is not called before calling Serve, context.Background() will be - // used. - func (p *TSimpleServer) SetLogContext(ctx context.Context) { - p.logContext.Store(&ctx) - } - - func (p *TSimpleServer) innerAccept() (int32, error) { - client, err := p.serverTransport.Accept() - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed.Load() - if closed != 0 { - return closed, nil - } - if err != nil { - return 0, err - } - if client != nil { - ctx, cancel := context.WithCancel(context.Background()) - p.wg.Add(2) - - go func() { - defer p.wg.Done() - defer cancel() - if err := p.processRequests(client); err != nil { - ctx := p.logContext.Load() - slog.ErrorContext(*ctx, "error processing request", "err", err) - } - }() - - go func() { - defer p.wg.Done() - select { - case <-ctx.Done(): - // client exited, do nothing - case <-p.stopChan: - // TSimpleServer.Close called, close the client connection - client.Close() - } - }() - } - return 0, nil - } - - func (p *TSimpleServer) AcceptLoop() error { - for { - closed, err := p.innerAccept() - if err != nil { - return err - } - if closed != 0 { - return nil - } - } - } - - func (p *TSimpleServer) Serve() error { - p.logContext.CompareAndSwap(nil, Pointer(context.Background())) - - err := p.Listen() - if err != nil { - return err - } - p.AcceptLoop() - return nil - } - - func (p *TSimpleServer) Stop() error { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.closed.CompareAndSwap(0, 1) { - // Already closed - return nil - } - p.serverTransport.Interrupt() - - ctx, cancel := context.WithCancel(context.Background()) - go func() { - defer cancel() - p.wg.Wait() - }() - - if ServerStopTimeout > 0 { - timer := time.NewTimer(ServerStopTimeout) - select { - case <-timer.C: - case <-ctx.Done(): - } - close(p.stopChan) - timer.Stop() - } - - <-ctx.Done() - p.stopChan = make(chan struct{}) - return nil - } - - // If err is actually EOF or NOT_OPEN, return nil, otherwise return err as-is. - func treatEOFErrorsAsNil(err error) error { - if err == nil { - return nil - } - if errors.Is(err, io.EOF) { - return nil - } - var te TTransportException - // NOT_OPEN returned by processor.Process is usually caused by client - // abandoning the connection (e.g. client side time out, or just client - // closes connections from the pool because of shutting down). - // Those logs will be very noisy, so suppress those logs as well. - if errors.As(err, &te) && (te.TypeId() == END_OF_FILE || te.TypeId() == NOT_OPEN) { - return nil - } - return err - } - - func (p *TSimpleServer) processRequests(client TTransport) (err error) { - defer func() { - err = treatEOFErrorsAsNil(err) - }() - - processor := p.processorFactory.GetProcessor(client) - inputTransport, err := p.inputTransportFactory.GetTransport(client) - if err != nil { - return err - } - inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) - var outputTransport TTransport - var outputProtocol TProtocol - - // for THeaderProtocol, we must use the same protocol instance for - // input and output so that the response is in the same dialect that - // the server detected the request was in. - headerProtocol, ok := inputProtocol.(*THeaderProtocol) - if ok { - outputProtocol = inputProtocol - } else { - oTrans, err := p.outputTransportFactory.GetTransport(client) - if err != nil { - return err - } - outputTransport = oTrans - outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport) - } - - if inputTransport != nil { - defer inputTransport.Close() - } - if outputTransport != nil { - defer outputTransport.Close() - } - for { - if p.closed.Load() != 0 { - return nil - } - - ctx := SetResponseHelper( - defaultCtx, - TResponseHelper{ - THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol), - }, - ) - if headerProtocol != nil { - // We need to call ReadFrame here, otherwise we won't - // get any headers on the AddReadTHeaderToContext call. - // - // ReadFrame is safe to be called multiple times so it - // won't break when it's called again later when we - // actually start to read the message. - if err := headerProtocol.ReadFrame(ctx); err != nil { - return err - } - ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders()) - ctx = SetWriteHeaderList(ctx, p.forwardHeaders) - } - - ok, err := processor.Process(ctx, inputProtocol, outputProtocol) - if errors.Is(err, ErrAbandonRequest) { - err := client.Close() - if errors.Is(err, net.ErrClosed) { - // In this case, it's kinda expected to get - // net.ErrClosed, treat that as no-error - return nil - } - return err - } - if errors.As(err, new(TTransportException)) && err != nil { - return err - } - var tae TApplicationException - if errors.As(err, &tae) && tae.TypeId() == UNKNOWN_METHOD { - continue - } - if !ok { - break - } - } - return nil - } - - // ErrAbandonRequest is a special error that server handler implementations can - // return to indicate that the request has been abandoned. - // - // TSimpleServer and compiler generated Process functions will check for this - // error, and close the client connection instead of trying to write the error - // back to the client. - // - // It shall only be used when the server handler implementation know that the - // client already abandoned the request (by checking that the passed in context - // is already canceled, for example). - // - // It also implements the interface defined by errors.Unwrap and always unwrap - // to context.Canceled error. - var ErrAbandonRequest = abandonRequestError{} - - type abandonRequestError struct{} - - func (abandonRequestError) Error() string { - return "request abandoned" - } - - func (abandonRequestError) Unwrap() error { - return context.Canceled - } \ No newline at end of file +var ServerConnectivityCheckInterval = time.Millisecond * 5 + +// ServerStopTimeout defines max stop wait duration used by +// server stop to avoid hanging too long to wait for all client connections to be closed gracefully. +// +// It's defined as a variable instead of constant, so that thrift server +// implementations can change its value to control the behavior. +// +// If it's set to <=0, the feature will be disabled(by default), and the server will wait for +// for all the client connections to be closed gracefully. +var ServerStopTimeout = time.Duration(0) + +/* + * This is not a typical TSimpleServer as it is not blocked after accept a socket. + * It is more like a TThreadedServer that can handle different connections in different goroutines. + * This will work if golang user implements a conn-pool like thing in client side. + */ +type TSimpleServer struct { + closed atomic.Int32 + wg sync.WaitGroup + mu sync.Mutex + stopChan chan struct{} + + processorFactory TProcessorFactory + serverTransport TServerTransport + inputTransportFactory TTransportFactory + outputTransportFactory TTransportFactory + inputProtocolFactory TProtocolFactory + outputProtocolFactory TProtocolFactory + + // Headers to auto forward in THeaderProtocol + forwardHeaders []string + + logContext atomic.Pointer[context.Context] +} + +func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport) +} + +func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory4(NewTProcessorFactory(processor), + serverTransport, + transportFactory, + protocolFactory, + ) +} + +func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(NewTProcessorFactory(processor), + serverTransport, + inputTransportFactory, + outputTransportFactory, + inputProtocolFactory, + outputProtocolFactory, + ) +} + +func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + NewTTransportFactory(), + NewTTransportFactory(), + NewTBinaryProtocolFactoryDefault(), + NewTBinaryProtocolFactoryDefault(), + ) +} + +func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { + return NewTSimpleServerFactory6(processorFactory, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + ) +} + +func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { + return &TSimpleServer{ + processorFactory: processorFactory, + serverTransport: serverTransport, + inputTransportFactory: inputTransportFactory, + outputTransportFactory: outputTransportFactory, + inputProtocolFactory: inputProtocolFactory, + outputProtocolFactory: outputProtocolFactory, + stopChan: make(chan struct{}), + } +} + +func (p *TSimpleServer) ProcessorFactory() TProcessorFactory { + return p.processorFactory +} + +func (p *TSimpleServer) ServerTransport() TServerTransport { + return p.serverTransport +} + +func (p *TSimpleServer) InputTransportFactory() TTransportFactory { + return p.inputTransportFactory +} + +func (p *TSimpleServer) OutputTransportFactory() TTransportFactory { + return p.outputTransportFactory +} + +func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory { + return p.inputProtocolFactory +} + +func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory { + return p.outputProtocolFactory +} + +func (p *TSimpleServer) Listen() error { + return p.serverTransport.Listen() +} + +// SetForwardHeaders sets the list of header keys that will be auto forwarded +// while using THeaderProtocol. +// +// "forward" means that when the server is also a client to other upstream +// thrift servers, the context object user gets in the processor functions will +// have both read and write headers set, with write headers being forwarded. +// Users can always override the write headers by calling SetWriteHeaderList +// before calling thrift client functions. +func (p *TSimpleServer) SetForwardHeaders(headers []string) { + size := len(headers) + if size == 0 { + p.forwardHeaders = nil + return + } + + keys := make([]string, size) + copy(keys, headers) + p.forwardHeaders = keys +} + +// SetLogger sets the logger used by this TSimpleServer. +// +// If no logger was set before Serve is called, a default logger using standard +// log library will be used. +// +// Deprecated: The logging inside TSimpleServer is now done via slog on error +// level, this does nothing now. It will be removed in a future version. +func (p *TSimpleServer) SetLogger(_ Logger) {} + +// SetLogContext sets the context to be used when logging errors inside +// TSimpleServer. +// +// If this is not called before calling Serve, context.Background() will be +// used. +func (p *TSimpleServer) SetLogContext(ctx context.Context) { + p.logContext.Store(&ctx) +} + +func (p *TSimpleServer) innerAccept() (int32, error) { + client, err := p.serverTransport.Accept() + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed.Load() + if closed != 0 { + return closed, nil + } + if err != nil { + return 0, err + } + if client != nil { + ctx, cancel := context.WithCancel(context.Background()) + p.wg.Add(2) + + go func() { + defer p.wg.Done() + defer cancel() + if err := p.processRequests(client); err != nil { + ctx := p.logContext.Load() + slog.ErrorContext(*ctx, "error processing request", "err", err) + } + }() + + go func() { + defer p.wg.Done() + select { + case <-ctx.Done(): + // client exited, do nothing + case <-p.stopChan: + // TSimpleServer.Close called, close the client connection + client.Close() + } + }() + } + return 0, nil +} + +func (p *TSimpleServer) AcceptLoop() error { + for { + closed, err := p.innerAccept() + if err != nil { + return err + } + if closed != 0 { + return nil + } + } +} + +func (p *TSimpleServer) Serve() error { + p.logContext.CompareAndSwap(nil, Pointer(context.Background())) + + err := p.Listen() + if err != nil { + return err + } + p.AcceptLoop() + return nil +} + +func (p *TSimpleServer) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.closed.CompareAndSwap(0, 1) { + // Already closed + return nil + } + p.serverTransport.Interrupt() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer cancel() + p.wg.Wait() + }() + + if ServerStopTimeout > 0 { + timer := time.NewTimer(ServerStopTimeout) + select { + case <-timer.C: + case <-ctx.Done(): + } + close(p.stopChan) + timer.Stop() + } + + <-ctx.Done() + p.stopChan = make(chan struct{}) + return nil +} + +// If err is actually EOF or NOT_OPEN, return nil, otherwise return err as-is. +func treatEOFErrorsAsNil(err error) error { + if err == nil { + return nil + } + if errors.Is(err, io.EOF) { + return nil + } + var te TTransportException + // NOT_OPEN returned by processor.Process is usually caused by client + // abandoning the connection (e.g. client side time out, or just client + // closes connections from the pool because of shutting down). + // Those logs will be very noisy, so suppress those logs as well. + if errors.As(err, &te) && (te.TypeId() == END_OF_FILE || te.TypeId() == NOT_OPEN) { + return nil + } + return err +} + +func (p *TSimpleServer) processRequests(client TTransport) (err error) { + defer func() { + err = treatEOFErrorsAsNil(err) + }() + + processor := p.processorFactory.GetProcessor(client) + inputTransport, err := p.inputTransportFactory.GetTransport(client) + if err != nil { + return err + } + inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) + var outputTransport TTransport + var outputProtocol TProtocol + + // for THeaderProtocol, we must use the same protocol instance for + // input and output so that the response is in the same dialect that + // the server detected the request was in. + headerProtocol, ok := inputProtocol.(*THeaderProtocol) + if ok { + outputProtocol = inputProtocol + } else { + oTrans, err := p.outputTransportFactory.GetTransport(client) + if err != nil { + return err + } + outputTransport = oTrans + outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport) + } + + if inputTransport != nil { + defer inputTransport.Close() + } + if outputTransport != nil { + defer outputTransport.Close() + } + for { + if p.closed.Load() != 0 { + return nil + } + + ctx := SetResponseHelper( + defaultCtx, + TResponseHelper{ + THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol), + }, + ) + if headerProtocol != nil { + // We need to call ReadFrame here, otherwise we won't + // get any headers on the AddReadTHeaderToContext call. + // + // ReadFrame is safe to be called multiple times so it + // won't break when it's called again later when we + // actually start to read the message. + if err := headerProtocol.ReadFrame(ctx); err != nil { + return err + } + ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders()) + ctx = SetWriteHeaderList(ctx, p.forwardHeaders) + } + + ok, err := processor.Process(ctx, inputProtocol, outputProtocol) + if errors.Is(err, ErrAbandonRequest) { + err := client.Close() + if errors.Is(err, net.ErrClosed) { + // In this case, it's kinda expected to get + // net.ErrClosed, treat that as no-error + return nil + } + return err + } + if errors.As(err, new(TTransportException)) && err != nil { + return err + } + var tae TApplicationException + if errors.As(err, &tae) && tae.TypeId() == UNKNOWN_METHOD { + continue + } + if !ok { + break + } + } + return nil +} + +// ErrAbandonRequest is a special error that server handler implementations can +// return to indicate that the request has been abandoned. +// +// TSimpleServer and compiler generated Process functions will check for this +// error, and close the client connection instead of trying to write the error +// back to the client. +// +// It shall only be used when the server handler implementation know that the +// client already abandoned the request (by checking that the passed in context +// is already canceled, for example). +// +// It also implements the interface defined by errors.Unwrap and always unwrap +// to context.Canceled error. +var ErrAbandonRequest = abandonRequestError{} + +type abandonRequestError struct{} + +func (abandonRequestError) Error() string { + return "request abandoned" +} + +func (abandonRequestError) Unwrap() error { + return context.Canceled +} diff --git a/thrift/slog.go b/thrift/slog.go index 9648cbb..22545d8 100644 --- a/thrift/slog.go +++ b/thrift/slog.go @@ -27,29 +27,29 @@ import ( // SlogTStructWrapper is a wrapper used by the compiler to wrap TStruct and // TException to be better logged by slog. - type SlogTStructWrapper struct { - Type string `json:"type"` - Value TStruct `json:"value"` - } - - var ( - _ fmt.Stringer = SlogTStructWrapper{} - _ json.Marshaler = SlogTStructWrapper{} - ) - - func (w SlogTStructWrapper) MarshalJSON() ([]byte, error) { - // Use an alias to avoid infinite recursion - type alias SlogTStructWrapper - return json.Marshal(alias(w)) - } - - func (w SlogTStructWrapper) String() string { - var sb strings.Builder - sb.WriteString(w.Type) - if err := json.NewEncoder(&sb).Encode(w.Value); err != nil { - // Should not happen, but just in case - return fmt.Sprintf("%s: %v", w.Type, w.Value) - } - // json encoder will write an additional \n at the end, get rid of it - return strings.TrimSuffix(sb.String(), "\n") - } \ No newline at end of file +type SlogTStructWrapper struct { + Type string `json:"type"` + Value TStruct `json:"value"` +} + +var ( + _ fmt.Stringer = SlogTStructWrapper{} + _ json.Marshaler = SlogTStructWrapper{} +) + +func (w SlogTStructWrapper) MarshalJSON() ([]byte, error) { + // Use an alias to avoid infinite recursion + type alias SlogTStructWrapper + return json.Marshal(alias(w)) +} + +func (w SlogTStructWrapper) String() string { + var sb strings.Builder + sb.WriteString(w.Type) + if err := json.NewEncoder(&sb).Encode(w.Value); err != nil { + // Should not happen, but just in case + return fmt.Sprintf("%s: %v", w.Type, w.Value) + } + // json encoder will write an additional \n at the end, get rid of it + return strings.TrimSuffix(sb.String(), "\n") +} diff --git a/thrift/slog_test.go b/thrift/slog_test.go index ace5884..f4155d4 100644 --- a/thrift/slog_test.go +++ b/thrift/slog_test.go @@ -24,14 +24,14 @@ import ( "strings" "testing" ) - - func TestSlogTStructWrapperJSON(t *testing.T) { - // This test just ensures that we don't have infinite recursion when - // json encoding it. More comprehensive tests are under lib/go/test. - v := SlogTStructWrapper{Type: "foo"} - var sb strings.Builder - if err := json.NewEncoder(&sb).Encode(v); err != nil { - t.Fatal(err) - } - t.Log(strings.TrimSuffix(sb.String(), "\n")) - } \ No newline at end of file + +func TestSlogTStructWrapperJSON(t *testing.T) { + // This test just ensures that we don't have infinite recursion when + // json encoding it. More comprehensive tests are under lib/go/test. + v := SlogTStructWrapper{Type: "foo"} + var sb strings.Builder + if err := json.NewEncoder(&sb).Encode(v); err != nil { + t.Fatal(err) + } + t.Log(strings.TrimSuffix(sb.String(), "\n")) +} diff --git a/thrift/socket.go b/thrift/socket.go index 8009810..2185fb1 100644 --- a/thrift/socket.go +++ b/thrift/socket.go @@ -24,218 +24,218 @@ import ( "net" "time" ) - - type TSocket struct { - conn *socketConn - addr net.Addr - cfg *TConfiguration - } - - // tcpAddr is a naive implementation of net.Addr that does nothing extra. - type tcpAddr string - - var _ net.Addr = tcpAddr("") - - func (ta tcpAddr) Network() string { - return "tcp" - } - - func (ta tcpAddr) String() string { - return string(ta) - } - - // Deprecated: Use NewTSocketConf instead. - func NewTSocket(hostPort string) (*TSocket, error) { - return NewTSocketConf(hostPort, &TConfiguration{ - noPropagation: true, - }), nil - } - - // NewTSocketConf creates a net.Conn-backed TTransport, given a host and port. - // - // Example: - // - // trans := thrift.NewTSocketConf("localhost:9090", &TConfiguration{ - // ConnectTimeout: time.Second, // Use 0 for no timeout - // SocketTimeout: time.Second, // Use 0 for no timeout - // }) - func NewTSocketConf(hostPort string, conf *TConfiguration) *TSocket { - return NewTSocketFromAddrConf(tcpAddr(hostPort), conf) - } - - // Deprecated: Use NewTSocketConf instead. - func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) { - return NewTSocketConf(hostPort, &TConfiguration{ - ConnectTimeout: connTimeout, - SocketTimeout: soTimeout, - - noPropagation: true, - }), nil - } - - // NewTSocketFromAddrConf creates a TSocket from a net.Addr - func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket { - return &TSocket{ - addr: addr, - cfg: conf, - } - } - - // Deprecated: Use NewTSocketFromAddrConf instead. - func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket { - return NewTSocketFromAddrConf(addr, &TConfiguration{ - ConnectTimeout: connTimeout, - SocketTimeout: soTimeout, - - noPropagation: true, - }) - } - - // NewTSocketFromConnConf creates a TSocket from an existing net.Conn. - func NewTSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSocket { - return &TSocket{ - conn: wrapSocketConn(conn), - addr: conn.RemoteAddr(), - cfg: conf, - } - } - - // Deprecated: Use NewTSocketFromConnConf instead. - func NewTSocketFromConnTimeout(conn net.Conn, socketTimeout time.Duration) *TSocket { - return NewTSocketFromConnConf(conn, &TConfiguration{ - SocketTimeout: socketTimeout, - - noPropagation: true, - }) - } - - // SetTConfiguration implements TConfigurationSetter. - // - // It can be used to set connect and socket timeouts. - func (p *TSocket) SetTConfiguration(conf *TConfiguration) { - p.cfg = conf - } - - // Sets the connect timeout - func (p *TSocket) SetConnTimeout(timeout time.Duration) error { - if p.cfg == nil { - p.cfg = &TConfiguration{ - noPropagation: true, - } - } - p.cfg.ConnectTimeout = timeout - return nil - } - - // Sets the socket timeout - func (p *TSocket) SetSocketTimeout(timeout time.Duration) error { - if p.cfg == nil { - p.cfg = &TConfiguration{ - noPropagation: true, - } - } - p.cfg.SocketTimeout = timeout - return nil - } - - func (p *TSocket) pushDeadline(read, write bool) { - var t time.Time - if timeout := p.cfg.GetSocketTimeout(); timeout > 0 { - t = time.Now().Add(time.Duration(timeout)) - } - if read && write { - p.conn.SetDeadline(t) - } else if read { - p.conn.SetReadDeadline(t) - } else if write { - p.conn.SetWriteDeadline(t) - } - } - - // Connects the socket, creating a new socket object if necessary. - func (p *TSocket) Open() error { - if p.conn.isValid() { - return NewTTransportException(ALREADY_OPEN, "Socket already connected.") - } - if p.addr == nil { - return NewTTransportException(NOT_OPEN, "Cannot open nil address.") - } - if len(p.addr.Network()) == 0 { - return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") - } - if len(p.addr.String()) == 0 { - return NewTTransportException(NOT_OPEN, "Cannot open bad address.") - } - var err error - if p.conn, err = createSocketConnFromReturn(net.DialTimeout( - p.addr.Network(), - p.addr.String(), - p.cfg.GetConnectTimeout(), - )); err != nil { - return &tTransportException{ - typeId: NOT_OPEN, - err: err, - msg: err.Error(), - } - } - p.addr = p.conn.RemoteAddr() - return nil - } - - // Retrieve the underlying net.Conn - func (p *TSocket) Conn() net.Conn { - return p.conn - } - - // Returns true if the connection is open - func (p *TSocket) IsOpen() bool { - return p.conn.IsOpen() - } - - // Closes the socket. - func (p *TSocket) Close() error { - return p.conn.Close() - } - - //Returns the remote address of the socket. - func (p *TSocket) Addr() net.Addr { - return p.addr - } - - func (p *TSocket) Read(buf []byte) (int, error) { - if !p.conn.isValid() { - return 0, NewTTransportException(NOT_OPEN, "Connection not open") - } - p.pushDeadline(true, false) - // NOTE: Calling any of p.IsOpen, p.conn.read0, or p.conn.IsOpen between - // p.pushDeadline and p.conn.Read could cause the deadline set inside - // p.pushDeadline being reset, thus need to be avoided. - n, err := p.conn.Read(buf) - return n, NewTTransportExceptionFromError(err) - } - - func (p *TSocket) Write(buf []byte) (int, error) { - if !p.conn.isValid() { - return 0, NewTTransportException(NOT_OPEN, "Connection not open") - } - p.pushDeadline(false, true) - return p.conn.Write(buf) - } - - func (p *TSocket) Flush(ctx context.Context) error { - return nil - } - - func (p *TSocket) Interrupt() error { - if !p.conn.isValid() { - return nil - } - return p.conn.Close() - } - - func (p *TSocket) RemainingBytes() (num_bytes uint64) { - const maxSize = ^uint64(0) - return maxSize // the truth is, we just don't know unless framed is used - } - - var _ TConfigurationSetter = (*TSocket)(nil) \ No newline at end of file + +type TSocket struct { + conn *socketConn + addr net.Addr + cfg *TConfiguration +} + +// tcpAddr is a naive implementation of net.Addr that does nothing extra. +type tcpAddr string + +var _ net.Addr = tcpAddr("") + +func (ta tcpAddr) Network() string { + return "tcp" +} + +func (ta tcpAddr) String() string { + return string(ta) +} + +// Deprecated: Use NewTSocketConf instead. +func NewTSocket(hostPort string) (*TSocket, error) { + return NewTSocketConf(hostPort, &TConfiguration{ + noPropagation: true, + }), nil +} + +// NewTSocketConf creates a net.Conn-backed TTransport, given a host and port. +// +// Example: +// +// trans := thrift.NewTSocketConf("localhost:9090", &TConfiguration{ +// ConnectTimeout: time.Second, // Use 0 for no timeout +// SocketTimeout: time.Second, // Use 0 for no timeout +// }) +func NewTSocketConf(hostPort string, conf *TConfiguration) *TSocket { + return NewTSocketFromAddrConf(tcpAddr(hostPort), conf) +} + +// Deprecated: Use NewTSocketConf instead. +func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) { + return NewTSocketConf(hostPort, &TConfiguration{ + ConnectTimeout: connTimeout, + SocketTimeout: soTimeout, + + noPropagation: true, + }), nil +} + +// NewTSocketFromAddrConf creates a TSocket from a net.Addr +func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket { + return &TSocket{ + addr: addr, + cfg: conf, + } +} + +// Deprecated: Use NewTSocketFromAddrConf instead. +func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket { + return NewTSocketFromAddrConf(addr, &TConfiguration{ + ConnectTimeout: connTimeout, + SocketTimeout: soTimeout, + + noPropagation: true, + }) +} + +// NewTSocketFromConnConf creates a TSocket from an existing net.Conn. +func NewTSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSocket { + return &TSocket{ + conn: wrapSocketConn(conn), + addr: conn.RemoteAddr(), + cfg: conf, + } +} + +// Deprecated: Use NewTSocketFromConnConf instead. +func NewTSocketFromConnTimeout(conn net.Conn, socketTimeout time.Duration) *TSocket { + return NewTSocketFromConnConf(conn, &TConfiguration{ + SocketTimeout: socketTimeout, + + noPropagation: true, + }) +} + +// SetTConfiguration implements TConfigurationSetter. +// +// It can be used to set connect and socket timeouts. +func (p *TSocket) SetTConfiguration(conf *TConfiguration) { + p.cfg = conf +} + +// Sets the connect timeout +func (p *TSocket) SetConnTimeout(timeout time.Duration) error { + if p.cfg == nil { + p.cfg = &TConfiguration{ + noPropagation: true, + } + } + p.cfg.ConnectTimeout = timeout + return nil +} + +// Sets the socket timeout +func (p *TSocket) SetSocketTimeout(timeout time.Duration) error { + if p.cfg == nil { + p.cfg = &TConfiguration{ + noPropagation: true, + } + } + p.cfg.SocketTimeout = timeout + return nil +} + +func (p *TSocket) pushDeadline(read, write bool) { + var t time.Time + if timeout := p.cfg.GetSocketTimeout(); timeout > 0 { + t = time.Now().Add(time.Duration(timeout)) + } + if read && write { + p.conn.SetDeadline(t) + } else if read { + p.conn.SetReadDeadline(t) + } else if write { + p.conn.SetWriteDeadline(t) + } +} + +// Connects the socket, creating a new socket object if necessary. +func (p *TSocket) Open() error { + if p.conn.isValid() { + return NewTTransportException(ALREADY_OPEN, "Socket already connected.") + } + if p.addr == nil { + return NewTTransportException(NOT_OPEN, "Cannot open nil address.") + } + if len(p.addr.Network()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") + } + if len(p.addr.String()) == 0 { + return NewTTransportException(NOT_OPEN, "Cannot open bad address.") + } + var err error + if p.conn, err = createSocketConnFromReturn(net.DialTimeout( + p.addr.Network(), + p.addr.String(), + p.cfg.GetConnectTimeout(), + )); err != nil { + return &tTransportException{ + typeId: NOT_OPEN, + err: err, + msg: err.Error(), + } + } + p.addr = p.conn.RemoteAddr() + return nil +} + +// Retrieve the underlying net.Conn +func (p *TSocket) Conn() net.Conn { + return p.conn +} + +// Returns true if the connection is open +func (p *TSocket) IsOpen() bool { + return p.conn.IsOpen() +} + +// Closes the socket. +func (p *TSocket) Close() error { + return p.conn.Close() +} + +//Returns the remote address of the socket. +func (p *TSocket) Addr() net.Addr { + return p.addr +} + +func (p *TSocket) Read(buf []byte) (int, error) { + if !p.conn.isValid() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(true, false) + // NOTE: Calling any of p.IsOpen, p.conn.read0, or p.conn.IsOpen between + // p.pushDeadline and p.conn.Read could cause the deadline set inside + // p.pushDeadline being reset, thus need to be avoided. + n, err := p.conn.Read(buf) + return n, NewTTransportExceptionFromError(err) +} + +func (p *TSocket) Write(buf []byte) (int, error) { + if !p.conn.isValid() { + return 0, NewTTransportException(NOT_OPEN, "Connection not open") + } + p.pushDeadline(false, true) + return p.conn.Write(buf) +} + +func (p *TSocket) Flush(ctx context.Context) error { + return nil +} + +func (p *TSocket) Interrupt() error { + if !p.conn.isValid() { + return nil + } + return p.conn.Close() +} + +func (p *TSocket) RemainingBytes() (num_bytes uint64) { + const maxSize = ^uint64(0) + return maxSize // the truth is, we just don't know unless framed is used +} + +var _ TConfigurationSetter = (*TSocket)(nil) diff --git a/thrift/ssl_server_socket.go b/thrift/ssl_server_socket.go index 907afca..3f05ad9 100644 --- a/thrift/ssl_server_socket.go +++ b/thrift/ssl_server_socket.go @@ -93,6 +93,9 @@ func (p *TSSLServerSocket) Open() error { } func (p *TSSLServerSocket) Addr() net.Addr { + if p.listener != nil { + return p.listener.Addr() + } return p.addr } diff --git a/thrift/transport_test.go b/thrift/transport_test.go index 309cc28..b6263b8 100644 --- a/thrift/transport_test.go +++ b/thrift/transport_test.go @@ -36,7 +36,7 @@ var ( func init() { transport_bdata = make([]byte, TRANSPORT_BINARY_DATA_SIZE) - for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ { + for i := range TRANSPORT_BINARY_DATA_SIZE { transport_bdata[i] = byte((i + 'a') % 255) } transport_header = map[string]string{"key": "User-Agent", diff --git a/thrift/zlib_transport.go b/thrift/zlib_transport.go index bee71a2..cefe1f9 100644 --- a/thrift/zlib_transport.go +++ b/thrift/zlib_transport.go @@ -26,112 +26,112 @@ import ( ) // TZlibTransportFactory is a factory for TZlibTransport instances - type TZlibTransportFactory struct { - level int - factory TTransportFactory - } - - // TZlibTransport is a TTransport implementation that makes use of zlib compression. - type TZlibTransport struct { - reader io.ReadCloser - transport TTransport - writer *zlib.Writer - } - - // GetTransport constructs a new instance of NewTZlibTransport - func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) { - if p.factory != nil { - // wrap other factory - var err error - trans, err = p.factory.GetTransport(trans) - if err != nil { - return nil, err - } - } - return NewTZlibTransport(trans, p.level) - } - - // NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory - func NewTZlibTransportFactory(level int) *TZlibTransportFactory { - return &TZlibTransportFactory{level: level, factory: nil} - } - - // NewTZlibTransportFactoryWithFactory constructs a new instance of TZlibTransportFactory - // as a wrapper over existing transport factory - func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory { - return &TZlibTransportFactory{level: level, factory: factory} - } - - // NewTZlibTransport constructs a new instance of TZlibTransport - func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) { - w, err := zlib.NewWriterLevel(trans, level) - if err != nil { - return nil, err - } - - return &TZlibTransport{ - writer: w, - transport: trans, - }, nil - } - - // Close closes the reader and writer (flushing any unwritten data) and closes - // the underlying transport. - func (z *TZlibTransport) Close() error { - if z.reader != nil { - if err := z.reader.Close(); err != nil { - return err - } - } - if err := z.writer.Close(); err != nil { - return err - } - return z.transport.Close() - } - - // Flush flushes the writer and its underlying transport. - func (z *TZlibTransport) Flush(ctx context.Context) error { - if err := z.writer.Flush(); err != nil { - return err - } - return z.transport.Flush(ctx) - } - - // IsOpen returns true if the transport is open - func (z *TZlibTransport) IsOpen() bool { - return z.transport.IsOpen() - } - - // Open opens the transport for communication - func (z *TZlibTransport) Open() error { - return z.transport.Open() - } - - func (z *TZlibTransport) Read(p []byte) (int, error) { - if z.reader == nil { - r, err := zlib.NewReader(z.transport) - if err != nil { - return 0, NewTTransportExceptionFromError(err) - } - z.reader = r - } - - return z.reader.Read(p) - } - - // RemainingBytes returns the size in bytes of the data that is still to be - // read. - func (z *TZlibTransport) RemainingBytes() uint64 { - return z.transport.RemainingBytes() - } - - func (z *TZlibTransport) Write(p []byte) (int, error) { - return z.writer.Write(p) - } - - // SetTConfiguration implements TConfigurationSetter for propagation. - func (z *TZlibTransport) SetTConfiguration(conf *TConfiguration) { - PropagateTConfiguration(z.transport, conf) - } - - var _ TConfigurationSetter = (*TZlibTransport)(nil) \ No newline at end of file +type TZlibTransportFactory struct { + level int + factory TTransportFactory +} + +// TZlibTransport is a TTransport implementation that makes use of zlib compression. +type TZlibTransport struct { + reader io.ReadCloser + transport TTransport + writer *zlib.Writer +} + +// GetTransport constructs a new instance of NewTZlibTransport +func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + if p.factory != nil { + // wrap other factory + var err error + trans, err = p.factory.GetTransport(trans) + if err != nil { + return nil, err + } + } + return NewTZlibTransport(trans, p.level) +} + +// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory +func NewTZlibTransportFactory(level int) *TZlibTransportFactory { + return &TZlibTransportFactory{level: level, factory: nil} +} + +// NewTZlibTransportFactoryWithFactory constructs a new instance of TZlibTransportFactory +// as a wrapper over existing transport factory +func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory { + return &TZlibTransportFactory{level: level, factory: factory} +} + +// NewTZlibTransport constructs a new instance of TZlibTransport +func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) { + w, err := zlib.NewWriterLevel(trans, level) + if err != nil { + return nil, err + } + + return &TZlibTransport{ + writer: w, + transport: trans, + }, nil +} + +// Close closes the reader and writer (flushing any unwritten data) and closes +// the underlying transport. +func (z *TZlibTransport) Close() error { + if z.reader != nil { + if err := z.reader.Close(); err != nil { + return err + } + } + if err := z.writer.Close(); err != nil { + return err + } + return z.transport.Close() +} + +// Flush flushes the writer and its underlying transport. +func (z *TZlibTransport) Flush(ctx context.Context) error { + if err := z.writer.Flush(); err != nil { + return err + } + return z.transport.Flush(ctx) +} + +// IsOpen returns true if the transport is open +func (z *TZlibTransport) IsOpen() bool { + return z.transport.IsOpen() +} + +// Open opens the transport for communication +func (z *TZlibTransport) Open() error { + return z.transport.Open() +} + +func (z *TZlibTransport) Read(p []byte) (int, error) { + if z.reader == nil { + r, err := zlib.NewReader(z.transport) + if err != nil { + return 0, NewTTransportExceptionFromError(err) + } + z.reader = r + } + + return z.reader.Read(p) +} + +// RemainingBytes returns the size in bytes of the data that is still to be +// read. +func (z *TZlibTransport) RemainingBytes() uint64 { + return z.transport.RemainingBytes() +} + +func (z *TZlibTransport) Write(p []byte) (int, error) { + return z.writer.Write(p) +} + +// SetTConfiguration implements TConfigurationSetter for propagation. +func (z *TZlibTransport) SetTConfiguration(conf *TConfiguration) { + PropagateTConfiguration(z.transport, conf) +} + +var _ TConfigurationSetter = (*TZlibTransport)(nil)