Skip to content

Commit

Permalink
Merge feature enhancements for PullModel and Generate functions
Browse files Browse the repository at this point in the history
- Integrated improvements to `PullModel` to provide visual feedback during model retrieval, enhancing user experience.
- Merged updates to `Generate` function to support streaming responses via channels, enabling real-time response handling.
- Expanded `GeneratePayload` struct with additional fields (`Suffix`, `Images`, `System`, `Template`, `Context`, `Stream`, `Raw`, `KeepAlive`) to allow greater customization in API requests.
- Updated main package to demonstrate new `Generate` functionality with streaming and concurrent request handling.
  • Loading branch information
prasad89 authored Nov 5, 2024
2 parents fd456dd + 5cc7944 commit c5f8bd9
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 92 deletions.
31 changes: 24 additions & 7 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,43 @@ func main() {
client, err := golamify.NewClient(&config)
if err != nil {
fmt.Print(err)
return
}

var wg sync.WaitGroup

for i := 1; i < 100; i++ {
for i := 1; i <= 1; i++ {
wg.Add(1)

go func(i int) {
defer wg.Done()

prompt := fmt.Sprintf("What is the square of %d?", i)
resp, err := golamify.Generate(client, "llama3.2:1b", prompt)
if err != nil {
fmt.Println(err)

payload := golamify.GeneratePayload{
Model: "llama3.2:1b",
Prompt: prompt,
Stream: new(bool),
}

fmt.Println(i,resp.Response)
responseChannel, errorChannel := golamify.Generate(client, &payload)

for {
select {
case response, ok := <-responseChannel:
if !ok {
return
}
fmt.Print(response["response"])

case err, ok := <-errorChannel:
if ok && err != nil {
fmt.Println("Error:", err)
}
}
}
}(i)
}

wg.Wait()

return
}
132 changes: 68 additions & 64 deletions pkg/golamify/generate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package golamify

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
Expand All @@ -9,82 +10,85 @@ import (
)

type GeneratePayload struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
Model string `json:"model" validate:"required"`
Prompt string `json:"prompt" validate:"required"`
Suffix string `json:"suffix,omitempty"`
Images []string `json:"images,omitempty"`
System string `json:"system,omitempty"`
Template string `json:"template,omitempty"`
Context string `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw *bool `json:"raw,omitempty"`
KeepAlive string `json:"keep_alive,omitempty"`
}

type GenerateResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
DoneReason string `json:"done_reason"`
Context []int `json:"context"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int64 `json:"prompt_eval_count"`
PromptEvalDuration int64 `json:"prompt_eval_duration"`
EvalCount int64 `json:"eval_count"`
EvalDuration int64 `json:"eval_duration"`
}
func Generate(c *Client, payload *GeneratePayload) (<-chan map[string]interface{}, <-chan error) {
responseChannel := make(chan map[string]interface{})
errorChannel := make(chan error, 1)

func Generate(c *Client, model string, prompt string) (*GenerateResponse, error) {
statusCode, err := ShowModel(model, c)
if err != nil {
return nil, fmt.Errorf("error showing model: %w", err)
}
go func() {
defer close(responseChannel)
defer close(errorChannel)

if statusCode == http.StatusNotFound {
pullStatus, err := PullModel(model, c)
statusCode, err := ShowModel(payload.Model, c)
if err != nil {
return nil, fmt.Errorf("failed to pull model: %w", err)
}
if pullStatus != http.StatusOK {
return nil, fmt.Errorf("failed to pull model, received status: %d", pullStatus)
errorChannel <- fmt.Errorf("error showing model: %w", err)
return
}
} else if statusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code from ShowModel: %d", statusCode)
}

payload := GeneratePayload{
Model: model,
Prompt: prompt,
Stream: false,
}
if statusCode == http.StatusNotFound {
err := PullModel(payload.Model, c)
if err != nil {
errorChannel <- fmt.Errorf("failed to pull model: %w", err)
return
}
}

body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to encode request payload: %w", err)
}
body, err := json.Marshal(payload)
if err != nil {
errorChannel <- fmt.Errorf("failed to encode request payload: %w", err)
return
}

req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/generate", bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", c.userAgent)
req.Header.Set("Content-Type", "application/json")
req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/generate", bytes.NewBuffer(body))
if err != nil {
errorChannel <- fmt.Errorf("failed to create request: %w", err)
return
}
req.Header.Set("User-Agent", c.userAgent)
req.Header.Set("Content-Type", "application/json")

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to connect to generate endpoint: %w", err)
}
defer resp.Body.Close()
resp, err := c.httpClient.Do(req)
if err != nil {
errorChannel <- fmt.Errorf("failed to connect to generate endpoint: %w", err)
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("generate endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody))
}
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
errorChannel <- fmt.Errorf("generate endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody))
return
}

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var generated map[string]interface{}
err := json.Unmarshal(scanner.Bytes(), &generated)
if err != nil {
errorChannel <- fmt.Errorf("error parsing JSON: %w", err)
continue
}
responseChannel <- generated
if done, exists := generated["done"].(bool); exists && done {
break
}
}

