From d1d617949481b466d89e543ddb955663996b1cf6 Mon Sep 17 00:00:00 2001
From: Grant Linville <grant@acorn.io>
Date: Mon, 4 Nov 2024 15:05:20 -0500
Subject: [PATCH] chore: sdkserver: update dataset methods for the rewrite

Signed-off-by: Grant Linville <grant@acorn.io>
---
 pkg/sdkserver/datasets.go | 186 +++++---------------------------------
 pkg/sdkserver/routes.go   |   2 -
 2 files changed, 23 insertions(+), 165 deletions(-)

diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go
index 5db90bf7..e922cd97 100644
--- a/pkg/sdkserver/datasets.go
+++ b/pkg/sdkserver/datasets.go
@@ -11,24 +11,21 @@ import (
 )
 
 func (s *server) getDatasetTool(req datasetRequest) string {
-	if req.DatasetToolRepo != "" {
-		return req.DatasetToolRepo
+	if req.DatasetTool != "" {
+		return req.DatasetTool
 	}
 
 	return s.datasetTool
 }
 
 type datasetRequest struct {
-	Input           string   `json:"input"`
-	WorkspaceID     string   `json:"workspaceID"`
-	DatasetToolRepo string   `json:"datasetToolRepo"`
-	Env             []string `json:"env"`
+	Input       string   `json:"input"`
+	DatasetTool string   `json:"datasetTool"`
+	Env         []string `json:"env"`
 }
 
 func (r datasetRequest) validate(requireInput bool) error {
-	if r.WorkspaceID == "" {
-		return fmt.Errorf("workspaceID is required")
-	} else if requireInput && r.Input == "" {
+	if requireInput && r.Input == "" {
 		return fmt.Errorf("input is required")
 	} else if len(r.Env) == 0 {
 		return fmt.Errorf("env is required")
@@ -38,10 +35,9 @@ func (r datasetRequest) validate(requireInput bool) error {
 
 func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
 	opts := gptscript.Options{
-		Cache:     o.Cache,
-		Monitor:   o.Monitor,
-		Runner:    o.Runner,
-		Workspace: r.WorkspaceID,
+		Cache:   o.Cache,
+		Monitor: o.Monitor,
+		Runner:  o.Runner,
 	}
 	return opts
 }
@@ -84,148 +80,19 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
 	writeResponse(logger, w, map[string]any{"stdout": result})
 }
 
-type createDatasetArgs struct {
-	Name        string `json:"datasetName"`
-	Description string `json:"datasetDescription"`
-}
-
-func (a createDatasetArgs) validate() error {
-	if a.Name == "" {
-		return fmt.Errorf("datasetName is required")
-	}
-	return nil
-}
-
-func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
-	logger := gcontext.GetLogger(r.Context())
-
-	var req datasetRequest
-	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
-		writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
-		return
-	}
-
-	if err := req.validate(true); err != nil {
-		writeError(logger, w, http.StatusBadRequest, err)
-		return
-	}
-
-	g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
-		return
-	}
-
-	var args createDatasetArgs
-	if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
-		writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
-		return
-	}
-
-	if err := args.validate(); err != nil {
-		writeError(logger, w, http.StatusBadRequest, err)
-		return
-	}
-
-	prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{
-		Cache: g.Cache,
-	})
-
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
-		return
-	}
-
-	result, err := g.Run(r.Context(), prg, req.Env, req.Input)
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
-		return
-	}
-
-	writeResponse(logger, w, map[string]any{"stdout": result})
-}
-
-type addDatasetElementArgs struct {
-	DatasetID          string `json:"datasetID"`
-	ElementName        string `json:"elementName"`
-	ElementDescription string `json:"elementDescription"`
-	ElementContent     string `json:"elementContent"`
-}
-
-func (a addDatasetElementArgs) validate() error {
-	if a.DatasetID == "" {
-		return fmt.Errorf("datasetID is required")
-	}
-	if a.ElementName == "" {
-		return fmt.Errorf("elementName is required")
-	}
-	if a.ElementContent == "" {
-		return fmt.Errorf("elementContent is required")
-	}
-	return nil
-}
-
-func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
-	logger := gcontext.GetLogger(r.Context())
-
-	var req datasetRequest
-	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
-		writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
-		return
-	}
-
-	if err := req.validate(true); err != nil {
-		writeError(logger, w, http.StatusBadRequest, err)
-		return
-	}
-
-	g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
-		return
-	}
-
-	var args addDatasetElementArgs
-	if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
-		writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
-		return
-	}
-
-	if err := args.validate(); err != nil {
-		writeError(logger, w, http.StatusBadRequest, err)
-		return
-	}
-
-	prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{
-		Cache: g.Cache,
-	})
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
-		return
-	}
-
-	result, err := g.Run(r.Context(), prg, req.Env, req.Input)
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
-		return
-	}
-
-	writeResponse(logger, w, map[string]any{"stdout": result})
-}
-
 type addDatasetElementsArgs struct {
-	DatasetID string `json:"datasetID"`
-	Elements  []struct {
-		Name        string `json:"name"`
-		Description string `json:"description"`
-		Contents    string `json:"contents"`
-	}
+	DatasetID   string `json:"datasetID"`
+	Name        string `json:"name"`
+	Description string `json:"description"`
+	Elements    []struct {
+		Name           string `json:"name"`
+		Description    string `json:"description"`
+		Contents       string `json:"contents"`
+		BinaryContents []byte `json:"binaryContents"`
+	} `json:"elements"`
 }
 
 func (a addDatasetElementsArgs) validate() error {
-	if a.DatasetID == "" {
-		return fmt.Errorf("datasetID is required")
-	}
 	if len(a.Elements) == 0 {
 		return fmt.Errorf("elements is required")
 	}
@@ -271,13 +138,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	elementsJSON, err := json.Marshal(args.Elements)
-	if err != nil {
-		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
-		return
-	}
-
-	result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON)))
+	result, err := g.Run(r.Context(), prg, req.Env, req.Input)
 	if err != nil {
 		writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
 		return
@@ -347,15 +208,14 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
 
 type getDatasetElementArgs struct {
 	DatasetID string `json:"datasetID"`
-	Element   string `json:"element"`
+	Name      string `json:"name"`
 }
 
 func (a getDatasetElementArgs) validate() error {
 	if a.DatasetID == "" {
 		return fmt.Errorf("datasetID is required")
-	}
-	if a.Element == "" {
-		return fmt.Errorf("element is required")
+	} else if a.Name == "" {
+		return fmt.Errorf("name is required")
 	}
 	return nil
 }
@@ -391,7 +251,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{
+	prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element", loader.Options{
 		Cache: g.Cache,
 	})
 	if err != nil {
diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go
index ea7fdb09..73bf5d58 100644
--- a/pkg/sdkserver/routes.go
+++ b/pkg/sdkserver/routes.go
@@ -69,10 +69,8 @@ func (s *server) addRoutes(mux *http.ServeMux) {
 	mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
 
 	mux.HandleFunc("POST /datasets", s.listDatasets)
-	mux.HandleFunc("POST /datasets/create", s.createDataset)
 	mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
 	mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
-	mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
 	mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)
 
 	mux.HandleFunc("POST /workspaces/create", s.createWorkspace)