Skip to content

Commit

Permalink
Added text to image support
Browse files Browse the repository at this point in the history
  • Loading branch information
Kardbord committed Apr 27, 2023
1 parent da40227 commit ff206ae
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ examples/text_classification/text_classification
examples/text_generation/text_generation
examples/token_classification/token_classification
examples/translation/translation
examples/zeroshot/zeroshot
examples/zeroshot/zeroshot
examples/text_to_image/*
!examples/text_to_image/*.go
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ See the [examples](./examples) directory.
- [Table Question Answering](./examples/table_question_answering/main.go)
- [Text Classification](./examples/text_classification/main.go)
- [Text Generation](./examples/text_generation/main.go)
- [Text-To-Image](./examples/text_to_image/main.go)
- [Token Classification](./examples/token_classification/main.go)
- [Translation](./examples/translation/main.go)
- [Zero-shot Classification](./examples/zeroshot/main.go)
Expand Down
85 changes: 85 additions & 0 deletions examples/text_to_image/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package main

import (
"bufio"
"fmt"
"image"
"image/jpeg"
"image/png"
"mime"
"os"
"time"

"github.com/TannerKvarfordt/hfapigo"
)

const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN"

func init() {
key := os.Getenv(HuggingFaceTokenEnv)
if key != "" {
hfapigo.SetAPIKey(key)
}
}

func main() {
fmt.Print("Enter an image prompt: ")

reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
fmt.Print(err)
return
}

type ChanRv struct {
resp image.Image
format string
err error
}
ch := make(chan ChanRv)

fmt.Print("Sending request")
go func() {
img, fmt, err := hfapigo.SendTextToImageRequest(hfapigo.RecommendedTextToImageModel, &hfapigo.TextToImageRequest{
Inputs: input,
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
ch <- ChanRv{img, fmt, err}
}()

for {
select {
default:
fmt.Print(".")
time.Sleep(time.Millisecond * 100)
case chrv := <-ch:
filename := fmt.Sprintf("output.%s", chrv.format)
fout, err := os.Create(filename)
if err != nil {
fmt.Println(err)
return
}
defer fout.Close()

mimetype := mime.TypeByExtension(fmt.Sprintf(".%s", chrv.format))

switch mimetype {
case "image/jpeg":
err = jpeg.Encode(fout, chrv.resp, nil)
case "image/png":
err = png.Encode(fout, chrv.resp)
default:
err = fmt.Errorf("unknown image format: %s", chrv.format)
}

if err != nil {
fmt.Println(err)
} else {
fmt.Printf("\nWrote image to %s\n", filename)
}

return
}
}
}
39 changes: 39 additions & 0 deletions text_to_image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package hfapigo

import (
"bytes"
"encoding/json"
"errors"
"image"

_ "image/jpeg"
_ "image/png"
)

const RecommendedTextToImageModel = "runwayml/stable-diffusion-v1-5"

// Request structure for text-to-image model
type TextToImageRequest struct {
Inputs string `json:"inputs,omitempty"`
Options Options `json:"options,omitempty"`
}

// Send a TextToImageRequest. If successful, returns the generated image object, format name, and nil.
// If unsuccessful, returns nil, "", and an error.
func SendTextToImageRequest(model string, request *TextToImageRequest) (image.Image, string, error) {
if request == nil {
return nil, "", errors.New("nil TextToImageRequest")
}

jsonBuf, err := json.Marshal(request)
if err != nil {
return nil, "", err
}

respBody, err := MakeHFAPIRequest(jsonBuf, model)
if err != nil {
return nil, "", err
}

return image.Decode(bytes.NewReader(respBody))
}
41 changes: 41 additions & 0 deletions text_to_image_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package hfapigo_test

import (
"testing"

"github.com/TannerKvarfordt/hfapigo"
)

func TestTextToImage(t *testing.T) {
{ // Test valid request
img, fmt, err := hfapigo.SendTextToImageRequest(hfapigo.RecommendedTextToImageModel, &hfapigo.TextToImageRequest{
Inputs: "A dog and a cat sleeping adorably.",
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err != nil {
t.Fatal(err)
}
if fmt == "" {
t.Fatal("empty encoding returned")
}
if img == nil {
t.Fatal("nil image returned")
}
}

{ // Test invalid request
img, fmt, err := hfapigo.SendTextToImageRequest("not-a-model", &hfapigo.TextToImageRequest{
Inputs: "A dog and a cat sleeping adorably.",
Options: *hfapigo.NewOptions().SetWaitForModel(true),
})
if err == nil {
t.Fatal("expected an error")
}
if fmt != "" {
t.Fatal("expected an empty encoding string")
}
if img != nil {
t.Fatal("expected a nil image")
}
}
}

0 comments on commit ff206ae

Please sign in to comment.