Skip to content

Commit

Permalink
Merge pull request #5 from arcaven/validate-models
Browse files Browse the repository at this point in the history
validate llm model, bugfix for make install
  • Loading branch information
RoseSecurity authored Sep 30, 2024
2 parents 2a11aff + 59a4668 commit 2d61cfc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ build: ## Build Kuzco
env $(if $(GOOS),GOOS=$(GOOS)) $(if $(GOARCH),GOARCH=$(GOARCH)) $(GO) build -o build/$(BINARY_NAME) -ldflags "-X 'github.com/RoseSecurity/kuzco/cmd.Version=local'" main.go

install: ## Install dependencies
$(GO) install ./...@latest
$(GO) install ./...

clean: ## Clean up build artifacts
$(GO) clean
Expand Down
6 changes: 6 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ func runAnalyzer(cmd *cobra.Command, args []string) {
return
}

// Validate that the specified model exists in Ollama
if err := internal.ValidateModel(model, addr); err != nil {
fmt.Fprintf(os.Stderr, "Model validation error: %v\n", err)
os.Exit(1)
}

// Proceed with the main logic if all required flags are set
if err := internal.Run(filePath, model, addr); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
Expand Down
44 changes: 44 additions & 0 deletions internal/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ type LlamaResponse struct {
Recommendations string `json:"response"`
}

// Model represents the structure of a model in the response
type Model struct {
Name string `json:"name"`
}

// ModelsResponse represents the structure of the response from /api/tags
type ModelsResponse struct {
Models []Model `json:"models"`
}

func GetRecommendations(resourceType string, unusedAttrs []string, model string, addr string) (string, error) {
prompt := fmt.Sprintf(`Unused attributes for Terraform resource '%s': %v
Expand Down Expand Up @@ -96,3 +106,37 @@ resource "type" "name" {

return llamaResp.Recommendations, nil
}

// ValidateModel checks if the specified model exists in Ollama
func ValidateModel(model, addr string) error {
// Get a list of available models from Ollama
resp, err := http.Get(fmt.Sprintf("%s/api/tags", addr))
if err != nil {
return fmt.Errorf("error fetching models from Ollama: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to retrieve models: status code %d", resp.StatusCode)
}

// Parse the response body into ModelsResponse
var modelsResp ModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&modelsResp); err != nil {
return fmt.Errorf("error decoding models response: %v", err)
}

// Check if the requested model is in the list of models
for _, availableModel := range modelsResp.Models {
if availableModel.Name == model {
return nil
}
}

// If model is not found, return an error and list available models
var availableModelNames []string
for _, availableModel := range modelsResp.Models {
availableModelNames = append(availableModelNames, availableModel.Name)
}
return fmt.Errorf("model '%s' not found. Available models: %v", model, availableModelNames)
}

0 comments on commit 2d61cfc

Please sign in to comment.