Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Dec 16, 2024
1 parent 94ffd42 commit 2b7cb50
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 34 deletions.
23 changes: 16 additions & 7 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ type Ports struct {

type Certs struct {
daemonCerts map[string]certs.CertAndKey
daemonLock sync.Mutex
clientCert certs.CertAndKey
lock sync.Mutex
}

func IsDaemonRunning(url string) bool {
Expand Down Expand Up @@ -157,8 +158,8 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path)

// Generate a certificate for the daemon, unless one already exists.
certificates.daemonLock.Lock()
defer certificates.daemonLock.Unlock()
certificates.lock.Lock()
defer certificates.lock.Unlock()
cert, exists := certificates.daemonCerts[tool.ID]
if !exists {
var err error
Expand All @@ -173,12 +174,21 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
certificates.daemonCerts[tool.ID] = cert
}

// Set the client certificate if there isn't one already.
if len(certificates.clientCert.Cert) == 0 {
gptscriptCert, err := certs.GenerateGPTScriptCert()
if err != nil {
return "", fmt.Errorf("failed to generate GPTScript certificate: %v", err)
}
certificates.clientCert = gptscriptCert
}

cmd, stop, err := e.newCommand(ctx, []string{
fmt.Sprintf("PORT=%d", port),
fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)),
fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)),
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)),
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(certificates.clientCert.Cert)),
},
tool,
"{}",
Expand Down Expand Up @@ -241,7 +251,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
}()

// Build HTTP client for checking the health of the daemon
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
tlsClientCert, err := tls.X509KeyPair(certificates.clientCert.Cert, certificates.clientCert.Key)
if err != nil {
return "", fmt.Errorf("failed to create client certificate: %v", err)
}
Expand All @@ -254,7 +264,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{clientCert},
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
},
Expand All @@ -271,7 +281,6 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
}()
return url, nil
}
_ = resp.Body.Close()
select {
case <-killedCtx.Done():
return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx))
Expand Down
2 changes: 0 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"
"sync"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
Expand All @@ -23,7 +22,6 @@ type RuntimeManager interface {
}

type Engine struct {
GPTScriptCert certs.CertAndKey
Model Model
RuntimeManager RuntimeManager
Env []string
Expand Down
9 changes: 5 additions & 4 deletions pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
toolURL = parsed.String()

// Find the certificate corresponding to this daemon tool
certificates.daemonLock.Lock()
certificates.lock.Lock()
daemonCert, exists := certificates.daemonCerts[referencedTool.ID]
certificates.daemonLock.Unlock()
clientCert := certificates.clientCert
certificates.lock.Unlock()

if !exists {
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
Expand All @@ -79,14 +80,14 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
}

clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
if err != nil {
return nil, fmt.Errorf("failed to create client certificate: %v", err)
}

// Create TLS config for use in the HTTP client later
tlsConfigForDaemonRequest = &tls.Config{
Certificates: []tls.Certificate{clientCert},
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
}
Expand Down
14 changes: 4 additions & 10 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/config"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
Expand Down Expand Up @@ -108,12 +107,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir)
}

gptscriptCert, err := certs.GenerateGPTScriptCert()
if err != nil {
return nil, err
}

simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert)
simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -146,7 +140,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet})
}

runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner)
runner, err := runner.New(registry, credStore, opts.Runner)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -291,8 +285,8 @@ type simpleRunner struct {
env []string
}

func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) {
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) {
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{
RuntimeManager: rm,
MonitorFactory: simpleMonitorFactory{},
})
Expand Down
7 changes: 1 addition & 6 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"time"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/certs"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/engine"
Expand Down Expand Up @@ -96,10 +95,9 @@ type Runner struct {
credOverrides []string
credStore credentials.CredentialStore
sequential bool
gptscriptCert certs.CertAndKey
}

func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) {
func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) {
opt := complete(opts...)

runner := &Runner{
Expand All @@ -111,7 +109,6 @@ func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCe
credStore: credStore,
sequential: opt.Sequential,
auth: opt.Authorizer,
gptscriptCert: gptscriptCert,
}

if opt.StartPort != 0 {
Expand Down Expand Up @@ -414,7 +411,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
Progress: progress,
Env: env,
GPTScriptCert: r.gptscriptCert,
}

callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
Expand Down Expand Up @@ -597,7 +593,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
Progress: progress,
Env: env,
GPTScriptCert: r.gptscriptCert,
}

var contentInput string
Expand Down
6 changes: 1 addition & 5 deletions pkg/tests/tester/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"testing"

"github.com/adrg/xdg"
"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
Expand Down Expand Up @@ -199,10 +198,7 @@ func NewRunner(t *testing.T) *Runner {

rm := runtimes.Default(cacheDir, "")

gptscriptCert, err := certs.GenerateGPTScriptCert()
require.NoError(t, err)

run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{
run, err := runner.New(c, credentials.NoopStore{}, runner.Options{
Sequential: true,
RuntimeManager: rm,
})
Expand Down

0 comments on commit 2b7cb50

Please sign in to comment.