From d1d617949481b466d89e543ddb955663996b1cf6 Mon Sep 17 00:00:00 2001 From: Grant Linville <grant@acorn.io> Date: Mon, 4 Nov 2024 15:05:20 -0500 Subject: [PATCH] chore: sdkserver: update dataset methods for the rewrite Signed-off-by: Grant Linville <grant@acorn.io> --- pkg/sdkserver/datasets.go | 186 +++++--------------------------------- pkg/sdkserver/routes.go | 2 - 2 files changed, 23 insertions(+), 165 deletions(-) diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index 5db90bf7..e922cd97 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -11,24 +11,21 @@ import ( ) func (s *server) getDatasetTool(req datasetRequest) string { - if req.DatasetToolRepo != "" { - return req.DatasetToolRepo + if req.DatasetTool != "" { + return req.DatasetTool } return s.datasetTool } type datasetRequest struct { - Input string `json:"input"` - WorkspaceID string `json:"workspaceID"` - DatasetToolRepo string `json:"datasetToolRepo"` - Env []string `json:"env"` + Input string `json:"input"` + DatasetTool string `json:"datasetTool"` + Env []string `json:"env"` } func (r datasetRequest) validate(requireInput bool) error { - if r.WorkspaceID == "" { - return fmt.Errorf("workspaceID is required") - } else if requireInput && r.Input == "" { + if requireInput && r.Input == "" { return fmt.Errorf("input is required") } else if len(r.Env) == 0 { return fmt.Errorf("env is required") @@ -38,10 +35,9 @@ func (r datasetRequest) validate(requireInput bool) error { func (r datasetRequest) opts(o gptscript.Options) gptscript.Options { opts := gptscript.Options{ - Cache: o.Cache, - Monitor: o.Monitor, - Runner: o.Runner, - Workspace: r.WorkspaceID, + Cache: o.Cache, + Monitor: o.Monitor, + Runner: o.Runner, } return opts } @@ -84,148 +80,19 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { writeResponse(logger, w, map[string]any{"stdout": result}) } -type createDatasetArgs struct { - Name string `json:"datasetName"` - Description string `json:"datasetDescription"` -} - -func (a createDatasetArgs) validate() error { - if a.Name == "" { - return fmt.Errorf("datasetName is required") - } - return nil -} - -func (s *server) createDataset(w http.ResponseWriter, r *http.Request) { - logger := gcontext.GetLogger(r.Context()) - - var req datasetRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) - return - } - - if err := req.validate(true); err != nil { - writeError(logger, w, http.StatusBadRequest, err) - return - } - - g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts)) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err)) - return - } - - var args createDatasetArgs - if err := json.Unmarshal([]byte(req.Input), &args); err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err)) - return - } - - if err := args.validate(); err != nil { - writeError(logger, w, http.StatusBadRequest, err) - return - } - - prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{ - Cache: g.Cache, - }) - - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) - return - } - - result, err := g.Run(r.Context(), prg, req.Env, req.Input) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) - return - } - - writeResponse(logger, w, map[string]any{"stdout": result}) -} - -type addDatasetElementArgs struct { - DatasetID string `json:"datasetID"` - ElementName string `json:"elementName"` - ElementDescription string `json:"elementDescription"` - ElementContent string `json:"elementContent"` -} - -func (a addDatasetElementArgs) validate() error { - if a.DatasetID == "" { - return fmt.Errorf("datasetID is required") - } - if a.ElementName == "" { - return fmt.Errorf("elementName is required") - } - if a.ElementContent == "" { - return fmt.Errorf("elementContent is required") - } - return nil -} - -func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) { - logger := gcontext.GetLogger(r.Context()) - - var req datasetRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) - return - } - - if err := req.validate(true); err != nil { - writeError(logger, w, http.StatusBadRequest, err) - return - } - - g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts)) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err)) - return - } - - var args addDatasetElementArgs - if err := json.Unmarshal([]byte(req.Input), &args); err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err)) - return - } - - if err := args.validate(); err != nil { - writeError(logger, w, http.StatusBadRequest, err) - return - } - - prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{ - Cache: g.Cache, - }) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) - return - } - - result, err := g.Run(r.Context(), prg, req.Env, req.Input) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) - return - } - - writeResponse(logger, w, map[string]any{"stdout": result}) -} - type addDatasetElementsArgs struct { - DatasetID string `json:"datasetID"` - Elements []struct { - Name string `json:"name"` - Description string `json:"description"` - Contents string `json:"contents"` - } + DatasetID string `json:"datasetID"` + Name string `json:"name"` + Description string `json:"description"` + Elements []struct { + Name string `json:"name"` + Description string `json:"description"` + Contents string `json:"contents"` + BinaryContents []byte `json:"binaryContents"` + } `json:"elements"` } func (a addDatasetElementsArgs) validate() error { - if a.DatasetID == "" { - return fmt.Errorf("datasetID is required") - } if len(a.Elements) == 0 { return fmt.Errorf("elements is required") } @@ -271,13 +138,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) { return } - elementsJSON, err := json.Marshal(args.Elements) - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err)) - return - } - - result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON))) + result, err := g.Run(r.Context(), prg, req.Env, req.Input) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -347,15 +208,14 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { type getDatasetElementArgs struct { DatasetID string `json:"datasetID"` - Element string `json:"element"` + Name string `json:"name"` } func (a getDatasetElementArgs) validate() error { if a.DatasetID == "" { return fmt.Errorf("datasetID is required") - } - if a.Element == "" { - return fmt.Errorf("element is required") + } else if a.Name == "" { + return fmt.Errorf("name is required") } return nil } @@ -391,7 +251,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element", loader.Options{ Cache: g.Cache, }) if err != nil { diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index ea7fdb09..73bf5d58 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -69,10 +69,8 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /credentials/delete", s.deleteCredential) mux.HandleFunc("POST /datasets", s.listDatasets) - mux.HandleFunc("POST /datasets/create", s.createDataset) mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements) mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement) - mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement) mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements) mux.HandleFunc("POST /workspaces/create", s.createWorkspace)