Skip to content

Commit bda5f60

Browse files
committed
feat: add ability to pass request-specific env vars to chat completion
This will allow authentication per-request in model providers. Signed-off-by: Donnie Adams <[email protected]>
1 parent 50489f2 commit bda5f60

File tree

11 files changed

+54
-52
lines changed

11 files changed

+54
-52
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1616
github.com/google/uuid v1.6.0
1717
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
18-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
18+
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131
1919
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
2020
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e
2121
github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
200200
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
201201
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
202202
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
203-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
204-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
203+
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131 h1:y2FcmT4X8U606gUS0teX5+JWX9K/NclsLEhHiyrd+EU=
204+
github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
205205
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
206206
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
207207
github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA=

pkg/context/context.go

-11
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,3 @@ func GetLogger(ctx context.Context) mvl.Logger {
4646

4747
return l
4848
}
49-
50-
type envKey struct{}
51-
52-
func WithEnv(ctx context.Context, env []string) context.Context {
53-
return context.WithValue(ctx, envKey{}, env)
54-
}
55-
56-
func GetEnv(ctx context.Context) []string {
57-
l, _ := ctx.Value(envKey{}).([]string)
58-
return l
59-
}

pkg/engine/engine.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@ import (
88
"sync"
99

1010
"github.com/gptscript-ai/gptscript/pkg/config"
11-
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1211
"github.com/gptscript-ai/gptscript/pkg/counter"
1312
"github.com/gptscript-ai/gptscript/pkg/types"
1413
"github.com/gptscript-ai/gptscript/pkg/version"
1514
)
1615

1716
type Model interface {
18-
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
17+
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
1918
ProxyInfo() (string, string, error)
2019
}
2120

@@ -389,7 +388,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
389388
}
390389
}()
391390

392-
resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress)
391+
resp, err := e.Model.Call(ctx, state.Completion, e.Env, progress)
393392
if err != nil {
394393
return nil, err
395394
}

pkg/llm/proxy.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
5454

5555
var (
5656
model string
57-
data = map[string]any{}
57+
data map[string]any
5858
)
5959

6060
if json.Unmarshal(inBytes, &data) == nil {
@@ -65,7 +65,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
6565
model = builtin.GetDefaultModel()
6666
}
6767

68-
c, err := r.getClient(req.Context(), model)
68+
c, err := r.getClient(req.Context(), model, nil)
6969
if err != nil {
7070
http.Error(w, err.Error(), http.StatusInternalServerError)
7171
return

pkg/llm/registry.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
)
1616

1717
type Client interface {
18-
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
18+
Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
1919
ListModels(ctx context.Context, providers ...string) (result []string, _ error)
2020
Supports(ctx context.Context, modelName string) (bool, error)
2121
}
@@ -78,7 +78,7 @@ func (r *Registry) fastPath(modelName string) Client {
7878
return r.clients[0]
7979
}
8080

81-
func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
81+
func (r *Registry) getClient(ctx context.Context, modelName string, env []string) (Client, error) {
8282
if c := r.fastPath(modelName); c != nil {
8383
return c, nil
8484
}
@@ -101,7 +101,7 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err
101101

102102
if len(errs) > 0 && oaiClient != nil {
103103
// Prompt the user to enter their OpenAI API key and try again.
104-
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
104+
if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil {
105105
return nil, err
106106
}
107107
ok, err := oaiClient.Supports(ctx, modelName)
@@ -119,13 +119,13 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err
119119
return nil, errors.Join(errs...)
120120
}
121121

122-
func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
122+
func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
123123
if messageRequest.Model == "" {
124124
return nil, fmt.Errorf("model is required")
125125
}
126126

127127
if c := r.fastPath(messageRequest.Model); c != nil {
128-
return c.Call(ctx, messageRequest, status)
128+
return c.Call(ctx, messageRequest, env, status)
129129
}
130130

131131
var errs []error
@@ -140,20 +140,20 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
140140

141141
errs = append(errs, err)
142142
} else if ok {
143-
return client.Call(ctx, messageRequest, status)
143+
return client.Call(ctx, messageRequest, env, status)
144144
}
145145
}
146146

