From ff206aeeab8e2754a47105a21b977bd3b00daa5d Mon Sep 17 00:00:00 2001 From: Tanner Kvarfordt Date: Wed, 26 Apr 2023 22:30:06 -0600 Subject: [PATCH] Added text to image support --- .gitignore | 4 +- README.md | 1 + examples/text_to_image/main.go | 85 ++++++++++++++++++++++++++++++++++ text_to_image.go | 39 ++++++++++++++++ text_to_image_test.go | 41 ++++++++++++++++ 5 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 examples/text_to_image/main.go create mode 100644 text_to_image.go create mode 100644 text_to_image_test.go diff --git a/.gitignore b/.gitignore index c2dee30..e018574 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,6 @@ examples/text_classification/text_classification examples/text_generation/text_generation examples/token_classification/token_classification examples/translation/translation -examples/zeroshot/zeroshot \ No newline at end of file +examples/zeroshot/zeroshot +examples/text_to_image/* +!examples/text_to_image/*.go \ No newline at end of file diff --git a/README.md b/README.md index 2789689..7976a0a 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ See the [examples](./examples) directory. - [Table Question Answering](./examples/table_question_answering/main.go) - [Text Classification](./examples/text_classification/main.go) - [Text Generation](./examples/text_generation/main.go) +- [Text-To-Image](./examples/text_to_image/main.go) - [Token Classification](./examples/token_classification/main.go) - [Translation](./examples/translation/main.go) - [Zero-shot Classification](./examples/zeroshot/main.go) diff --git a/examples/text_to_image/main.go b/examples/text_to_image/main.go new file mode 100644 index 0000000..6b9c90e --- /dev/null +++ b/examples/text_to_image/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "bufio" + "fmt" + "image" + "image/jpeg" + "image/png" + "mime" + "os" + "time" + + "github.com/TannerKvarfordt/hfapigo" +) + +const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" + +func init() { + key := os.Getenv(HuggingFaceTokenEnv) + if key != "" { + hfapigo.SetAPIKey(key) + } +} + +func main() { + fmt.Print("Enter an image prompt: ") + + reader := bufio.NewReader(os.Stdin) + input, err := reader.ReadString('\n') + if err != nil { + fmt.Print(err) + return + } + + type ChanRv struct { + resp image.Image + format string + err error + } + ch := make(chan ChanRv) + + fmt.Print("Sending request") + go func() { + img, fmt, err := hfapigo.SendTextToImageRequest(hfapigo.RecommendedTextToImageModel, &hfapigo.TextToImageRequest{ + Inputs: input, + Options: *hfapigo.NewOptions().SetWaitForModel(true), + }) + ch <- ChanRv{img, fmt, err} + }() + + for { + select { + default: + fmt.Print(".") + time.Sleep(time.Millisecond * 100) + case chrv := <-ch: + filename := fmt.Sprintf("output.%s", chrv.format) + fout, err := os.Create(filename) + if err != nil { + fmt.Println(err) + return + } + defer fout.Close() + + mimetype := mime.TypeByExtension(fmt.Sprintf(".%s", chrv.format)) + + switch mimetype { + case "image/jpeg": + err = jpeg.Encode(fout, chrv.resp, nil) + case "image/png": + err = png.Encode(fout, chrv.resp) + default: + err = fmt.Errorf("unknown image format: %s", chrv.format) + } + + if err != nil { + fmt.Println(err) + } else { + fmt.Printf("\nWrote image to %s\n", filename) + } + + return + } + } +} diff --git a/text_to_image.go b/text_to_image.go new file mode 100644 index 0000000..6ba576b --- /dev/null +++ b/text_to_image.go @@ -0,0 +1,39 @@ +package hfapigo + +import ( + "bytes" + "encoding/json" + "errors" + "image" + + _ "image/jpeg" + _ "image/png" +) + +const RecommendedTextToImageModel = "runwayml/stable-diffusion-v1-5" + +// Request structure for text-to-image model +type TextToImageRequest struct { + Inputs string `json:"inputs,omitempty"` + Options Options `json:"options,omitempty"` +} + +// Send a TextToImageRequest. If successful, returns the generated image object, format name, and nil. +// If unsuccessful, returns nil, "", and an error. +func SendTextToImageRequest(model string, request *TextToImageRequest) (image.Image, string, error) { + if request == nil { + return nil, "", errors.New("nil TextToImageRequest") + } + + jsonBuf, err := json.Marshal(request) + if err != nil { + return nil, "", err + } + + respBody, err := MakeHFAPIRequest(jsonBuf, model) + if err != nil { + return nil, "", err + } + + return image.Decode(bytes.NewReader(respBody)) +} diff --git a/text_to_image_test.go b/text_to_image_test.go new file mode 100644 index 0000000..4e31f2a --- /dev/null +++ b/text_to_image_test.go @@ -0,0 +1,41 @@ +package hfapigo_test + +import ( + "testing" + + "github.com/TannerKvarfordt/hfapigo" +) + +func TestTextToImage(t *testing.T) { + { // Test valid request + img, fmt, err := hfapigo.SendTextToImageRequest(hfapigo.RecommendedTextToImageModel, &hfapigo.TextToImageRequest{ + Inputs: "A dog and a cat sleeping adorably.", + Options: *hfapigo.NewOptions().SetWaitForModel(true), + }) + if err != nil { + t.Fatal(err) + } + if fmt == "" { + t.Fatal("empty encoding returned") + } + if img == nil { + t.Fatal("nil image returned") + } + } + + { // Test invalid request + img, fmt, err := hfapigo.SendTextToImageRequest("not-a-model", &hfapigo.TextToImageRequest{ + Inputs: "A dog and a cat sleeping adorably.", + Options: *hfapigo.NewOptions().SetWaitForModel(true), + }) + if err == nil { + t.Fatal("expected an error") + } + if fmt != "" { + t.Fatal("expected an empty encoding string") + } + if img != nil { + t.Fatal("expected a nil image") + } + } +}