Skip to content

feat(go/plugins/googlegenai): add image-generation models #2903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,30 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo
provider = vertexAIProvider
}

var config any
config = &genai.GenerateContentConfig{}
if mi, found := supportedImagenModels[name]; found {
config = &genai.GenerateImagesConfig{}
info = mi
}
meta := &ai.ModelInfo{
Label: info.Label,
Supports: info.Supports,
Versions: info.Versions,
ConfigSchema: configToMap(genai.GenerateContentConfig{}),
ConfigSchema: configToMap(config),
}

fn := func(
ctx context.Context,
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
return generate(ctx, client, name, input, cb)
switch config.(type) {
case *genai.GenerateImagesConfig:
return generateImage(ctx, client, name, input, cb)
default:
return generate(ctx, client, name, input, cb)
}
}
// the gemini api doesn't support downloading media from http(s)
if info.Supports.Media {
Expand Down
117 changes: 117 additions & 0 deletions go/plugins/googlegenai/imagen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package googlegenai

import (
"context"
"encoding/base64"
"fmt"

"github.com/firebase/genkit/go/ai"
"google.golang.org/genai"
)

// Media describes model capabilities for Gemini models with media and text
// input and image only output
var Media = ai.ModelSupports{
Media: true,
Multiturn: false,
Tools: false,
ToolChoice: false,
SystemRole: false,
}

// imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*genai.GenerateImagesConfig]
func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) {
var result genai.GenerateImagesConfig

switch config := input.Config.(type) {
case genai.GenerateImagesConfig:
result = config
case *genai.GenerateImagesConfig:
result = *config
case map[string]any:
if err := mapToStruct(config, &result); err != nil {
return nil, err
}
case nil:
// empty but valid config
default:
return nil, fmt.Errorf("unexpected config type: %T", input.Config)
}

return &result, nil
}

// translateImagenCandidates translates the image generation response to [*ai.ModelResponse]
func translateImagenCandidates(images []*genai.GeneratedImage) *ai.ModelResponse {
m := &ai.ModelResponse{}
m.FinishReason = ai.FinishReasonStop

msg := &ai.Message{}
msg.Role = ai.RoleModel

for _, img := range images {
msg.Content = append(msg.Content, ai.NewMediaPart(img.Image.MIMEType, "data:"+img.Image.MIMEType+";base64,"+base64.StdEncoding.EncodeToString(img.Image.ImageBytes)))
}

m.Message = msg
return m
}

// translateImagenResponse translates [*genai.GenerateImagesResponse] to an [*ai.ModelResponse]
func translateImagenResponse(resp *genai.GenerateImagesResponse) *ai.ModelResponse {
return translateImagenCandidates(resp.GeneratedImages)
}

// generateImage requests a generate call to the specified imagen model with the
// provided configuration
func generateImage(
ctx context.Context,
client *genai.Client,
model string,
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
gic, err := imagenConfigFromRequest(input)
if err != nil {
return nil, err
}

var userPrompt string
for _, m := range input.Messages {
if m.Role == ai.RoleUser {
userPrompt += m.Text()
}
}
if userPrompt == "" {
return nil, fmt.Errorf("error generating images: empty prompt detected")
}

if cb != nil {
return nil, fmt.Errorf("streaming mode not supported for image generation")
}

resp, err := client.Models.GenerateImages(ctx, model, userPrompt, gic)
if err != nil {
return nil, err
}

r := translateImagenResponse(resp)
r.Request = input
return r, nil
}
59 changes: 53 additions & 6 deletions go/plugins/googlegenai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ const (
gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25"
gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25"
gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06"

imagen3Generate001 = "imagen-3.0-generate-001"
imagen3Generate002 = "imagen-3.0-generate-002"
imagen3FastGenerate001 = "imagen-3.0-fast-generate-001"
)