147147
if len(errs) > 0 && oaiClient != nil {
148148
// Prompt the user to enter their OpenAI API key and try again.
149-
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
149+
if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil {
150150
return nil, err
151151
}
152152
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
153153
if err != nil {
154154
return nil, err
155155
} else if ok {
156-
return oaiClient.Call(ctx, messageRequest, status)
156+
return oaiClient.Call(ctx, messageRequest, env, status)
157157
}
158158
}
159159

pkg/openai/client.go

+27-12
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313

1414
openai "github.com/gptscript-ai/chat-completion-client"
1515
"github.com/gptscript-ai/gptscript/pkg/cache"
16-
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1716
"github.com/gptscript-ai/gptscript/pkg/counter"
1817
"github.com/gptscript-ai/gptscript/pkg/credentials"
1918
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -303,9 +302,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
303302
return
304303
}
305304

306-
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
305+
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
307306
if err := c.ValidAuth(); err != nil {
308-
if err := c.RetrieveAPIKey(ctx); err != nil {
307+
if err := c.RetrieveAPIKey(ctx, env); err != nil {
309308
return nil, err
310309
}
311310
}
@@ -401,15 +400,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
401400
if err != nil {
402401
return nil, err
403402
} else if !ok {
404-
result, err = c.call(ctx, request, id, status)
403+
result, err = c.call(ctx, request, id, env, status)
405404

406405
// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
407406
var apiError *openai.APIError
408407
if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat {
409408
// Decrease maxTokens by 10% to make garbage collection more aggressive.
410409
// The retry loop will further decrease maxTokens if needed.
411410
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
412-
result, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status)
411+
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
413412
}
414413
if err != nil {
415414
return nil, err
@@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
443442
return &result, nil
444443
}
445444

446-
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
445+
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
447446
var (
448447
response types.CompletionMessage
449448
err error
@@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC
452451
for range 10 { // maximum 10 tries
453452
// Try to drop older messages again, with a decreased max tokens.
454453
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
455-
response, err = c.call(ctx, request, id, status)
454+
response, err = c.call(ctx, request, id, env, status)
456455
if err == nil {
457456
return response, nil
458457
}
@@ -542,7 +541,7 @@ func override(left, right string) string {
542541
return left
543542
}
544543

545-
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
544+
func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) {
546545
streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false"
547546

548547
partial <- types.CompletionStatus{
@@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
553552
},
554553
}
555554

555+
var (
556+
headers map[string]string
557+
modelProviderEnv []string
558+
)
559+
for _, e := range env {
560+
if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") {
561+
modelProviderEnv = append(modelProviderEnv, e)
562+
}
563+
}
564+
565+
if len(modelProviderEnv) > 0 {
566+
headers = map[string]string{
567+
"X-GPTScript-Env": strings.Join(modelProviderEnv, ","),
568+
}
569+
}
570+
556571
slog.Debug("calling openai", "message", request.Messages)
557572

558573
if !streamResponse {
559574
request.StreamOptions = nil
560-
resp, err := c.c.CreateChatCompletion(ctx, request)
575+
resp, err := c.c.CreateChatCompletion(ctx, request, headers)
561576
if err != nil {
562577
return types.CompletionMessage{}, err
563578
}
@@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
582597
}), nil
583598
}
584599

585-
stream, err := c.c.CreateChatCompletionStream(ctx, request)
600+
stream, err := c.c.CreateChatCompletionStream(ctx, request, headers)
586601
if err != nil {
587602
return types.CompletionMessage{}, err
588603
}
@@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
614629
}
615630
}
616631

