Skip to content

Commit cb54f3b

Browse files
Merge pull request gptscript-ai#292 from ibuildthecloud/main
feat: add workspace functions
2 parents 0a7d53a + 7132a36 commit cb54f3b

File tree

7 files changed

+184
-48
lines changed

7 files changed

+184
-48
lines changed

pkg/builtin/builtin.go

+100-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ import (
2626
)
2727

2828
var tools = map[string]types.Tool{
29+
"sys.workspace.ls": {
30+
Parameters: types.Parameters{
31+
Description: "Lists the contents of a directory relative to the current workspace",
32+
Arguments: types.ObjectSchema(
33+
"dir", "The directory to list"),
34+
},
35+
BuiltinFunc: SysWorkspaceLs,
36+
},
37+
"sys.workspace.write": {
38+
Parameters: types.Parameters{
39+
Description: "Write the contents to a file relative to the current workspace",
40+
Arguments: types.ObjectSchema(
41+
"filename", "The name of the file to write to",
42+
"content", "The content to write"),
43+
},
44+
BuiltinFunc: SysWorkspaceWrite,
45+
},
46+
"sys.workspace.read": {
47+
Parameters: types.Parameters{
48+
Description: "Reads the contents of a file relative to the current workspace",
49+
Arguments: types.ObjectSchema(
50+
"filename", "The name of the file to read"),
51+
},
52+
BuiltinFunc: SysWorkspaceRead,
53+
},
2954
"sys.ls": {
3055
Parameters: types.Parameters{
3156
Description: "Lists the contents of a directory",
@@ -297,19 +322,46 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
297322
return string(out), err
298323
}
299324

325+
func getWorkspaceDir(envs []string) (string, error) {
326+
for _, env := range envs {
327+
dir, ok := strings.CutPrefix(env, "GPTSCRIPT_WORKSPACE_DIR=")
328+
if ok && dir != "" {
329+
return dir, nil
330+
}
331+
}
332+
return "", fmt.Errorf("no workspace directory found in env")
333+
}
334+
335+
func SysWorkspaceLs(_ context.Context, env []string, input string) (string, error) {
336+
dir, err := getWorkspaceDir(env)
337+
if err != nil {
338+
return "", err
339+
}
340+
return sysLs(dir, input)
341+
}
342+
300343
func SysLs(_ context.Context, _ []string, input string) (string, error) {
344+
return sysLs("", input)
345+
}
346+
347+
func sysLs(base, input string) (string, error) {
301348
var params struct {
302349
Dir string `json:"dir,omitempty"`
303350
}
304351
if err := json.Unmarshal([]byte(input), &params); err != nil {
305352
return "", err
306353
}
307354

308-
if params.Dir == "" {
309-
params.Dir = "."
355+
dir := params.Dir
356+
if dir == "" {
357+
dir = "."
358+
}
359+
360+
if base != "" {
361+
dir = filepath.Join(base, dir)
310362
}
311363

312-
entries, err := os.ReadDir(params.Dir)
364+
entries, err := os.ReadDir(dir)
313365
if errors.Is(err, fs.ErrNotExist) {
314366
return fmt.Sprintf("directory does not exist: %s", params.Dir), nil
315367
} else if err != nil {
@@ -328,20 +380,38 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
328380
return strings.Join(result, "\n"), nil
329381
}
330382

383+
func SysWorkspaceRead(ctx context.Context, env []string, input string) (string, error) {
384+
dir, err := getWorkspaceDir(env)
385+
if err != nil {
386+
return "", err
387+
}
388+
389+
return sysRead(ctx, dir, env, input)
390+
}
391+
331392
func SysRead(ctx context.Context, env []string, input string) (string, error) {
393+
return sysRead(ctx, "", env, input)
394+
}
395+
396+
func sysRead(ctx context.Context, base string, env []string, input string) (string, error) {
332397
var params struct {
333398
Filename string `json:"filename,omitempty"`
334399
}
335400
if err := json.Unmarshal([]byte(input), &params); err != nil {
336401
return "", err
337402
}
338403

404+
file := params.Filename
405+
if base != "" {
406+
file = filepath.Join(base, file)
407+
}
408+
339409
// Lock the file to prevent concurrent writes from other tool calls.
340-
locker.RLock(params.Filename)
341-
defer locker.RUnlock(params.Filename)
410+
locker.RLock(file)
411+
defer locker.RUnlock(file)
342412

343-
log.Debugf("Reading file %s", params.Filename)
344-
data, err := os.ReadFile(params.Filename)
413+
log.Debugf("Reading file %s", file)
414+
data, err := os.ReadFile(file)
345415
if errors.Is(err, fs.ErrNotExist) {
346416
return fmt.Sprintf("The file %s does not exist", params.Filename), nil
347417
} else if err != nil {
@@ -354,7 +424,19 @@ func SysRead(ctx context.Context, env []string, input string) (string, error) {
354424
return string(data), nil
355425
}
356426

427+
func SysWorkspaceWrite(ctx context.Context, env []string, input string) (string, error) {
428+
dir, err := getWorkspaceDir(env)
429+
if err != nil {
430+
return "", err
431+
}
432+
return sysWrite(ctx, dir, env, input)
433+
}
434+
357435
func SysWrite(ctx context.Context, env []string, input string) (string, error) {
436+
return sysWrite(ctx, "", env, input)
437+
}
438+
439+
func sysWrite(ctx context.Context, base string, env []string, input string) (string, error) {
358440
var params struct {
359441
Filename string `json:"filename,omitempty"`
360442
Content string `json:"content,omitempty"`
@@ -363,28 +445,33 @@ func SysWrite(ctx context.Context, env []string, input string) (string, error) {
363445
return "", err
364446
}
365447

448+
file := params.Filename
449+
if base != "" {
450+
file = filepath.Join(base, file)
451+
}
452+
366453
// Lock the file to prevent concurrent writes from other tool calls.
367-
locker.Lock(params.Filename)
368-
defer locker.Unlock(params.Filename)
454+
locker.Lock(file)
455+
defer locker.Unlock(file)
369456

370-
dir := filepath.Dir(params.Filename)
457+
dir := filepath.Dir(file)
371458
if _, err := os.Stat(dir); errors.Is(err, fs.ErrNotExist) {
372459
log.Debugf("Creating dir %s", dir)
373460
if err := os.MkdirAll(dir, 0755); err != nil {
374461
return "", fmt.Errorf("creating dir %s: %w", dir, err)
375462
}
376463
}
377464

378-
if _, err := os.Stat(params.Filename); err == nil {
465+
if _, err := os.Stat(file); err == nil {
379466
if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil {
380467
return "", err
381468
}
382469
}
383470

384471
data := []byte(params.Content)
385-
log.Debugf("Wrote %d bytes to file %s", len(data), params.Filename)
472+
log.Debugf("Wrote %d bytes to file %s", len(data), file)
386473

387-
return "", os.WriteFile(params.Filename, data, 0644)
474+
return "", os.WriteFile(file, data, 0644)
388475
}
389476

390477
func SysAppend(ctx context.Context, env []string, input string) (string, error) {

pkg/chat/chat.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg
3535
prompter Prompter
3636
)
3737

38-
prompter, err := newReadlinePrompter()
38+
prompter, err := newReadlinePrompter(prg)
3939
if err != nil {
4040
return err
4141
}

pkg/chat/readline.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/adrg/xdg"
1010
"github.com/chzyer/readline"
1111
"github.com/fatih/color"
12+
"github.com/gptscript-ai/gptscript/pkg/hash"
1213
"github.com/gptscript-ai/gptscript/pkg/mvl"
1314
)
1415

@@ -18,8 +19,13 @@ type readlinePrompter struct {
1819
readliner *readline.Instance
1920
}
2021

21-
func newReadlinePrompter() (*readlinePrompter, error) {
22-
historyFile, err := xdg.CacheFile("gptscript/chat.history")
22+
func newReadlinePrompter(prg GetProgram) (*readlinePrompter, error) {
23+
targetProgram, err := prg()
24+
if err != nil {
25+
return nil, err
26+
}
27+
28+
historyFile, err := xdg.CacheFile(fmt.Sprintf("gptscript/chat-%s.history", hash.ID(targetProgram.EntryToolID)))
2329
if err != nil {
2430
historyFile = ""
2531
}

pkg/cli/gptscript.go

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type GPTScript struct {
6161
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
6262
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
6363
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`
64+
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
6465

6566
readData []byte
6667
}
@@ -123,6 +124,7 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
123124
Quiet: r.Quiet,
124125
Env: os.Environ(),
125126
CredentialContext: r.CredentialContext,
127+
Workspace: r.Workspace,
126128
}
127129

128130
if r.Ports != "" {

pkg/gptscript/gptscript.go

+51-9
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,41 @@ package gptscript
22

33
import (
44
"context"
5+
"fmt"
56
"os"
67

78
"github.com/gptscript-ai/gptscript/pkg/builtin"
89
"github.com/gptscript-ai/gptscript/pkg/cache"
910
"github.com/gptscript-ai/gptscript/pkg/engine"
11+
"github.com/gptscript-ai/gptscript/pkg/hash"
1012
"github.com/gptscript-ai/gptscript/pkg/llm"
1113
"github.com/gptscript-ai/gptscript/pkg/monitor"
14+
"github.com/gptscript-ai/gptscript/pkg/mvl"
1215
"github.com/gptscript-ai/gptscript/pkg/openai"
1316
"github.com/gptscript-ai/gptscript/pkg/remote"
1417
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
1518
"github.com/gptscript-ai/gptscript/pkg/runner"
1619
"github.com/gptscript-ai/gptscript/pkg/types"
1720
)
1821

22+
var log = mvl.Package()
23+
1924
type GPTScript struct {
20-
Registry *llm.Registry
21-
Runner *runner.Runner
25+
Registry *llm.Registry
26+
Runner *runner.Runner
27+
WorkspacePath string
28+
DeleteWorkspaceOnClose bool
2229
}
2330

2431
type Options struct {
2532
Cache cache.Options
2633
OpenAI openai.Options
2734
Monitor monitor.Options
2835
Runner runner.Options
29-
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
30-
Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"`
31-
Env []string `usage:"-"`
36+
CredentialContext string
37+
Quiet *bool
38+
Workspace string
39+
Env []string
3240
}
3341

3442
func complete(opts *Options) (result *Options) {
@@ -89,21 +97,55 @@ func New(opts *Options) (*GPTScript, error) {
8997
}
9098

9199
return &GPTScript{
92-
Registry: registry,
93-
Runner: runner,
100+
Registry: registry,
101+
Runner: runner,
102+
WorkspacePath: opts.Workspace,
103+
DeleteWorkspaceOnClose: opts.Workspace == "",
94104
}, nil
95105
}
96106

97-
func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (runner.ChatResponse, error) {
98-
return g.Runner.Chat(ctx, prevState, prg, env, input)
107+
func (g *GPTScript) getEnv(env []string) ([]string, error) {
108+
if g.WorkspacePath == "" {
109+
var err error
110+
g.WorkspacePath, err = os.MkdirTemp("", "gptscript-workspace-*")
111+
if err != nil {
112+
return nil, err
113+
}
114+
}
115+
if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil {
116+
return nil, err
117+
}
118+
return append([]string{
119+
fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath),
120+
fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)),
121+
}, env...), nil
122+
}
123+
124+
func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) {
125+
envs, err := g.getEnv(envs)
126+
if err != nil {
127+
return runner.ChatResponse{}, err
128+
}
129+
130+
return g.Runner.Chat(ctx, prevState, prg, envs, input)
99131
}
100132

101133
func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) {
134+
envs, err := g.getEnv(envs)
135+
if err != nil {
136+
return "", err
137+
}
138+
102139
return g.Runner.Run(ctx, prg, envs, input)
103140
}
104141

105142
func (g *GPTScript) Close() {
106143
g.Runner.Close()
144+
if g.DeleteWorkspaceOnClose && g.WorkspacePath != "" {
145+
if err := os.RemoveAll(g.WorkspacePath); err != nil {
146+
log.Errorf("failed to delete workspace %s: %s", g.WorkspacePath, err)
147+
}
148+
}
107149
}
108150

109151
func (g *GPTScript) GetModel() engine.Model {

pkg/parser/parser.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
7979
value = strings.TrimSpace(value)
8080
switch normalize(key) {
8181
case "name":
82-
tool.Parameters.Name = strings.ToLower(value)
82+
tool.Parameters.Name = value
8383
case "modelprovider":
8484
tool.Parameters.ModelProvider = true
8585
case "model", "modelname":

0 commit comments

Comments
 (0)