var generateResponse GenerateResponse
if err := json.Unmarshal(respBody, &generateResponse); err != nil {
return nil, fmt.Errorf("failed to decode response body: %w", err)
}
if err := scanner.Err(); err != nil {
errorChannel <- fmt.Errorf("error reading response body: %w", err)
}
}()

return &generateResponse, nil
return responseChannel, errorChannel
}
48 changes: 27 additions & 21 deletions pkg/golamify/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@ import (
"fmt"
"io"
"net/http"
"time"
)

type PullResponse struct {
Status string `json:"status"`
Digest string `json:"digest"`
Total string `json:"total"`
Error string `josn:"error"`
}

func ShowModel(model string, c *Client) (int, error) {
payload := map[string]string{"name": model}

Expand All @@ -39,41 +33,53 @@ func ShowModel(model string, c *Client) (int, error) {
return resp.StatusCode, nil
}

func PullModel(model string, c *Client) (int, error) {
func PullModel(model string, c *Client) error {
payload := map[string]string{"name": model}

body, err := json.Marshal(payload)
if err != nil {
return 0, fmt.Errorf("failed to encode request payload: %w", err)
return fmt.Errorf("failed to encode request payload: %w", err)
}

req, err := http.NewRequest("POST", c.config.OllamaHost+"/api/pull", bytes.NewBuffer(body))
if err != nil {
return 0, fmt.Errorf("failed to create request: %w", err)
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", c.userAgent)
req.Header.Set("Content-Type", "application/json")

resp, err := c.httpClient.Do(req)
if err != nil {
return 0, fmt.Errorf("failed to connect to pull endpoint: %w", err)
return fmt.Errorf("failed to connect to pull endpoint: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return resp.StatusCode, fmt.Errorf("pull endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody))
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read error response body: %w", err)
}
return fmt.Errorf("pull endpoint is not reachable, received status: %s, body: %s", resp.Status, string(respBody))
}

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("failed to read response body: %w", err)
}
fmt.Printf("Pulling model: %s", model)

decoder := json.NewDecoder(resp.Body)
for decoder.More() {
var message struct {
Error string `json:"error,omitempty"`
}
if err := decoder.Decode(&message); err != nil {
return fmt.Errorf("failed to decode JSON message: %w", err)
}

var pullResponse PullResponse
if err := json.Unmarshal(respBody, &pullResponse); err != nil {
return 0, fmt.Errorf("failed to decode response body: %w", err)
if message.Error != "" {
return fmt.Errorf("'%s': %s", model, message.Error)
}

fmt.Print("...")
time.Sleep(5 * time.Second)
}

return resp.StatusCode, nil
return nil
}

0 comments on commit c5f8bd9

Please sign in to comment.