diff --git a/dialer.go b/dialer.go index a1bb5a3..5b952bf 100644 --- a/dialer.go +++ b/dialer.go @@ -66,9 +66,10 @@ func (d *Dialer) Dial(ctx context.Context, addr string) (Driver, error) { timeout: d.Timeout, config: config, meta: &meta{ - trace: config.Trace, - database: config.Database, - credentials: config.Credentials, + trace: config.Trace, + database: config.Database, + credentials: config.Credentials, + requestsType: config.RequestsType, }, }).dial(ctx, addr) } diff --git a/driver.go b/driver.go index d8e57cd..0b1f5e3 100644 --- a/driver.go +++ b/driver.go @@ -149,6 +149,10 @@ type DriverConfig struct { // is, currently this option may be called as experimental. // You have been warned. PreferLocalEndpoints bool + + // RequestsType set an additional type hint to all requests. + // It is needed only for debug purposes and advanced cases. + RequestsType string } func (d *DriverConfig) withDefaults() (c DriverConfig) { diff --git a/meta.go b/meta.go index bd254c0..48ae4ab 100644 --- a/meta.go +++ b/meta.go @@ -8,15 +8,18 @@ import ( ) const ( - metaDatabase = "x-ydb-database" - metaTicket = "x-ydb-auth-ticket" - metaVersion = "x-ydb-sdk-build-info" + metaDatabase = "x-ydb-database" + metaTicket = "x-ydb-auth-ticket" + metaVersion = "x-ydb-sdk-build-info" + metaRequestType = "x-ydb-request-type" + metaTeraceID = "x-ydb-trace-id" ) type meta struct { - trace DriverTrace - credentials Credentials - database string + trace DriverTrace + credentials Credentials + database string + requestsType string once sync.Once mu sync.RWMutex @@ -25,10 +28,17 @@ type meta struct { } func (m *meta) make() metadata.MD { - return metadata.New(map[string]string{ + newMeta := metadata.New(map[string]string{ metaDatabase: m.database, metaVersion: Version, }) + if m.requestsType != "" { + newMeta.Set(metaRequestType, m.requestsType) + } + if m.token != "" { + newMeta.Set(metaTicket, m.token) + } + return newMeta } func (m *meta) md(ctx context.Context) (md metadata.MD, _ error) { @@ -52,6 +62,9 @@ func (m *meta) md(ctx context.Context) (md metadata.MD, _ error) { // Continue. case ErrCredentialsDropToken: + m.mu.Lock() + defer m.mu.Unlock() + m.token = "" return m.make(), nil case ErrCredentialsKeepToken: @@ -77,10 +90,6 @@ func (m *meta) md(ctx context.Context) (md metadata.MD, _ error) { } m.token = token - m.curr = make(metadata.MD, 3) - m.curr.Set(metaDatabase, m.database) - m.curr.Set(metaTicket, m.token) - m.curr.Set(metaVersion, Version) - + m.curr = m.make() return m.curr, nil } diff --git a/meta_test.go b/meta_test.go index 3834651..ad63b95 100644 --- a/meta_test.go +++ b/meta_test.go @@ -10,7 +10,8 @@ import ( func TestMetaErrDropToken(t *testing.T) { var call int m := &meta{ - database: "database", + database: "database", + requestsType: "requestType", credentials: CredentialsFunc(func(context.Context) (string, error) { if call == 0 { call++ @@ -26,6 +27,7 @@ func TestMetaErrDropToken(t *testing.T) { } assertMetaHasDatabase(t, md1) assertMetaHasToken(t, md1) + assertMetaHasRequestType(t, md1) md2, err := m.md(context.Background()) if err != nil { @@ -33,12 +35,14 @@ func TestMetaErrDropToken(t *testing.T) { } assertMetaHasDatabase(t, md2) assertMetaHasNoToken(t, md2) + assertMetaHasRequestType(t, md2) } func TestMetaErrKeepToken(t *testing.T) { var call int m := &meta{ - database: "database", + database: "database", + requestsType: "requestType", credentials: CredentialsFunc(func(context.Context) (string, error) { if call == 0 { call++ @@ -54,13 +58,15 @@ func TestMetaErrKeepToken(t *testing.T) { } assertMetaHasDatabase(t, md1) assertMetaHasToken(t, md1) + assertMetaHasRequestType(t, md1) md2, err := m.md(context.Background()) if err != nil { t.Fatal(err) } assertMetaHasDatabase(t, md2) - assertMetaHasToken(t, md1) + assertMetaHasToken(t, md2) + assertMetaHasRequestType(t, md2) } func assertMetaHasDatabase(t *testing.T, md metadata.MD) { @@ -78,3 +84,8 @@ func assertMetaHasNoToken(t *testing.T, md metadata.MD) { t.Errorf("unexpected token info in meta") } } +func assertMetaHasRequestType(t *testing.T, md metadata.MD) { + if len(md.Get(metaRequestType)) == 0 { + t.Errorf("no request type info in meta") + } +}