From 59a4668079152ab2934ac9dfb27dcc862712cce6 Mon Sep 17 00:00:00 2001 From: Michael Pursifull Date: Fri, 27 Sep 2024 22:11:12 -0500 Subject: [PATCH] validate llm model, bugfix for make install --- Makefile | 2 +- cmd/root.go | 6 ++++++ internal/llm.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 45cfc11..91fb458 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ build: ## Build Kuzco env $(if $(GOOS),GOOS=$(GOOS)) $(if $(GOARCH),GOARCH=$(GOARCH)) $(GO) build -o build/$(BINARY_NAME) -ldflags "-X 'github.com/RoseSecurity/kuzco/cmd.Version=local'" main.go install: ## Install dependencies - $(GO) install ./...@latest + $(GO) install ./... clean: ## Clean up build artifacts $(GO) clean diff --git a/cmd/root.go b/cmd/root.go index d55c7bf..ffb9679 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -52,6 +52,12 @@ func runAnalyzer(cmd *cobra.Command, args []string) { return } + // Validate that the specified model exists in Ollama + if err := internal.ValidateModel(model, addr); err != nil { + fmt.Fprintf(os.Stderr, "Model validation error: %v\n", err) + os.Exit(1) + } + // Proceed with the main logic if all required flags are set if err := internal.Run(filePath, model, addr); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) diff --git a/internal/llm.go b/internal/llm.go index 8b3310b..1d97c0d 100644 --- a/internal/llm.go +++ b/internal/llm.go @@ -31,6 +31,16 @@ type LlamaResponse struct { Recommendations string `json:"response"` } +// Model represents the structure of a model in the response +type Model struct { + Name string `json:"name"` +} + +// ModelsResponse represents the structure of the response from /api/tags +type ModelsResponse struct { + Models []Model `json:"models"` +} + func GetRecommendations(resourceType string, unusedAttrs []string, model string, addr string) (string, error) { prompt := fmt.Sprintf(`Unused attributes for Terraform resource '%s': %v @@ -96,3 +106,37 @@ resource "type" "name" { return llamaResp.Recommendations, nil } + +// ValidateModel checks if the specified model exists in Ollama +func ValidateModel(model, addr string) error { + // Get a list of available models from Ollama + resp, err := http.Get(fmt.Sprintf("%s/api/tags", addr)) + if err != nil { + return fmt.Errorf("error fetching models from Ollama: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to retrieve models: status code %d", resp.StatusCode) + } + + // Parse the response body into ModelsResponse + var modelsResp ModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&modelsResp); err != nil { + return fmt.Errorf("error decoding models response: %v", err) + } + + // Check if the requested model is in the list of models + for _, availableModel := range modelsResp.Models { + if availableModel.Name == model { + return nil + } + } + + // If model is not found, return an error and list available models + var availableModelNames []string + for _, availableModel := range modelsResp.Models { + availableModelNames = append(availableModelNames, availableModel.Name) + } + return fmt.Errorf("model '%s' not found. Available models: %v", model, availableModelNames) +}