Skip to content

Commit 5695e3a

Browse files
committed
feat: Auth support
- Added support for Basic and API token Waiting on amikos-tech/chromadb-chart#39 to be implemented to add the integration tests. Refs: #2
1 parent a873a3b commit 5695e3a

File tree

2 files changed

+226
-3
lines changed

2 files changed

+226
-3
lines changed

chroma.go

+120-2
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,129 @@ type Client struct {
5858
ApiClient *openapiclient.APIClient //nolint
5959
}
6060

61-
func NewClient(basePath string) *Client {
61+
type AuthType string
62+
63+
const (
64+
BASIC AuthType = "basic"
65+
TokenAuthorization AuthType = "authorization"
66+
TokenXChromaToken AuthType = "xchromatoken"
67+
)
68+
69+
type AuthMethod interface {
70+
GetCredentials() map[string]string
71+
GetType() AuthType
72+
}
73+
74+
type BasicAuth struct {
75+
Username string
76+
Password string
77+
}
78+
79+
func (b BasicAuth) GetCredentials() map[string]string {
80+
return map[string]string{
81+
"username": b.Username,
82+
"password": b.Password,
83+
}
84+
}
85+
86+
func (b BasicAuth) GetType() AuthType {
87+
return BASIC
88+
}
89+
90+
func NewBasicAuth(username string, password string) ClientAuthCredentials {
91+
return ClientAuthCredentials{
92+
AuthMethod: BasicAuth{
93+
Username: username,
94+
Password: password,
95+
},
96+
}
97+
}
98+
99+
type AuthorizationTokenAuth struct {
100+
Token string
101+
}
102+
103+
func (t AuthorizationTokenAuth) GetType() AuthType {
104+
return TokenAuthorization
105+
}
106+
107+
func (t AuthorizationTokenAuth) GetCredentials() map[string]string {
108+
return map[string]string{
109+
"Authorization": "Bearer " + t.Token,
110+
}
111+
}
112+
113+
type XChromaTokenAuth struct {
114+
Token string
115+
}
116+
117+
func (t XChromaTokenAuth) GetType() AuthType {
118+
return TokenXChromaToken
119+
}
120+
121+
func (t XChromaTokenAuth) GetCredentials() map[string]string {
122+
return map[string]string{
123+
"X-Chroma-Token": t.Token,
124+
}
125+
}
126+
127+
type ClientAuthCredentials struct {
128+
AuthMethod AuthMethod
129+
}
130+
131+
func NewTokenAuth(token string, authType AuthType) ClientAuthCredentials {
132+
switch {
133+
case authType == TokenAuthorization:
134+
return ClientAuthCredentials{
135+
AuthMethod: AuthorizationTokenAuth{
136+
Token: token,
137+
},
138+
}
139+
case authType == TokenXChromaToken:
140+
return ClientAuthCredentials{
141+
AuthMethod: XChromaTokenAuth{
142+
Token: token,
143+
},
144+
}
145+
default:
146+
panic("Invalid auth type")
147+
}
148+
}
149+
150+
type ClientConfig struct {
151+
BasePath string
152+
DefaultHeaders *map[string]string
153+
ClientAuthCredentials *ClientAuthCredentials
154+
}
155+
156+
func NewClientConfig(basePath string, defaultHeaders *map[string]string, clientAuthCredentials *ClientAuthCredentials) ClientConfig {
157+
return ClientConfig{
158+
BasePath: basePath,
159+
DefaultHeaders: defaultHeaders,
160+
ClientAuthCredentials: clientAuthCredentials,
161+
}
162+
}
163+
164+
func NewClient(config ClientConfig) *Client {
62165
configuration := openapiclient.NewConfiguration()
166+
if config.ClientAuthCredentials != nil {
167+
// combine config.DefaultHeaders and config.AuthMethod.GetCredentials() maps
168+
var headers = make(map[string]string)
169+
if config.DefaultHeaders != nil {
170+
for k, v := range *config.DefaultHeaders {
171+
headers[k] = v
172+
}
173+
}
174+
for k, v := range config.ClientAuthCredentials.AuthMethod.GetCredentials() {
175+
headers[k] = v
176+
}
177+
configuration.DefaultHeader = headers
178+
} else if config.DefaultHeaders != nil {
179+
configuration.DefaultHeader = *config.DefaultHeaders
180+
}
63181
configuration.Servers = openapiclient.ServerConfigurations{
64182
{
65-
URL: basePath,
183+
URL: config.BasePath,
66184
Description: "No description provided",
67185
},
68186
}

test/chroma_client_test.go

+106-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ func Test_chroma_client(t *testing.T) {
2626
if chromaURL == "" {
2727
chromaURL = "http://localhost:8000"
2828
}
29-
client := chroma.NewClient(chromaURL)
29+
30+
clientConfig := chroma.NewClientConfig(chromaURL, nil, nil)
31+
client := chroma.NewClient(clientConfig)
3032

3133
t.Run("Test Heartbeat", func(t *testing.T) {
3234
resp, err := client.Heartbeat()
@@ -746,3 +748,106 @@ func Test_chroma_client(t *testing.T) {
746748
require.Nil(t, addError)
747749
})
748750
}
751+
752+
func Test_chroma_client_with_basic(t *testing.T) {
753+
chromaURL := os.Getenv("CHROMA_URL")
754+
if chromaURL == "" {
755+
chromaURL = "http://localhost:8003"
756+
}
757+
clientAuth := chroma.NewBasicAuth("test", "test")
758+
759+
clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
760+
client := chroma.NewClient(clientConfig)
761+
762+
t.Run("Test Heartbeat", func(t *testing.T) {
763+
resp, err := client.Heartbeat()
764+
765+
require.Nil(t, err)
766+
require.NotNil(t, resp)
767+
assert.Truef(t, resp["nanosecond heartbeat"] > 0, "Heartbeat should be greater than 0")
768+
})
769+
}
770+
771+
func Test_chroma_client_with_authorization_token(t *testing.T) {
772+
chromaURL := os.Getenv("CHROMA_URL")
773+
if chromaURL == "" {
774+
chromaURL = "http://localhost:8001"
775+
}
776+
clientAuth := chroma.NewTokenAuth("test", chroma.TokenAuthorization)
777+
778+
clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
779+
client := chroma.NewClient(clientConfig)
780+
781+
t.Run("Test List Collections", func(t *testing.T) {
782+
collectionName1 := "test-collection1"
783+
collectionName2 := "test-collection2"
784+
metadata := map[string]string{}
785+
apiKey := os.Getenv("OPENAI_API_KEY")
786+
if apiKey == "" {
787+
err := godotenv.Load("../.env")
788+
if err != nil {
789+
assert.Failf(t, "Error loading .env file", "%s", err)
790+
}
791+
apiKey = os.Getenv("OPENAI_API_KEY")
792+
}
793+
embeddingFunction := openai.NewOpenAIEmbeddingFunction(apiKey)
794+
distanceFunction := chroma.L2
795+
_, errRest := client.Reset()
796+
if errRest != nil {
797+
assert.Fail(t, fmt.Sprintf("Error resetting database: %s", errRest))
798+
}
799+
_, _ = client.CreateCollection(collectionName1, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
800+
_, _ = client.CreateCollection(collectionName2, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
801+
collections, gcerr := client.ListCollections()
802+
require.Nil(t, gcerr)
803+
assert.Equal(t, 2, len(collections))
804+
names := make([]string, len(collections))
805+
for i, person := range collections {
806+
names[i] = person.Name
807+
}
808+
assert.Contains(t, names, collectionName1)
809+
assert.Contains(t, names, collectionName2)
810+
})
811+
}
812+
813+
func Test_chroma_client_with_x_token(t *testing.T) {
814+
chromaURL := os.Getenv("CHROMA_URL")
815+
if chromaURL == "" {
816+
chromaURL = "http://localhost:8002"
817+
}
818+
clientAuth := chroma.NewTokenAuth("test", chroma.TokenXChromaToken)
819+
820+
clientConfig := chroma.NewClientConfig(chromaURL, nil, &clientAuth)
821+
client := chroma.NewClient(clientConfig)
822+
823+
t.Run("Test List Collections", func(t *testing.T) {
824+
collectionName1 := "test-collection1"
825+
collectionName2 := "test-collection2"
826+
metadata := map[string]string{}
827+
apiKey := os.Getenv("OPENAI_API_KEY")
828+
if apiKey == "" {
829+
err := godotenv.Load("../.env")
830+
if err != nil {
831+
assert.Failf(t, "Error loading .env file", "%s", err)
832+
}
833+
apiKey = os.Getenv("OPENAI_API_KEY")
834+
}
835+
embeddingFunction := openai.NewOpenAIEmbeddingFunction(apiKey)
836+
distanceFunction := chroma.L2
837+
_, errRest := client.Reset()
838+
if errRest != nil {
839+
assert.Fail(t, fmt.Sprintf("Error resetting database: %s", errRest))
840+
}
841+
_, _ = client.CreateCollection(collectionName1, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
842+
_, _ = client.CreateCollection(collectionName2, chroma.MapToAPI(metadata), true, embeddingFunction, distanceFunction)
843+
collections, gcerr := client.ListCollections()
844+
require.Nil(t, gcerr)
845+
assert.Equal(t, 2, len(collections))
846+
names := make([]string, len(collections))
847+
for i, person := range collections {
848+
names[i] = person.Name
849+
}
850+
assert.Contains(t, names, collectionName1)
851+
assert.Contains(t, names, collectionName2)
852+
})
853+
}

0 commit comments

Comments
 (0)