@@ -13,7 +13,6 @@ import (
13
13
14
14
openai "github.com/gptscript-ai/chat-completion-client"
15
15
"github.com/gptscript-ai/gptscript/pkg/cache"
16
- gcontext "github.com/gptscript-ai/gptscript/pkg/context"
17
16
"github.com/gptscript-ai/gptscript/pkg/counter"
18
17
"github.com/gptscript-ai/gptscript/pkg/credentials"
19
18
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -303,9 +302,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
303
302
return
304
303
}
305
304
306
- func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
305
+ func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , env [] string , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
307
306
if err := c .ValidAuth (); err != nil {
308
- if err := c .RetrieveAPIKey (ctx ); err != nil {
307
+ if err := c .RetrieveAPIKey (ctx , env ); err != nil {
309
308
return nil , err
310
309
}
311
310
}
@@ -401,15 +400,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
401
400
if err != nil {
402
401
return nil , err
403
402
} else if ! ok {
404
- result , err = c .call (ctx , request , id , status )
403
+ result , err = c .call (ctx , request , id , env , status )
405
404
406
405
// If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
407
406
var apiError * openai.APIError
408
407
if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
409
408
// Decrease maxTokens by 10% to make garbage collection more aggressive.
410
409
// The retry loop will further decrease maxTokens if needed.
411
410
maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
412
- result , err = c .contextLimitRetryLoop (ctx , request , id , maxTokens , status )
411
+ result , err = c .contextLimitRetryLoop (ctx , request , id , env , maxTokens , status )
413
412
}
414
413
if err != nil {
415
414
return nil , err
@@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
443
442
return & result , nil
444
443
}
445
444
446
- func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
445
+ func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , env [] string , maxTokens int , status chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
447
446
var (
448
447
response types.CompletionMessage
449
448
err error
@@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC
452
451
for range 10 { // maximum 10 tries
453
452
// Try to drop older messages again, with a decreased max tokens.
454
453
request .Messages = dropMessagesOverCount (maxTokens , request .Messages )
455
- response , err = c .call (ctx , request , id , status )
454
+ response , err = c .call (ctx , request , id , env , status )
456
455
if err == nil {
457
456
return response , nil
458
457
}
@@ -542,7 +541,7 @@ func override(left, right string) string {
542
541
return left
543
542
}
544
543
545
- func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
544
+ func (c * Client ) call (ctx context.Context , request openai.ChatCompletionRequest , transactionID string , env [] string , partial chan <- types.CompletionStatus ) (types.CompletionMessage , error ) {
546
545
streamResponse := os .Getenv ("GPTSCRIPT_INTERNAL_OPENAI_STREAMING" ) != "false"
547
546
548
547
partial <- types.CompletionStatus {
@@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
553
552
},
554
553
}
555
554
555
+ var (
556
+ headers map [string ]string
557
+ modelProviderEnv []string
558
+ )
559
+ for _ , e := range env {
560
+ if strings .HasPrefix (e , "GPTSCRIPT_MODEL_PROVIDER_" ) {
561
+ modelProviderEnv = append (modelProviderEnv , e )
562
+ }
563
+ }
564
+
565
+ if len (modelProviderEnv ) > 0 {
566
+ headers = map [string ]string {
567
+ "X-GPTScript-Env" : strings .Join (modelProviderEnv , "," ),
568
+ }
569
+ }
570
+
556
571
slog .Debug ("calling openai" , "message" , request .Messages )
557
572
558
573
if ! streamResponse {
559
574
request .StreamOptions = nil
560
- resp , err := c .c .CreateChatCompletion (ctx , request )
575
+ resp , err := c .c .CreateChatCompletion (ctx , request , headers )
561
576
if err != nil {
562
577
return types.CompletionMessage {}, err
563
578
}
@@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
582
597
}), nil
583
598
}
584
599
585
- stream , err := c .c .CreateChatCompletionStream (ctx , request )
600
+ stream , err := c .c .CreateChatCompletionStream (ctx , request , headers )
586
601
if err != nil {
587
602
return types.CompletionMessage {}, err
588
603
}
@@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
614
629
}
615
630
}
616
631
617
- func (c * Client ) RetrieveAPIKey (ctx context.Context ) error {
618
- k , err := prompt .GetModelProviderCredential (ctx , c .credStore , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , gcontext . GetEnv ( ctx ) )
632
+ func (c * Client ) RetrieveAPIKey (ctx context.Context , env [] string ) error {
633
+ k , err := prompt .GetModelProviderCredential (ctx , c .credStore , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , env )
619
634
if err != nil {
620
635
return err
621
636
}
0 commit comments