From 960cdeaa0c248c5ec25626c8dfcb9b7acf1e6397 Mon Sep 17 00:00:00 2001 From: prasad89 Date: Tue, 5 Nov 2024 10:17:30 +0530 Subject: [PATCH 1/2] Enhance PullModel Function to Provide Visual Feedback During Model Retrieval --- examples/main.go | 1 + pkg/golamify/generate.go | 7 +----- pkg/golamify/model.go | 48 ++++++++++++++++++++++------------------ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/examples/main.go b/examples/main.go index 16b521b..c92574c 100644 --- a/examples/main.go +++ b/examples/main.go @@ -30,6 +30,7 @@ func main() { resp, err := golamify.Generate(client, "llama3.2:1b", prompt) if err != nil { fmt.Println(err) + return } fmt.Println(i,resp.Response) diff --git a/pkg/golamify/generate.go b/pkg/golamify/generate.go index 560463e..96c2c81 100644 --- a/pkg/golamify/generate.go +++ b/pkg/golamify/generate.go @@ -36,15 +36,10 @@ func Generate(c *Client, model string, prompt string) (*GenerateResponse, error) } if statusCode == http.StatusNotFound { - pullStatus, err := PullModel(model, c) + err := PullModel(model, c) if err != nil { return nil, fmt.Errorf("failed to pull model: %w", err) } - if pullStatus != http.StatusOK { - return nil, fmt.Errorf("failed to pull model, received status: %d", pullStatus) - } - } else if statusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code from ShowModel: %d", statusCode) } payload := GeneratePayload{ diff --git a/pkg/golamify/model.go b/pkg/golamify/model.go index 309a4fd..9e00a30 100644 --- a/pkg/golamify/model.go +++ b/pkg/golamify/model.go @@ -6,15 +6,9 @@ import ( "fmt" "io" "net/http" + "time" ) -type PullResponse struct { - Status string `json:"status"` - Digest string `json:"digest"` - Total string `json:"total"` - Error string `josn:"error"` -} - func ShowModel(model string, c *Client) (int, error) { payload := map[string]string{"name": model} @@ -39,41 +33,53 @@ func ShowModel(model string, c *Client) (int, error) { return resp.StatusCode, nil } -func PullModel(model string, c *Client) (int, error) { +func PullModel(model string, c *Client) error { payload := map[string]string{"name": model} body, err := json.Marshal(payload) if err != nil { - return 0, fmt.Errorf("failed to encode request payload: %w", err) + return fmt.Errorf("failed to encode request payload: %w", err) } req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/pull", bytes.NewBuffer(body)) if err != nil { - return 0, fmt.Errorf("failed to create request: %w", err) + return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", c.userAgent) req.Header.Set("Content-Type", "application/json") resp, err := c.httpClient.Do(req) if err != nil { - return 0, fmt.Errorf("failed to connect to pull endpoint: %w", err) + return fmt.Errorf("failed to connect to pull endpoint: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return resp.StatusCode, fmt.Errorf("pull endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody)) + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read error response body: %w", err) + } + return fmt.Errorf("pull endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody)) } - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return 0, fmt.Errorf("failed to read response body: %w", err) - } + fmt.Printf("Pulling model: %s", model) + + decoder := json.NewDecoder(resp.Body) + for decoder.More() { + var message struct { + Error string `json:"error,omitempty"` + } + if err := decoder.Decode(&message); err != nil { + return fmt.Errorf("failed to decode JSON message: %w", err) + } - var pullResponse PullResponse - if err := json.Unmarshal(respBody, &pullResponse); err != nil { - return 0, fmt.Errorf("failed to decode response body: %w", err) + if message.Error != "" { + return fmt.Errorf("'%s': %s", model, message.Error) + } + + fmt.Print("...") + time.Sleep(5 * time.Second) } - return resp.StatusCode, nil + return nil } From 5cc7944fd40bd929b8277972676903fe06c5e0f3 Mon Sep 17 00:00:00 2001 From: prasad89 Date: Tue, 5 Nov 2024 16:43:28 +0530 Subject: [PATCH 2/2] Enhance Generate function with streaming support and expanded payload options --- examples/main.go | 32 +++++++--- pkg/golamify/generate.go | 127 +++++++++++++++++++++------------------ 2 files changed, 92 insertions(+), 67 deletions(-) diff --git a/examples/main.go b/examples/main.go index c92574c..d77be89 100644 --- a/examples/main.go +++ b/examples/main.go @@ -17,27 +17,43 @@ func main() { client, err := golamify.NewClient(&config) if err != nil { fmt.Print(err) + return } + var wg sync.WaitGroup - for i := 1; i < 100; i++ { + for i := 1; i <= 1; i++ { wg.Add(1) go func(i int) { defer wg.Done() prompt := fmt.Sprintf("What is the square of %d?", i) - resp, err := golamify.Generate(client, "llama3.2:1b", prompt) - if err != nil { - fmt.Println(err) - return + + payload := golamify.GeneratePayload{ + Model: "llama3.2:1b", + Prompt: prompt, + Stream: new(bool), } - fmt.Println(i,resp.Response) + responseChannel, errorChannel := golamify.Generate(client, &payload) + + for { + select { + case response, ok := <-responseChannel: + if !ok { + return + } + fmt.Print(response["response"]) + + case err, ok := <-errorChannel: + if ok && err != nil { + fmt.Println("Error:", err) + } + } + } }(i) } wg.Wait() - - return } diff --git a/pkg/golamify/generate.go b/pkg/golamify/generate.go index 96c2c81..08da484 100644 --- a/pkg/golamify/generate.go +++ b/pkg/golamify/generate.go @@ -1,6 +1,7 @@ package golamify import ( + "bufio" "bytes" "encoding/json" "fmt" @@ -9,77 +10,85 @@ import ( ) type GeneratePayload struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Stream bool `json:"stream"` + Model string `json:"model" validate:"required"` + Prompt string `json:"prompt" validate:"required"` + Suffix string `json:"suffix,omitempty"` + Images []string `json:"images,omitempty"` + System string `json:"system,omitempty"` + Template string `json:"template,omitempty"` + Context string `json:"context,omitempty"` + Stream *bool `json:"stream,omitempty"` + Raw *bool `json:"raw,omitempty"` + KeepAlive string `json:"keep_alive,omitempty"` } -type GenerateResponse struct { - Model string `json:"model"` - CreatedAt string `json:"created_at"` - Response string `json:"response"` - Done bool `json:"done"` - DoneReason string `json:"done_reason"` - Context []int `json:"context"` - TotalDuration int64 `json:"total_duration"` - LoadDuration int64 `json:"load_duration"` - PromptEvalCount int64 `json:"prompt_eval_count"` - PromptEvalDuration int64 `json:"prompt_eval_duration"` - EvalCount int64 `json:"eval_count"` - EvalDuration int64 `json:"eval_duration"` -} +func Generate(c *Client, payload *GeneratePayload) (<-chan map[string]interface{}, <-chan error) { + responseChannel := make(chan map[string]interface{}) + errorChannel := make(chan error, 1) -func Generate(c *Client, model string, prompt string) (*GenerateResponse, error) { - statusCode, err := ShowModel(model, c) - if err != nil { - return nil, fmt.Errorf("error showing model: %w", err) - } + go func() { + defer close(responseChannel) + defer close(errorChannel) - if statusCode == http.StatusNotFound { - err := PullModel(model, c) + statusCode, err := ShowModel(payload.Model, c) if err != nil { - return nil, fmt.Errorf("failed to pull model: %w", err) + errorChannel <- fmt.Errorf("error showing model: %w", err) + return } - } - payload := GeneratePayload{ - Model: model, - Prompt: prompt, - Stream: false, - } + if statusCode == http.StatusNotFound { + err := PullModel(payload.Model, c) + if err != nil { + errorChannel <- fmt.Errorf("failed to pull model: %w", err) + return + } + } - body, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to encode request payload: %w", err) - } + body, err := json.Marshal(payload) + if err != nil { + errorChannel <- fmt.Errorf("failed to encode request payload: %w", err) + return + } - req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/generate", bytes.NewBuffer(body)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("User-Agent", c.userAgent) - req.Header.Set("Content-Type", "application/json") + req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/generate", bytes.NewBuffer(body)) + if err != nil { + errorChannel <- fmt.Errorf("failed to create request: %w", err) + return + } + req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("Content-Type", "application/json") - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to connect to generate endpoint: %w", err) - } - defer resp.Body.Close() + resp, err := c.httpClient.Do(req) + if err != nil { + errorChannel <- fmt.Errorf("failed to connect to generate endpoint: %w", err) + return + } + defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("generate endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody)) - } + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + errorChannel <- fmt.Errorf("generate endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody)) + return + } - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var generated map[string]interface{} + err := json.Unmarshal(scanner.Bytes(), &generated) + if err != nil { + errorChannel <- fmt.Errorf("error parsing JSON: %w", err) + continue + } + responseChannel <- generated + if done, exists := generated["done"].(bool); exists && done { + break + } + } - var generateResponse GenerateResponse - if err := json.Unmarshal(respBody, &generateResponse); err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) - } + if err := scanner.Err(); err != nil { + errorChannel <- fmt.Errorf("error reading response body: %w", err) + } + }() - return &generateResponse, nil + return responseChannel, errorChannel }