var (
Expand All @@ -50,6 +54,10 @@ var (
gemini25ProExp0325,
gemini25ProPreview0325,
gemini25ProPreview0506,

imagen3Generate001,
imagen3Generate002,
imagen3FastGenerate001,
}

googleAIModels = []string{
Expand All @@ -66,9 +74,11 @@ var (
gemini25ProExp0325,
gemini25ProPreview0325,
gemini25ProPreview0506,

imagen3Generate002,
}

// models with native image support generation
// Gemini models with native image support generation
imageGenModels = []string{
gemini20FlashPrevImageGen,
}
Expand Down Expand Up @@ -175,6 +185,27 @@ var (
},
}

supportedImagenModels = map[string]ai.ModelInfo{
imagen3Generate001: {
Label: "Imagen 3 Generate 001",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
imagen3Generate002: {
Label: "Imagen 3 Generate 002",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
imagen3FastGenerate001: {
Label: "Imagen 3 Fast Generate 001",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
}

googleAIEmbedders = []string{
"text-embedding-004",
"embedding-001",
Expand All @@ -194,7 +225,7 @@ var (
// listModels returns a map of supported models and their capabilities
// based on the detected backend
func listModels(provider string) (map[string]ai.ModelInfo, error) {
names := []string{}
var names []string
var prefix string

switch provider {
Expand All @@ -210,7 +241,13 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {

models := make(map[string]ai.ModelInfo, 0)
for _, n := range names {
m, ok := supportedGeminiModels[n]
var m ai.ModelInfo
var ok bool
if strings.HasPrefix(n, "image") {
m, ok = supportedImagenModels[n]
} else {
m, ok = supportedGeminiModels[n]
}
if !ok {
return nil, fmt.Errorf("model %s not found for provider %s", n, provider)
}
Expand All @@ -227,7 +264,7 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {
// listEmbedders returns a list of supported embedders based on the
// detected backend
func listEmbedders(backend genai.Backend) ([]string, error) {
embedders := []string{}
var embedders []string

switch backend {
case genai.BackendGeminiAPI:
Expand All @@ -242,9 +279,10 @@ func listEmbedders(backend genai.Backend) ([]string, error) {
}

// genaiModels collects all the available models in go-genai SDK
// TODO: add imagen and veo models
// TODO: add veo models
type genaiModels struct {
gemini []string
imagen []string
embedders []string
}

Expand All @@ -253,6 +291,7 @@ type genaiModels struct {
func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) {
models := genaiModels{}
allowedModels := []string{"gemini", "gemma"}
allowedImagenModels := []string{"imagen"}

for item, err := range client.Models.All(ctx) {
var name string
Expand Down Expand Up @@ -283,7 +322,15 @@ func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, er
continue
}

// TODO: add imagen and veo models
found = slices.ContainsFunc(allowedImagenModels, func(s string) bool {
return strings.Contains(name, s)
})
// filter out: Aqa, Text-bison, Chat, learnlm
if found {
models.imagen = append(models.imagen, name)
continue
}
}

return models, nil
}
61 changes: 61 additions & 0 deletions go/samples/imagen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"log"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/googlegenai"
"google.golang.org/genai"
)

func main() {
ctx := context.Background()
g, err := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.VertexAI{}))
if err != nil {
log.Fatal(err)
}

genkit.DefineFlow(g, "image-generation", func(ctx context.Context, input string) ([]string, error) {
r, err := genkit.Generate(ctx, g,
ai.WithModelName("vertexai/imagen-3.0-generate-001"),
ai.WithPrompt("Generate an image of %s", input),
ai.WithConfig(&genai.GenerateImagesConfig{
NumberOfImages: 2,
NegativePrompt: "night",
AspectRatio: "9:16",
SafetyFilterLevel: genai.SafetyFilterLevelBlockLowAndAbove,
PersonGeneration: genai.PersonGenerationAllowAll,
Language: genai.ImagePromptLanguageEn,
AddWatermark: true,
OutputMIMEType: "image/jpeg",
}),
)
if err != nil {
log.Fatal(err)
}

var images []string
for _, m := range r.Message.Content {
images = append(images, m.Text)
}
return images, nil
})

<-ctx.Done()
}
Loading