617-
func (c *Client) RetrieveAPIKey(ctx context.Context) error {
618-
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx))
632+
func (c *Client) RetrieveAPIKey(ctx context.Context, env []string) error {
633+
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", env)
619634
if err != nil {
620635
return err
621636
}

pkg/remote/remote.go

+9-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"sync"
1111

1212
"github.com/gptscript-ai/gptscript/pkg/cache"
13-
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1413
"github.com/gptscript-ai/gptscript/pkg/credentials"
1514
"github.com/gptscript-ai/gptscript/pkg/engine"
1615
env2 "github.com/gptscript-ai/gptscript/pkg/env"
@@ -42,13 +41,13 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
4241
}
4342
}
4443

45-
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
44+
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
4645
_, provider := c.parseModel(messageRequest.Model)
4746
if provider == "" {
4847
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
4948
}
5049

51-
client, err := c.load(ctx, provider)
50+
client, err := c.load(ctx, provider, env...)
5251
if err != nil {
5352
return nil, err
5453
}
@@ -60,7 +59,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
6059
modelName = toolName
6160
}
6261
messageRequest.Model = modelName
63-
return client.Call(ctx, messageRequest, status)
62+
return client.Call(ctx, messageRequest, env, status)
6463
}
6564

6665
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
@@ -111,7 +110,7 @@ func isHTTPURL(toolName string) bool {
111110
strings.HasPrefix(toolName, "https://")
112111
}
113112

114-
func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Client, error) {
113+
func (c *Client) clientFromURL(ctx context.Context, apiURL string, envs []string) (*openai.Client, error) {
115114
parsed, err := url.Parse(apiURL)
116115
if err != nil {
117116
return nil, err
@@ -121,7 +120,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
121120

122121
if key == "" && !isLocalhost(apiURL) {
123122
var err error
124-
key, err = c.retrieveAPIKey(ctx, env, apiURL)
123+
key, err = c.retrieveAPIKey(ctx, env, apiURL, envs)
125124
if err != nil {
126125
return nil, err
127126
}
@@ -134,7 +133,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
134133
})
135134
}
136135

137-
func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
136+
func (c *Client) load(ctx context.Context, toolName string, env ...string) (*openai.Client, error) {
138137
c.clientsLock.Lock()
139138
defer c.clientsLock.Unlock()
140139

@@ -144,7 +143,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
144143
}
145144

146145
if isHTTPURL(toolName) {
147-
remoteClient, err := c.clientFromURL(ctx, toolName)
146+
remoteClient, err := c.clientFromURL(ctx, toolName, env)
148147
if err != nil {
149148
return nil, err
150149
}
@@ -183,8 +182,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
183182
return oClient, nil
184183
}
185184

186-
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
187-
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...))
185+
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string, envs []string) (string, error) {
186+
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(envs, c.envs...))
188187
}
189188

190189
func isLocalhost(url string) bool {

pkg/runner/output.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
8484
if err != nil {
8585
return nil, fmt.Errorf("marshaling input for output filter: %w", err)
8686
}
87-
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory)
87+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, inputData, "", engine.OutputToolCategory)
8888
if err != nil {
8989
return nil, err
9090
}

pkg/tests/judge/judge.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func (j *Judge[T]) Equal(ctx context.Context, expected, actual T, criteria strin
112112
},
113113
},
114114
}
115-
response, err := j.client.CreateChatCompletion(ctx, request)
115+
response, err := j.client.CreateChatCompletion(ctx, request, nil)
116116
if err != nil {
117117
return false, "", fmt.Errorf("failed to create chat completion request: %w", err)
118118
}

pkg/tests/tester/runner.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (c *Client) ProxyInfo() (string, string, error) {
3535
return "test-auth", "test-url", nil
3636
}
3737

38-
func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
38+
func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ []string, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
3939
msgData, err := json.MarshalIndent(messageRequest, "", " ")
4040
require.NoError(c.t, err)
4141

0 commit comments

Comments
 (0)