From 03ca97f882b9b09f711141c51815d205263d1979 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 6 Jun 2025 02:04:42 +0000 Subject: [PATCH 1/2] feat(vertex_ai_example): Add image editing and upscaling examples to Imagen page I've implemented example usage for the new ImagenModel capabilities (upscaling, inpainting/outpainting, and mask-free editing) in the `imagen_page.dart` of the example app. Key changes: - I added UI elements (buttons, dropdowns, text fields) for selecting images, specifying prompts, and configuring parameters for upscaling and editing. - I implemented image picking functionality using the `image_picker` package to allow you to select source and mask images. - I added functions to call `ImagenModel.upscaleImage()` and `ImagenModel.editImage()` with the user-provided inputs. - The app now displays the resulting images or any errors encountered during the process in the existing UI structure. - I reused existing mechanisms for loading indicators and error dialogs. --- .../firebase_ai/lib/firebase_ai.dart | 7 +- .../firebase_ai/lib/src/imagen_api.dart | 85 ++++ .../firebase_ai/lib/src/imagen_model.dart | 86 +++- .../firebase_ai/test/imagen_test.dart | 371 ++++++++++++++++ .../example/lib/pages/imagen_page.dart | 398 ++++++++++++++++-- 5 files changed, 908 insertions(+), 39 deletions(-) diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index dbc95e1bca24..e7af8ecc96a0 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -65,8 +65,11 @@ export 'src/imagen_api.dart' ImagenSafetyFilterLevel, ImagenPersonFilterLevel, ImagenGenerationConfig, - ImagenAspectRatio; -export 'src/imagen_content.dart' show ImagenInlineImage; + ImagenAspectRatio, + ImagenEditingConfig, + ImagenEditMode, + ImagenUpscaleFactor; +export 'src/imagen_content.dart' show ImagenInlineImage, ImagenGenerationResponse; export 'src/live_api.dart' show LiveGenerationConfig, diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart index 1579b6740a92..1a5b02d17841 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen_api.dart @@ -15,6 +15,8 @@ import 'dart:developer'; import 'package:meta/meta.dart'; +import 'imagen_content.dart'; + /// Specifies the level of safety filtering for image generation. /// /// If not specified, default will be "block_medium_and_above". @@ -232,3 +234,86 @@ final class ImagenFormat { 'compressionQuality': compressionQuality, }; } + +/// Enum representing the mode for image editing. +@experimental +enum ImagenEditMode { + /// Inpaint mode for image editing. + inpaint, + + /// Outpaint mode for image editing. + outpaint, +} + +/// Enum representing the upscale factor for image upscaling. +@experimental +enum ImagenUpscaleFactor { + /// Upscale factor of 2x. + x2('x2'), + + /// Upscale factor of 4x. + x4('x4'); + + const ImagenUpscaleFactor(this._jsonString); + + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; +} + +/// Configuration for Imagen image editing. +@experimental +final class ImagenEditingConfig { + /// Source image for editing. + final ImagenInlineImage image; + + /// Mask image for editing, optional for mask-free editing. + final ImagenInlineImage? mask; + + /// Mask dilation factor. + final double? maskDilation; + + /// Number of editing steps. + final int? editSteps; + + /// Number of images to generate. + final int? numberOfImages; + + /// Editing mode. + final ImagenEditMode? editMode; + + // ignore: public_member_api_docs + ImagenEditingConfig({ + required this.image, + this.mask, + this.maskDilation, + this.editSteps, + this.numberOfImages, + this.editMode, + }); + + /// Factory constructor for mask-free image editing. + /// Takes numberOfImages as an optional parameter. + factory ImagenEditingConfig.maskFree({ + required ImagenInlineImage image, + int? numberOfImages, + }) { + return ImagenEditingConfig( + image: image, + numberOfImages: numberOfImages, + // Mask and editMode related to masking are left null for mask-free. + // Other fields like maskDilation, editSteps are also left null. + ); + } + + // ignore: public_member_api_docs + Map toJson() => { + 'image': image.toJson(), + if (mask != null) 'mask': mask!.toJson(), + if (maskDilation != null) 'maskDilation': maskDilation, + if (editSteps != null) 'editSteps': editSteps, + if (numberOfImages != null) 'numberOfImages': numberOfImages, + if (editMode != null) 'editMode': editMode!.name, + }; +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart index bf4731a3b264..4133d6ae3db8 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart @@ -14,6 +14,9 @@ part of 'base_model.dart'; +import 'imagen_api.dart'; +import 'imagen_content.dart'; + /// Represents a remote Imagen model with the ability to generate images using /// text prompts. /// @@ -57,7 +60,7 @@ final class ImagenModel extends BaseApiClientModel { if (gcsUri != null) 'storageUri': gcsUri, 'sampleCount': _generationConfig?.numberOfImages ?? 1, if (_generationConfig?.aspectRatio case final aspectRatio?) - 'aspectRatio': aspectRatio, + 'aspectRatio': aspectRatio.toJson(), if (_generationConfig?.negativePrompt case final negativePrompt?) 'negativePrompt': negativePrompt, if (_generationConfig?.addWatermark case final addWatermark?) @@ -110,6 +113,87 @@ final class ImagenModel extends BaseApiClientModel { (jsonObject) => parseImagenGenerationResponse(jsonObject), ); + + /// Edits an image based on a prompt and configuration. + @experimental + Future> editImage( + String prompt, { + required ImagenEditingConfig config, + }) async { + // Construct the request payload. + final payload = { + 'instances': [ + { + 'prompt': prompt, + 'image': config.image.toJson(), + if (config.mask != null) 'mask': config.mask!.toJson(), + } + ], + 'parameters': { + if (config.editMode != null) 'editMode': config.editMode!.name, + if (config.maskDilation != null) 'maskDilation': config.maskDilation, + if (config.editSteps != null) 'editSteps': config.editSteps, + 'sampleCount': config.numberOfImages ?? _generationConfig?.numberOfImages ?? 1, + + // Parameters from model-level _generationConfig and _safetySettings + if (_generationConfig?.aspectRatio case final aspectRatio?) + 'aspectRatio': aspectRatio.toJson(), + if (_generationConfig?.negativePrompt case final negativePrompt?) + 'negativePrompt': negativePrompt, + if (_generationConfig?.addWatermark case final addWatermark?) + 'addWatermark': addWatermark, + if (_generationConfig?.imageFormat case final imageFormat?) + 'outputOption': imageFormat.toJson(), + if (_safetySettings?.personFilterLevel case final personFilterLevel?) + 'personGeneration': personFilterLevel.toJson(), + if (_safetySettings?.safetyFilterLevel case final safetyFilterLevel?) + 'safetySetting': safetyFilterLevel.toJson(), + }, + }; + + return makeRequest( + Task.predict, + payload, + (jsonObject) => parseImagenGenerationResponse(jsonObject), + ); + } + + /// Upscales an image. + @experimental + Future> upscaleImage({ + required ImagenInlineImage image, + required ImagenUpscaleFactor upscaleFactor, + ImagenSafetySettings? safetySettings, + ImagenGenerationConfig? generationConfig, + }) async { + // Construct the request payload for upscaling. + final payload = { + 'instances': [ + { + 'image': image.toJson(), + } + ], + 'parameters': { + 'upscaleFactor': upscaleFactor.toJson(), + if (generationConfig?.aspectRatio ?? _generationConfig?.aspectRatio case final aspectRatio?) + 'aspectRatio': aspectRatio.toJson(), + if (generationConfig?.addWatermark ?? _generationConfig?.addWatermark case final addWatermark?) + 'addWatermark': addWatermark, + if (generationConfig?.imageFormat ?? _generationConfig?.imageFormat case final imageFormat?) + 'outputOption': imageFormat.toJson(), + if (safetySettings?.personFilterLevel ?? _safetySettings?.personFilterLevel case final personFilterLevel?) + 'personGeneration': personFilterLevel.toJson(), + if (safetySettings?.safetyFilterLevel ?? _safetySettings?.safetyFilterLevel case final safetyFilterLevel?) + 'safetySetting': safetyFilterLevel.toJson(), + }, + }; + + return makeRequest( + Task.predict, + payload, + (jsonObject) => parseImagenGenerationResponse(jsonObject), + ); + } } /// Returns a [ImagenModel] using it's private constructor. diff --git a/packages/firebase_ai/firebase_ai/test/imagen_test.dart b/packages/firebase_ai/firebase_ai/test/imagen_test.dart index 4bd7ae5b763a..80891344dcf1 100644 --- a/packages/firebase_ai/firebase_ai/test/imagen_test.dart +++ b/packages/firebase_ai/firebase_ai/test/imagen_test.dart @@ -17,9 +17,32 @@ import 'dart:typed_data'; import 'package:firebase_ai/src/error.dart'; import 'package:firebase_ai/src/imagen_content.dart'; +import 'package:firebase_ai/firebase_ai.dart'; +import 'package:firebase_ai/src/api_client.dart'; +import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; +import 'package:mockito/annotations.dart'; +import 'package:mockito/mockito.dart'; +import 'mock.dart'; +import 'imagen_test.mocks.dart'; // Generated by Mockito +// Mock HttpApiClient +@GenerateMocks([HttpApiClient]) void main() { + setupFirebaseCoreMocks(); + + setUpAll(() async { + await Firebase.initializeApp( + name: 'testApp', + options: const FirebaseOptions( + apiKey: 'test-api-key', + appId: 'test-app-id', + messagingSenderId: 'test-sender-id', + projectId: 'test-project-id', + ), + ); + }); + group('ImagenInlineImage', () { test('fromJson with valid base64', () { final json = { @@ -238,4 +261,352 @@ void main() { throwsA(isA())); }); }); + + group('ImagenModel Tests', () { + late ImagenModel imagenModel; + late MockHttpApiClient mockClient; + final app = Firebase.app('testApp'); + + setUp(() { + mockClient = MockHttpApiClient(); + imagenModel = ImagenModel._( + app: app, + model: 'gemini-1.5-flash', // Example model + location: 'us-central1', + useVertexBackend: true, // Assuming Vertex backend for these tests + client: mockClient, + // No default generationConfig or safetySettings for cleaner test isolation + ); + }); + + group('editImage', () { + final sourceImageBytes = Uint8List.fromList(utf8.encode('source_image_bytes')); + final sourceImage = ImagenInlineImage(bytesBase64Encoded: sourceImageBytes, mimeType: 'image/png'); + final maskImageBytes = Uint8List.fromList(utf8.encode('mask_image_bytes')); + final maskImage = ImagenInlineImage(bytesBase64Encoded: maskImageBytes, mimeType: 'image/png'); + const prompt = 'a test prompt'; + + test('should construct correct payload for mask-free editing', () async { + // Assuming model has no default _generationConfig or _safetySettings for this test + final config = ImagenEditingConfig.maskFree(image: sourceImage, numberOfImages: 2); + final expectedPayload = { + 'instances': [ + { + 'prompt': prompt, + 'image': sourceImage.toJson(), + } + ], + 'parameters': { + 'sampleCount': 2, // From config + // No other parameters expected if model defaults are null + }, + }; + + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); // Dummy success + + await imagenModel.editImage(prompt, config: config); + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + expect(captured[1], equals(expectedPayload)); + }); + + test('should construct correct payload for full editing config', () async { + final config = ImagenEditingConfig( + image: sourceImage, + mask: maskImage, + maskDilation: 0.05, + editSteps: 60, + numberOfImages: 3, + editMode: ImagenEditMode.inpaint, + // negativePrompt, safetySettings, and generationConfig are no longer part of ImagenEditingConfig + ); + + // For this test, let's assume the model has some default _generationConfig and _safetySettings + // to see them appear in the payload. + final modelForFullTest = ImagenModel._( + app: app, + model: 'gemini-1.5-flash', + location: 'us-central1', + useVertexBackend: true, + client: mockClient, + generationConfig: ImagenGenerationConfig( + aspectRatio: ImagenAspectRatio.landscape16x9, + imageFormat: ImagenFormat.jpeg(compressionQuality: 80), + addWatermark: false, + negativePrompt: 'model-level blurry', + ), + safetySettings: ImagenSafetySettings(ImagenSafetyFilterLevel.blockLowAndAbove, ImagenPersonFilterLevel.blockAll), + ); + + final expectedPayload = { + 'instances': [ + { + 'prompt': prompt, + 'image': sourceImage.toJson(), + 'mask': maskImage.toJson(), + } + ], + 'parameters': { + // From config + 'editMode': 'inpaint', + 'maskDilation': 0.05, + 'editSteps': 60, + 'sampleCount': 3, + // From model-level settings + 'aspectRatio': '16:9', + 'negativePrompt': 'model-level blurry', + 'addWatermark': false, + 'outputOption': {'mimeType': 'image/jpeg', 'compressionQuality': 80}, + 'personGeneration': 'dont_allow', + 'safetySetting': 'block_low_and_above', + }, + }; + + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); + + await modelForFullTest.editImage(prompt, config: config); // Use modelForFullTest here + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + expect(captured[1], equals(expectedPayload)); + }); + + test('should parse successful response', () async { + final config = ImagenEditingConfig.maskFree(image: sourceImage); + final apiResponse = { + 'predictions': [ + {'bytesBase64Encoded': base64Encode(sourceImageBytes), 'mimeType': 'image/png'} + ] + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + + final response = await imagenModel.editImage(prompt, config: config); // Use default imagenModel + expect(response.images.length, 1); + expect(response.images.first.bytesBase64Encoded, equals(sourceImageBytes)); + expect(response.images.first.mimeType, 'image/png'); + }); + + test('should throw ImagenImagesBlockedException for filtered response', () async { + final config = ImagenEditingConfig.maskFree(image: sourceImage); + final apiResponse = { + 'predictions': [ + {'raiFilteredReason': 'Blocked due to safety reasons'} + ] + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + + expect( + () => imagenModel.editImage(prompt, config: config), // Use default imagenModel + throwsA(isA())); + }); + + test('should throw ServerException for error response', () async { + final config = ImagenEditingConfig.maskFree(image: sourceImage); + final apiResponse = { + 'error': {'code': 400, 'message': 'Bad request', 'status': 'INVALID_ARGUMENT'} + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + + expect( + () => imagenModel.editImage(prompt, config: config), // Use default imagenModel + throwsA(isA())); + }); + + test('model-level defaults should be used when not overridden by ImagenEditingConfig', () async { + // Initialize model with default settings + final defaultSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockMediumAndAbove, ImagenPersonFilterLevel.allowAdult); + final defaultGeneration = ImagenGenerationConfig( + negativePrompt: "default negative", + numberOfImages: 1, // This will be overridden by config.numberOfImages if set, or config's default + aspectRatio: ImagenAspectRatio.square1x1, + addWatermark: true, + imageFormat: ImagenFormat.png(), + ); + + final modelWithDefaults = ImagenModel._( + app: app, + model: 'gemini-1.5-flash', + location: 'us-central1', + useVertexBackend: true, + client: mockClient, + safetySettings: defaultSafety, + generationConfig: defaultGeneration, + ); + + // ImagenEditingConfig now only has image and numberOfImages that could affect these global params + final config = ImagenEditingConfig( + image: sourceImage, + numberOfImages: 2, // This will override defaultGeneration.numberOfImages for 'sampleCount' + editMode: ImagenEditMode.outpaint, // Specific to edit + ); + + final expectedPayloadParameters = { + // From config + 'sampleCount': 2, + 'editMode': 'outpaint', + // From model-level defaults (defaultGeneration and defaultSafety) + 'aspectRatio': '1:1', + 'negativePrompt': 'default negative', + 'addWatermark': true, + 'outputOption': {'mimeType': 'image/png'}, + 'personGeneration': 'allow_adult', + 'safetySetting': 'block_medium_and_above', + }; + + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); + + await modelWithDefaults.editImage(prompt, config: config); + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + // We are interested in the 'parameters' part of the payload (captured[1]) + final actualParameters = captured[1]['parameters'] as Map; + expect(actualParameters, equals(expectedPayloadParameters)); + }); + }); + + group('upscaleImage', () { + final sourceImageBytes = Uint8List.fromList(utf8.encode('source_image_bytes_for_upscale')); + final sourceImage = ImagenInlineImage(bytesBase64Encoded: sourceImageBytes, mimeType: 'image/jpeg'); + + test('should construct correct payload for basic upscaling', () async { + final expectedPayload = { + 'instances': [ + {'image': sourceImage.toJson()} + ], + 'parameters': { + 'upscaleFactor': 'x2', + }, + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); + + await imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2); + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + expect(captured[1], equals(expectedPayload)); + }); + + test('should construct correct payload with optional safety and generation config', () async { + final safety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockLowAndAbove, ImagenPersonFilterLevel.blockAll); + final generation = ImagenGenerationConfig( + imageFormat: ImagenFormat.png(), + addWatermark: true, + ); + final expectedPayload = { + 'instances': [ + {'image': sourceImage.toJson()} + ], + 'parameters': { + 'upscaleFactor': 'x4', + 'outputOption': {'mimeType': 'image/png'}, + 'addWatermark': true, + 'personGeneration': 'dont_allow', + 'safetySetting': 'block_low_and_above', + }, + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); + + await imagenModel.upscaleImage( + image: sourceImage, + upscaleFactor: ImagenUpscaleFactor.x4, + safetySettings: safety, + generationConfig: generation, + ); + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + expect(captured[1], equals(expectedPayload)); + }); + + test('should parse successful upscale response', () async { + final apiResponse = { + 'predictions': [ + {'bytesBase64Encoded': base64Encode(sourceImageBytes), 'mimeType': 'image/jpeg'} + ] + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + + final response = await imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2); + expect(response.images.length, 1); + expect(response.images.first.bytesBase64Encoded, equals(sourceImageBytes)); + expect(response.images.first.mimeType, 'image/jpeg'); + }); + + test('should throw ImagenImagesBlockedException for filtered upscale response', () async { + final apiResponse = { + 'predictions': [ + {'raiFilteredReason': 'Blocked due to safety reasons'} + ] + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + + expect( + () => imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2), + throwsA(isA())); + }); + + test('should throw ServerException for error in upscale response', () async { + final apiResponse = { + 'error': {'code': 500, 'message': 'Internal server error', 'status': 'UNAVAILABLE'} + }; + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); + expect( + () => imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2), + throwsA(isA())); + }); + + test('upscale method parameters should override model-level defaults', () async { + final defaultSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockNone, ImagenPersonFilterLevel.allowAll); + final defaultGeneration = ImagenGenerationConfig(imageFormat: ImagenFormat.jpeg(), addWatermark: false); + + final modelWithDefaults = ImagenModel._( + app: app, + model: 'gemini-1.5-flash', + location: 'us-central1', + useVertexBackend: true, + client: mockClient, + safetySettings: defaultSafety, + generationConfig: defaultGeneration, + ); + + final methodSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockMediumAndAbove, null); // Override safety, keep person from default + final methodGeneration = ImagenGenerationConfig(addWatermark: true, imageFormat: ImagenFormat.png()); // Override watermark and format + + final expectedPayloadParameters = { + 'upscaleFactor': 'x2', + 'outputOption': {'mimeType': 'image/png'}, // from methodGeneration + 'addWatermark': true, // from methodGeneration + 'personGeneration': 'allow_all', // from model default + 'safetySetting': 'block_medium_and_above', // from methodSafety + }; + + when(mockClient.makeRequest>(any, any, any)) + .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); + + await modelWithDefaults.upscaleImage( + image: sourceImage, + upscaleFactor: ImagenUpscaleFactor.x2, + safetySettings: methodSafety, + generationConfig: methodGeneration, + ); + final captured = verify(mockClient.makeRequest>( + captureAny, captureAny, captureAny)) + .captured; + expect(captured[1]['parameters'], equals(expectedPayloadParameters)); + }); + }); + }); } diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart index 0ab750b13fef..7bb7ee6138cf 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +import 'dart:typed_data'; +import 'dart:io'; // Though XFile.readAsBytes is available, good to have for other File ops if needed. +import 'package:image_picker/image_picker.dart'; + import 'package:flutter/material.dart'; import 'package:firebase_vertexai/firebase_vertexai.dart'; //import 'package:firebase_storage/firebase_storage.dart'; @@ -38,6 +42,20 @@ class _ImagenPageState extends State { final List _generatedContent = []; bool _loading = false; + // For image picking + ImagenInlineImage? _sourceImageForUpscaling; + ImagenInlineImage? _sourceImageForEditing; + ImagenInlineImage? _maskImageForEditing; + + // For upscale factor + ImagenUpscaleFactor _selectedUpscaleFactor = ImagenUpscaleFactor.x2; // Default + + // For editing parameters + final TextEditingController _editPromptController = TextEditingController(); + final TextEditingController _maskDilationController = TextEditingController(); + final TextEditingController _editStepsController = TextEditingController(); + ImagenEditMode _selectedEditMode = ImagenEditMode.inpaint; // Default + void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( (_) => _scrollController.animateTo( @@ -80,45 +98,160 @@ class _ImagenPageState extends State { vertical: 25, horizontal: 15, ), - child: Row( + child: Column( children: [ - Expanded( - child: TextField( - autofocus: true, - focusNode: _textFieldFocus, - controller: _textController, - ), + // Generate Image Row + Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + decoration: const InputDecoration( + hintText: 'Enter a prompt...', + ), + controller: _textController, + ), + ), + const SizedBox.square(dimension: 15), + if (!_loading) + IconButton( + onPressed: () async { + await _generateImageFromPrompt(_textController.text); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Generate Image', + ) + else + const CircularProgressIndicator(), + ], ), - const SizedBox.square( - dimension: 15, + const SizedBox(height: 20), + // Upscaling UI + const Text('Image Upscaling', style: TextStyle(fontWeight: FontWeight.bold)), + Row( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + children: [ + ElevatedButton( + onPressed: _loading ? null : _pickSourceImageForUpscaling, + child: const Text('Pick Source'), + ), + if (_sourceImageForUpscaling != null) + Text(_sourceImageForUpscaling!.mimeType, style: const TextStyle(fontSize: 12)), + DropdownButton( + value: _selectedUpscaleFactor, + onChanged: _loading ? null : (ImagenUpscaleFactor? newValue) { + if (newValue != null) { + setState(() { + _selectedUpscaleFactor = newValue; + }); + } + }, + items: ImagenUpscaleFactor.values + .map>( + (ImagenUpscaleFactor value) { + return DropdownMenuItem( + value: value, + child: Text(value.name), + ); + }).toList(), + ), + ElevatedButton( + onPressed: _loading || _sourceImageForUpscaling == null ? null : _upscaleImage, + child: const Text('Upscale Image'), + ), + ], + ), + const SizedBox(height: 20), + // Editing UI + const Text('Image Editing', style: TextStyle(fontWeight: FontWeight.bold)), + Row( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + children: [ + ElevatedButton( + onPressed: _loading ? null : _pickSourceImageForEditing, + child: const Text('Pick Source'), + ), + if (_sourceImageForEditing != null) + Text(_sourceImageForEditing!.mimeType, style: const TextStyle(fontSize: 12)), + Expanded( + child: TextField( + controller: _editPromptController, + decoration: const InputDecoration(hintText: 'Edit prompt'), + enabled: !_loading, + ), + ), + ], + ), + const SizedBox(height: 10), + // Inpaint/Outpaint Specific UI + const Text('Inpaint/Outpaint', style: TextStyle(fontStyle: FontStyle.italic)), + Row( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + children: [ + ElevatedButton( + onPressed: _loading ? null : _pickMaskImageForEditing, + child: const Text('Pick Mask'), + ), + if (_maskImageForEditing != null) + Text(_maskImageForEditing!.mimeType, style: const TextStyle(fontSize: 12)), + DropdownButton( + value: _selectedEditMode, + onChanged: _loading ? null : (ImagenEditMode? newValue) { + if (newValue != null) { + setState(() { + _selectedEditMode = newValue; + }); + } + }, + items: ImagenEditMode.values + .map>( + (ImagenEditMode value) { + return DropdownMenuItem( + value: value, + child: Text(value.name), + ); + }).toList(), + ), + ], ), - if (!_loading) - IconButton( - onPressed: () async { - await _testImagen(_textController.text); - }, - icon: Icon( - Icons.image_search, - color: Theme.of(context).colorScheme.primary, + Row( + children: [ + Expanded( + child: TextField( + controller: _maskDilationController, + decoration: const InputDecoration(hintText: 'Mask Dilation (e.g., 0.01)'), + keyboardType: const TextInputType.numberWithOptions(decimal: true), + enabled: !_loading, + ), + ), + const SizedBox(width: 10), + Expanded( + child: TextField( + controller: _editStepsController, + decoration: const InputDecoration(hintText: 'Edit Steps (e.g., 50)'), + keyboardType: TextInputType.number, + enabled: !_loading, + ), ), - tooltip: 'Imagen raw data', - ) - else - const CircularProgressIndicator(), - // NOTE: Keep this API private until future release. - // if (!_loading) - // IconButton( - // onPressed: () async { - // await _testImagenGCS(_textController.text); - // }, - // icon: Icon( - // Icons.imagesearch_roller, - // color: Theme.of(context).colorScheme.primary, - // ), - // tooltip: 'Imagen GCS', - // ) - // else - // const CircularProgressIndicator(), + ], + ), + ElevatedButton( + onPressed: _loading || _sourceImageForEditing == null || _maskImageForEditing == null + ? null + : _editImageInpaintOutpaint, + child: const Text('Edit (Inpaint/Outpaint)'), + ), + const SizedBox(height: 10), + // Mask-Free Editing Button + const Text('Mask-Free Edit', style: TextStyle(fontStyle: FontStyle.italic)), + ElevatedButton( + onPressed: _loading || _sourceImageForEditing == null ? null : _editImageMaskFree, + child: const Text('Edit (Mask-Free)'), + ), ], ), ), @@ -128,7 +261,200 @@ class _ImagenPageState extends State { ); } - Future _testImagen(String prompt) async { + Future _pickImage() async { + final ImagePicker picker = ImagePicker(); + try { + final XFile? imageFile = await picker.pickImage(source: ImageSource.gallery); + if (imageFile != null) { + // Attempt to get mimeType, default if null. + // Note: imageFile.mimeType might be null on some platforms or for some files. + final String mimeType = imageFile.mimeType ?? 'image/jpeg'; + final Uint8List imageBytes = await imageFile.readAsBytes(); + return ImagenInlineImage(bytesBase64Encoded: imageBytes, mimeType: mimeType); + } + } catch (e) { + _showError('Error picking image: $e'); + } + return null; + } + + Future _pickSourceImageForUpscaling() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _sourceImageForUpscaling = pickedImage; + }); + } + } + + Future _pickSourceImageForEditing() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _sourceImageForEditing = pickedImage; + }); + } + } + + Future _pickMaskImageForEditing() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _maskImageForEditing = pickedImage; + }); + } + } + + Future _upscaleImage() async { + if (_sourceImageForUpscaling == null) { + _showError('Please pick a source image for upscaling.'); + return; + } + setState(() { + _loading = true; + }); + + try { + final response = await widget.model.upscaleImage( + image: _sourceImageForUpscaling!, + upscaleFactor: _selectedUpscaleFactor, + ); + if (response.images.isNotEmpty) { + final upscaledImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(upscaledImage.bytesBase64Encoded), + text: 'Upscaled image (Factor: ${_selectedUpscaleFactor.name})', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from upscaling.'); + } + } catch (e) { + _showError('Error upscaling image: $e'); + } + + setState(() { + _loading = false; + }); + } + + Future _editImageInpaintOutpaint() async { + if (_sourceImageForEditing == null || _maskImageForEditing == null) { + _showError('Please pick a source image and a mask image for inpainting/outpainting.'); + return; + } + setState(() { + _loading = true; + }); + + final String prompt = _editPromptController.text; + final double? maskDilation = double.tryParse(_maskDilationController.text); + final int? editSteps = int.tryParse(_editStepsController.text); + + final editConfig = ImagenEditingConfig( + image: _sourceImageForEditing!, + mask: _maskImageForEditing!, + editMode: _selectedEditMode, + maskDilation: maskDilation, + editSteps: editSteps, + // numberOfImages: 1, // Default in model or could be added to UI + ); + + try { + final response = await widget.model.editImage( + prompt, + config: editConfig, + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(editedImage.bytesBase64Encoded), + text: 'Edited image (Inpaint/Outpaint): $prompt', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from editing.'); + } + } catch (e) { + _showError('Error editing image: $e'); + } + setState(() { + _loading = false; + }); + } + + Future _editImageMaskFree() async { + if (_sourceImageForEditing == null) { + _showError('Please pick a source image for mask-free editing.'); + return; + } + setState(() { + _loading = true; + }); + + final String prompt = _editPromptController.text; + final editConfig = ImagenEditingConfig.maskFree( + image: _sourceImageForEditing!, + // numberOfImages: 1, // Default in model or could be added to UI + ); + + try { + final response = await widget.model.editImage( + prompt, + config: editConfig, + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(editedImage.bytesBase64Encoded), + text: 'Edited image (Mask-Free): $prompt', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from mask-free editing.'); + } + } catch (e) { + _showError('Error performing mask-free edit: $e'); + } + setState(() { + _loading = false; + }); + } + + void _showNotImplementedDialog() { + showDialog( + context: context, + builder: (context) { + return AlertDialog( + title: const Text('Not Implemented'), + content: const Text('This feature will be implemented in a later step.'), + actions: [ + TextButton( + onPressed: () => Navigator.of(context).pop(), + child: const Text('OK'), + ), + ], + ); + }, + ); + } + + Future _generateImageFromPrompt(String prompt) async { setState(() { _loading = true; }); From 503c0afe771b688e7c081af71cb7834430b01354 Mon Sep 17 00:00:00 2001 From: Cynthia J Date: Fri, 6 Jun 2025 10:25:48 -0700 Subject: [PATCH 2/2] make everything buildable after jules --- .../example/lib/pages/imagen_page.dart | 329 ++++++++++++++-- .../macos/Runner/DebugProfile.entitlements | 2 + .../example/macos/Runner/Info.plist | 2 + .../firebase_ai/example/pubspec.yaml | 1 + .../firebase_ai/lib/src/imagen_model.dart | 29 +- .../firebase_ai/test/chat_test.dart | 4 +- .../test/firebase_vertexai_test.dart | 4 +- .../test/google_ai_generative_model_test.dart | 4 +- .../firebase_ai/test/imagen_test.dart | 360 +----------------- .../test/{mock.dart => mocks/mock_core.dart} | 10 +- .../firebase_ai/test/model_test.dart | 4 +- .../example/lib/pages/imagen_page.dart | 339 +---------------- 12 files changed, 334 insertions(+), 754 deletions(-) rename packages/firebase_ai/firebase_ai/test/{mock.dart => mocks/mock_core.dart} (89%) diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart index c957f207278e..b5551d402ca4 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart @@ -12,8 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -import 'package:flutter/material.dart'; +import 'dart:typed_data'; + +import 'package:image_picker/image_picker.dart'; import 'package:firebase_ai/firebase_ai.dart'; + +import 'package:flutter/material.dart'; //import 'package:firebase_storage/firebase_storage.dart'; import '../widgets/message_widget.dart'; @@ -38,6 +42,10 @@ class _ImagenPageState extends State { final List _generatedContent = []; bool _loading = false; + // For image picking + ImagenInlineImage? _sourceImage; + ImagenInlineImage? _maskImageForEditing; + void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( (_) => _scrollController.animateTo( @@ -80,45 +88,89 @@ class _ImagenPageState extends State { vertical: 25, horizontal: 15, ), - child: Row( + child: Column( children: [ - Expanded( - child: TextField( - autofocus: true, - focusNode: _textFieldFocus, - controller: _textController, - ), - ), - const SizedBox.square( - dimension: 15, - ), - if (!_loading) - IconButton( - onPressed: () async { - await _testImagen(_textController.text); - }, - icon: Icon( - Icons.image_search, - color: Theme.of(context).colorScheme.primary, + // Generate Image Row + Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + decoration: const InputDecoration( + hintText: 'Enter a prompt...', + ), + controller: _textController, + ), ), - tooltip: 'Imagen raw data', - ) - else - const CircularProgressIndicator(), - // NOTE: Keep this API private until future release. - // if (!_loading) - // IconButton( - // onPressed: () async { - // await _testImagenGCS(_textController.text); - // }, - // icon: Icon( - // Icons.imagesearch_roller, - // color: Theme.of(context).colorScheme.primary, - // ), - // tooltip: 'Imagen GCS', - // ) - // else - // const CircularProgressIndicator(), + const SizedBox.square(dimension: 15), + IconButton( + onPressed: () async { + await _pickSourceImage(); + }, + icon: Icon( + Icons.add_a_photo, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Pick Source Image', + ), + IconButton( + onPressed: () async { + await _pickMaskImage(); + }, + icon: Icon( + Icons.add_to_photos, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Pick mask', + ), + IconButton( + onPressed: () async { + await _editImageMaskFree(); + }, + icon: Icon( + Icons.edit, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Edit Image Mask Free', + ), + IconButton( + onPressed: () async { + await _editImageInpaintOutpaint(); + }, + icon: Icon( + Icons.masks, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Mask Inpaint Outpaint', + ), + IconButton( + onPressed: () async { + await _upscaleImage(); + }, + icon: Icon( + Icons.plus_one, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Upscale', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _generateImageFromPrompt( + _textController.text, + ); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Generate Image', + ) + else + const CircularProgressIndicator(), + ], + ), ], ), ), @@ -128,7 +180,206 @@ class _ImagenPageState extends State { ); } - Future _testImagen(String prompt) async { + Future _pickImage() async { + final ImagePicker picker = ImagePicker(); + try { + final XFile? imageFile = + await picker.pickImage(source: ImageSource.gallery); + if (imageFile != null) { + // Attempt to get mimeType, default if null. + // Note: imageFile.mimeType might be null on some platforms or for some files. + final String mimeType = imageFile.mimeType ?? 'image/jpeg'; + final Uint8List imageBytes = await imageFile.readAsBytes(); + return ImagenInlineImage( + bytesBase64Encoded: imageBytes, mimeType: mimeType); + } + } catch (e) { + _showError('Error picking image: $e'); + } + return null; + } + + Future _pickSourceImage() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _sourceImage = pickedImage; + }); + } + } + + Future _pickMaskImage() async { + final pickedImage = await _pickImage(); + if (pickedImage != null) { + setState(() { + _maskImageForEditing = pickedImage; + }); + } + } + + Future _upscaleImage() async { + if (_sourceImage == null) { + _showError('Please pick a source image for upscaling.'); + return; + } + setState(() { + _loading = true; + }); + + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(_sourceImage!.bytesBase64Encoded), + text: + 'Try to Upscaled image (Factor: ${ImagenUpscaleFactor.x2.name})', + fromUser: true, + ), + ); + _scrollDown(); + }); + + try { + final response = await widget.model.upscaleImage( + image: _sourceImage!, + upscaleFactor: ImagenUpscaleFactor.x2, + ); + if (response.images.isNotEmpty) { + final upscaledImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(upscaledImage.bytesBase64Encoded), + text: 'Upscaled image (Factor: ${ImagenUpscaleFactor.x2.name})', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from upscaling.'); + } + } catch (e) { + _showError('Error upscaling image: $e'); + } + + setState(() { + _loading = false; + }); + } + + Future _editImageInpaintOutpaint() async { + if (_sourceImage == null || _maskImageForEditing == null) { + _showError( + 'Please pick a source image and a mask image for inpainting/outpainting.'); + return; + } + setState(() { + _loading = true; + }); + + final String prompt = _textController.text; + + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(_sourceImage!.bytesBase64Encoded), + text: prompt, + fromUser: true, + ), + ); + _scrollDown(); + }); + + final editConfig = ImagenEditingConfig( + image: _sourceImage!, + mask: _maskImageForEditing, + maskDilation: 0.01, + editSteps: 50, + ); + + try { + final response = await widget.model.editImage( + prompt, + config: editConfig, + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(editedImage.bytesBase64Encoded), + text: 'Edited image (Inpaint/Outpaint): $prompt', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from editing.'); + } + } catch (e) { + _showError('Error editing image: $e'); + } + setState(() { + _loading = false; + }); + } + + Future _editImageMaskFree() async { + if (_sourceImage == null) { + _showError('Please pick a source image for mask-free editing.'); + return; + } + setState(() { + _loading = true; + }); + + final String prompt = _textController.text; + + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(_sourceImage!.bytesBase64Encoded), + text: prompt, + fromUser: true, + ), + ); + _scrollDown(); + }); + final editConfig = ImagenEditingConfig.maskFree( + image: _sourceImage!, + // numberOfImages: 1, // Default in model or could be added to UI + ); + + try { + final response = await widget.model.editImage( + prompt, + config: editConfig, + ); + if (response.images.isNotEmpty) { + final editedImage = response.images[0]; + setState(() { + _generatedContent.add( + MessageData( + image: Image.memory(editedImage.bytesBase64Encoded), + text: 'Edited image (Mask-Free): $prompt', + fromUser: false, + ), + ); + _scrollDown(); + }); + } else { + _showError('No image was returned from mask-free editing.'); + } + } catch (e) { + _showError('Error performing mask-free edit: $e'); + } + setState(() { + _loading = false; + }); + } + + Future _generateImageFromPrompt(String prompt) async { setState(() { _loading = true; }); diff --git a/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements b/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements index b4bd9ee174a1..8560da29b687 100644 --- a/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements +++ b/packages/firebase_ai/firebase_ai/example/macos/Runner/DebugProfile.entitlements @@ -14,5 +14,7 @@ com.apple.security.device.audio-input + com.apple.security.files.user-selected.read-only + diff --git a/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist b/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist index a81b3fd0d617..d4369e6253fa 100644 --- a/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist +++ b/packages/firebase_ai/firebase_ai/example/macos/Runner/Info.plist @@ -30,5 +30,7 @@ NSApplication NSMicrophoneUsageDescription Permission to Record audio + NSPhotoLibraryUsageDescription + This app needs access to your photo library to let you select a profile picture. diff --git a/packages/firebase_ai/firebase_ai/example/pubspec.yaml b/packages/firebase_ai/firebase_ai/example/pubspec.yaml index 4868f106d648..475907298c36 100644 --- a/packages/firebase_ai/firebase_ai/example/pubspec.yaml +++ b/packages/firebase_ai/firebase_ai/example/pubspec.yaml @@ -27,6 +27,7 @@ dependencies: sdk: flutter flutter_markdown: ^0.6.20 flutter_soloud: ^3.1.6 + image_picker: ^1.1.2 path_provider: ^2.1.5 record: ^5.2.1 diff --git a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart index 4133d6ae3db8..752942e104d1 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/imagen_model.dart @@ -14,9 +14,6 @@ part of 'base_model.dart'; -import 'imagen_api.dart'; -import 'imagen_content.dart'; - /// Represents a remote Imagen model with the ability to generate images using /// text prompts. /// @@ -133,7 +130,8 @@ final class ImagenModel extends BaseApiClientModel { if (config.editMode != null) 'editMode': config.editMode!.name, if (config.maskDilation != null) 'maskDilation': config.maskDilation, if (config.editSteps != null) 'editSteps': config.editSteps, - 'sampleCount': config.numberOfImages ?? _generationConfig?.numberOfImages ?? 1, + 'sampleCount': + config.numberOfImages ?? _generationConfig?.numberOfImages ?? 1, // Parameters from model-level _generationConfig and _safetySettings if (_generationConfig?.aspectRatio case final aspectRatio?) @@ -154,7 +152,8 @@ final class ImagenModel extends BaseApiClientModel { return makeRequest( Task.predict, payload, - (jsonObject) => parseImagenGenerationResponse(jsonObject), + (jsonObject) => + parseImagenGenerationResponse(jsonObject), ); } @@ -175,15 +174,22 @@ final class ImagenModel extends BaseApiClientModel { ], 'parameters': { 'upscaleFactor': upscaleFactor.toJson(), - if (generationConfig?.aspectRatio ?? _generationConfig?.aspectRatio case final aspectRatio?) + if (generationConfig?.aspectRatio ?? _generationConfig?.aspectRatio + case final aspectRatio?) 'aspectRatio': aspectRatio.toJson(), - if (generationConfig?.addWatermark ?? _generationConfig?.addWatermark case final addWatermark?) + if (generationConfig?.addWatermark ?? _generationConfig?.addWatermark + case final addWatermark?) 'addWatermark': addWatermark, - if (generationConfig?.imageFormat ?? _generationConfig?.imageFormat case final imageFormat?) + if (generationConfig?.imageFormat ?? _generationConfig?.imageFormat + case final imageFormat?) 'outputOption': imageFormat.toJson(), - if (safetySettings?.personFilterLevel ?? _safetySettings?.personFilterLevel case final personFilterLevel?) + if (safetySettings?.personFilterLevel ?? + _safetySettings?.personFilterLevel + case final personFilterLevel?) 'personGeneration': personFilterLevel.toJson(), - if (safetySettings?.safetyFilterLevel ?? _safetySettings?.safetyFilterLevel case final safetyFilterLevel?) + if (safetySettings?.safetyFilterLevel ?? + _safetySettings?.safetyFilterLevel + case final safetyFilterLevel?) 'safetySetting': safetyFilterLevel.toJson(), }, }; @@ -191,7 +197,8 @@ final class ImagenModel extends BaseApiClientModel { return makeRequest( Task.predict, payload, - (jsonObject) => parseImagenGenerationResponse(jsonObject), + (jsonObject) => + parseImagenGenerationResponse(jsonObject), ); } } diff --git a/packages/firebase_ai/firebase_ai/test/chat_test.dart b/packages/firebase_ai/firebase_ai/test/chat_test.dart index ab5819f0f12a..76f302e1162b 100644 --- a/packages/firebase_ai/firebase_ai/test/chat_test.dart +++ b/packages/firebase_ai/firebase_ai/test/chat_test.dart @@ -17,12 +17,12 @@ import 'package:firebase_ai/src/base_model.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'mock.dart'; +import 'mocks/mock_core.dart'; import 'utils/matchers.dart'; import 'utils/stub_client.dart'; void main() { - setupFirebaseVertexAIMocks(); + setupFirebaseAIMocks(); // ignore: unused_local_variable late FirebaseApp app; diff --git a/packages/firebase_ai/firebase_ai/test/firebase_vertexai_test.dart b/packages/firebase_ai/firebase_ai/test/firebase_vertexai_test.dart index 8edb2d0f8480..a5a0e6dd2518 100644 --- a/packages/firebase_ai/firebase_ai/test/firebase_vertexai_test.dart +++ b/packages/firebase_ai/firebase_ai/test/firebase_vertexai_test.dart @@ -17,10 +17,10 @@ import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'mock.dart'; +import 'mocks/mock_core.dart'; void main() { - setupFirebaseVertexAIMocks(); + setupFirebaseAIMocks(); // ignore: unused_local_variable late FirebaseApp app; // ignore: unused_local_variable diff --git a/packages/firebase_ai/firebase_ai/test/google_ai_generative_model_test.dart b/packages/firebase_ai/firebase_ai/test/google_ai_generative_model_test.dart index 9883102c2729..ef67d2505b34 100644 --- a/packages/firebase_ai/firebase_ai/test/google_ai_generative_model_test.dart +++ b/packages/firebase_ai/firebase_ai/test/google_ai_generative_model_test.dart @@ -17,12 +17,12 @@ import 'package:firebase_ai/src/base_model.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'mock.dart'; +import 'mocks/mock_core.dart'; import 'utils/matchers.dart'; import 'utils/stub_client.dart'; void main() { - setupFirebaseVertexAIMocks(); + setupFirebaseAIMocks(); late FirebaseApp app; setUpAll(() async { // Initialize Firebase diff --git a/packages/firebase_ai/firebase_ai/test/imagen_test.dart b/packages/firebase_ai/firebase_ai/test/imagen_test.dart index 80891344dcf1..d56a8d6cd8a3 100644 --- a/packages/firebase_ai/firebase_ai/test/imagen_test.dart +++ b/packages/firebase_ai/firebase_ai/test/imagen_test.dart @@ -15,21 +15,17 @@ import 'dart:convert'; import 'dart:typed_data'; +import 'package:firebase_ai/firebase_ai.dart'; import 'package:firebase_ai/src/error.dart'; import 'package:firebase_ai/src/imagen_content.dart'; -import 'package:firebase_ai/firebase_ai.dart'; -import 'package:firebase_ai/src/api_client.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'package:mockito/annotations.dart'; -import 'package:mockito/mockito.dart'; -import 'mock.dart'; -import 'imagen_test.mocks.dart'; // Generated by Mockito + +import 'mocks/mock_core.dart'; // Mock HttpApiClient -@GenerateMocks([HttpApiClient]) void main() { - setupFirebaseCoreMocks(); + setupFirebaseAIMocks(); setUpAll(() async { await Firebase.initializeApp( @@ -261,352 +257,4 @@ void main() { throwsA(isA())); }); }); - - group('ImagenModel Tests', () { - late ImagenModel imagenModel; - late MockHttpApiClient mockClient; - final app = Firebase.app('testApp'); - - setUp(() { - mockClient = MockHttpApiClient(); - imagenModel = ImagenModel._( - app: app, - model: 'gemini-1.5-flash', // Example model - location: 'us-central1', - useVertexBackend: true, // Assuming Vertex backend for these tests - client: mockClient, - // No default generationConfig or safetySettings for cleaner test isolation - ); - }); - - group('editImage', () { - final sourceImageBytes = Uint8List.fromList(utf8.encode('source_image_bytes')); - final sourceImage = ImagenInlineImage(bytesBase64Encoded: sourceImageBytes, mimeType: 'image/png'); - final maskImageBytes = Uint8List.fromList(utf8.encode('mask_image_bytes')); - final maskImage = ImagenInlineImage(bytesBase64Encoded: maskImageBytes, mimeType: 'image/png'); - const prompt = 'a test prompt'; - - test('should construct correct payload for mask-free editing', () async { - // Assuming model has no default _generationConfig or _safetySettings for this test - final config = ImagenEditingConfig.maskFree(image: sourceImage, numberOfImages: 2); - final expectedPayload = { - 'instances': [ - { - 'prompt': prompt, - 'image': sourceImage.toJson(), - } - ], - 'parameters': { - 'sampleCount': 2, // From config - // No other parameters expected if model defaults are null - }, - }; - - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); // Dummy success - - await imagenModel.editImage(prompt, config: config); - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - expect(captured[1], equals(expectedPayload)); - }); - - test('should construct correct payload for full editing config', () async { - final config = ImagenEditingConfig( - image: sourceImage, - mask: maskImage, - maskDilation: 0.05, - editSteps: 60, - numberOfImages: 3, - editMode: ImagenEditMode.inpaint, - // negativePrompt, safetySettings, and generationConfig are no longer part of ImagenEditingConfig - ); - - // For this test, let's assume the model has some default _generationConfig and _safetySettings - // to see them appear in the payload. - final modelForFullTest = ImagenModel._( - app: app, - model: 'gemini-1.5-flash', - location: 'us-central1', - useVertexBackend: true, - client: mockClient, - generationConfig: ImagenGenerationConfig( - aspectRatio: ImagenAspectRatio.landscape16x9, - imageFormat: ImagenFormat.jpeg(compressionQuality: 80), - addWatermark: false, - negativePrompt: 'model-level blurry', - ), - safetySettings: ImagenSafetySettings(ImagenSafetyFilterLevel.blockLowAndAbove, ImagenPersonFilterLevel.blockAll), - ); - - final expectedPayload = { - 'instances': [ - { - 'prompt': prompt, - 'image': sourceImage.toJson(), - 'mask': maskImage.toJson(), - } - ], - 'parameters': { - // From config - 'editMode': 'inpaint', - 'maskDilation': 0.05, - 'editSteps': 60, - 'sampleCount': 3, - // From model-level settings - 'aspectRatio': '16:9', - 'negativePrompt': 'model-level blurry', - 'addWatermark': false, - 'outputOption': {'mimeType': 'image/jpeg', 'compressionQuality': 80}, - 'personGeneration': 'dont_allow', - 'safetySetting': 'block_low_and_above', - }, - }; - - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); - - await modelForFullTest.editImage(prompt, config: config); // Use modelForFullTest here - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - expect(captured[1], equals(expectedPayload)); - }); - - test('should parse successful response', () async { - final config = ImagenEditingConfig.maskFree(image: sourceImage); - final apiResponse = { - 'predictions': [ - {'bytesBase64Encoded': base64Encode(sourceImageBytes), 'mimeType': 'image/png'} - ] - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - - final response = await imagenModel.editImage(prompt, config: config); // Use default imagenModel - expect(response.images.length, 1); - expect(response.images.first.bytesBase64Encoded, equals(sourceImageBytes)); - expect(response.images.first.mimeType, 'image/png'); - }); - - test('should throw ImagenImagesBlockedException for filtered response', () async { - final config = ImagenEditingConfig.maskFree(image: sourceImage); - final apiResponse = { - 'predictions': [ - {'raiFilteredReason': 'Blocked due to safety reasons'} - ] - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - - expect( - () => imagenModel.editImage(prompt, config: config), // Use default imagenModel - throwsA(isA())); - }); - - test('should throw ServerException for error response', () async { - final config = ImagenEditingConfig.maskFree(image: sourceImage); - final apiResponse = { - 'error': {'code': 400, 'message': 'Bad request', 'status': 'INVALID_ARGUMENT'} - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - - expect( - () => imagenModel.editImage(prompt, config: config), // Use default imagenModel - throwsA(isA())); - }); - - test('model-level defaults should be used when not overridden by ImagenEditingConfig', () async { - // Initialize model with default settings - final defaultSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockMediumAndAbove, ImagenPersonFilterLevel.allowAdult); - final defaultGeneration = ImagenGenerationConfig( - negativePrompt: "default negative", - numberOfImages: 1, // This will be overridden by config.numberOfImages if set, or config's default - aspectRatio: ImagenAspectRatio.square1x1, - addWatermark: true, - imageFormat: ImagenFormat.png(), - ); - - final modelWithDefaults = ImagenModel._( - app: app, - model: 'gemini-1.5-flash', - location: 'us-central1', - useVertexBackend: true, - client: mockClient, - safetySettings: defaultSafety, - generationConfig: defaultGeneration, - ); - - // ImagenEditingConfig now only has image and numberOfImages that could affect these global params - final config = ImagenEditingConfig( - image: sourceImage, - numberOfImages: 2, // This will override defaultGeneration.numberOfImages for 'sampleCount' - editMode: ImagenEditMode.outpaint, // Specific to edit - ); - - final expectedPayloadParameters = { - // From config - 'sampleCount': 2, - 'editMode': 'outpaint', - // From model-level defaults (defaultGeneration and defaultSafety) - 'aspectRatio': '1:1', - 'negativePrompt': 'default negative', - 'addWatermark': true, - 'outputOption': {'mimeType': 'image/png'}, - 'personGeneration': 'allow_adult', - 'safetySetting': 'block_medium_and_above', - }; - - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); - - await modelWithDefaults.editImage(prompt, config: config); - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - // We are interested in the 'parameters' part of the payload (captured[1]) - final actualParameters = captured[1]['parameters'] as Map; - expect(actualParameters, equals(expectedPayloadParameters)); - }); - }); - - group('upscaleImage', () { - final sourceImageBytes = Uint8List.fromList(utf8.encode('source_image_bytes_for_upscale')); - final sourceImage = ImagenInlineImage(bytesBase64Encoded: sourceImageBytes, mimeType: 'image/jpeg'); - - test('should construct correct payload for basic upscaling', () async { - final expectedPayload = { - 'instances': [ - {'image': sourceImage.toJson()} - ], - 'parameters': { - 'upscaleFactor': 'x2', - }, - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); - - await imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2); - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - expect(captured[1], equals(expectedPayload)); - }); - - test('should construct correct payload with optional safety and generation config', () async { - final safety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockLowAndAbove, ImagenPersonFilterLevel.blockAll); - final generation = ImagenGenerationConfig( - imageFormat: ImagenFormat.png(), - addWatermark: true, - ); - final expectedPayload = { - 'instances': [ - {'image': sourceImage.toJson()} - ], - 'parameters': { - 'upscaleFactor': 'x4', - 'outputOption': {'mimeType': 'image/png'}, - 'addWatermark': true, - 'personGeneration': 'dont_allow', - 'safetySetting': 'block_low_and_above', - }, - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); - - await imagenModel.upscaleImage( - image: sourceImage, - upscaleFactor: ImagenUpscaleFactor.x4, - safetySettings: safety, - generationConfig: generation, - ); - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - expect(captured[1], equals(expectedPayload)); - }); - - test('should parse successful upscale response', () async { - final apiResponse = { - 'predictions': [ - {'bytesBase64Encoded': base64Encode(sourceImageBytes), 'mimeType': 'image/jpeg'} - ] - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - - final response = await imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2); - expect(response.images.length, 1); - expect(response.images.first.bytesBase64Encoded, equals(sourceImageBytes)); - expect(response.images.first.mimeType, 'image/jpeg'); - }); - - test('should throw ImagenImagesBlockedException for filtered upscale response', () async { - final apiResponse = { - 'predictions': [ - {'raiFilteredReason': 'Blocked due to safety reasons'} - ] - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - - expect( - () => imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2), - throwsA(isA())); - }); - - test('should throw ServerException for error in upscale response', () async { - final apiResponse = { - 'error': {'code': 500, 'message': 'Internal server error', 'status': 'UNAVAILABLE'} - }; - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => parseImagenGenerationResponse(apiResponse)); - expect( - () => imagenModel.upscaleImage(image: sourceImage, upscaleFactor: ImagenUpscaleFactor.x2), - throwsA(isA())); - }); - - test('upscale method parameters should override model-level defaults', () async { - final defaultSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockNone, ImagenPersonFilterLevel.allowAll); - final defaultGeneration = ImagenGenerationConfig(imageFormat: ImagenFormat.jpeg(), addWatermark: false); - - final modelWithDefaults = ImagenModel._( - app: app, - model: 'gemini-1.5-flash', - location: 'us-central1', - useVertexBackend: true, - client: mockClient, - safetySettings: defaultSafety, - generationConfig: defaultGeneration, - ); - - final methodSafety = ImagenSafetySettings(ImagenSafetyFilterLevel.blockMediumAndAbove, null); // Override safety, keep person from default - final methodGeneration = ImagenGenerationConfig(addWatermark: true, imageFormat: ImagenFormat.png()); // Override watermark and format - - final expectedPayloadParameters = { - 'upscaleFactor': 'x2', - 'outputOption': {'mimeType': 'image/png'}, // from methodGeneration - 'addWatermark': true, // from methodGeneration - 'personGeneration': 'allow_all', // from model default - 'safetySetting': 'block_medium_and_above', // from methodSafety - }; - - when(mockClient.makeRequest>(any, any, any)) - .thenAnswer((_) async => ImagenGenerationResponse(images: [sourceImage])); - - await modelWithDefaults.upscaleImage( - image: sourceImage, - upscaleFactor: ImagenUpscaleFactor.x2, - safetySettings: methodSafety, - generationConfig: methodGeneration, - ); - final captured = verify(mockClient.makeRequest>( - captureAny, captureAny, captureAny)) - .captured; - expect(captured[1]['parameters'], equals(expectedPayloadParameters)); - }); - }); - }); } diff --git a/packages/firebase_ai/firebase_ai/test/mock.dart b/packages/firebase_ai/firebase_ai/test/mocks/mock_core.dart similarity index 89% rename from packages/firebase_ai/firebase_ai/test/mock.dart rename to packages/firebase_ai/firebase_ai/test/mocks/mock_core.dart index ed883d924371..0a78e72888a7 100644 --- a/packages/firebase_ai/firebase_ai/test/mock.dart +++ b/packages/firebase_ai/firebase_ai/test/mocks/mock_core.dart @@ -18,7 +18,7 @@ import 'package:flutter_test/flutter_test.dart'; import 'package:mockito/mockito.dart'; import 'package:plugin_platform_interface/plugin_platform_interface.dart'; -class MockFirebaseAppVertexAI implements TestFirebaseCoreHostApi { +class MockFirebaseAppAI implements TestFirebaseCoreHostApi { @override Future initializeApp( String appName, @@ -58,16 +58,16 @@ class MockFirebaseAppVertexAI implements TestFirebaseCoreHostApi { } } -void setupFirebaseVertexAIMocks() { +void setupFirebaseAIMocks() { TestWidgetsFlutterBinding.ensureInitialized(); - TestFirebaseCoreHostApi.setup(MockFirebaseAppVertexAI()); + TestFirebaseCoreHostApi.setup(MockFirebaseAppAI()); } // FirebaseVertexAIPlatform Mock -class MockFirebaseVertexAI extends Mock +class MockFirebaseAI extends Mock with // ignore: prefer_mixin, plugin_platform_interface needs to migrate to use `mixin` MockPlatformInterfaceMixin { - MockFirebaseVertexAI(); + MockFirebaseAI(); } diff --git a/packages/firebase_ai/firebase_ai/test/model_test.dart b/packages/firebase_ai/firebase_ai/test/model_test.dart index 2ddf4d55406c..7049d43a5e33 100644 --- a/packages/firebase_ai/firebase_ai/test/model_test.dart +++ b/packages/firebase_ai/firebase_ai/test/model_test.dart @@ -17,12 +17,12 @@ import 'package:firebase_ai/src/base_model.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'mock.dart'; +import 'mocks/mock_core.dart'; import 'utils/matchers.dart'; import 'utils/stub_client.dart'; void main() { - setupFirebaseVertexAIMocks(); + setupFirebaseAIMocks(); // ignore: unused_local_variable late FirebaseApp app; setUpAll(() async { diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart index 7bb7ee6138cf..c5dc1e7fbcaf 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -import 'dart:typed_data'; -import 'dart:io'; // Though XFile.readAsBytes is available, good to have for other File ops if needed. -import 'package:image_picker/image_picker.dart'; +import 'package:firebase_vertexai/firebase_vertexai.dart'; import 'package:flutter/material.dart'; -import 'package:firebase_vertexai/firebase_vertexai.dart'; //import 'package:firebase_storage/firebase_storage.dart'; import '../widgets/message_widget.dart'; @@ -42,20 +39,6 @@ class _ImagenPageState extends State { final List _generatedContent = []; bool _loading = false; - // For image picking - ImagenInlineImage? _sourceImageForUpscaling; - ImagenInlineImage? _sourceImageForEditing; - ImagenInlineImage? _maskImageForEditing; - - // For upscale factor - ImagenUpscaleFactor _selectedUpscaleFactor = ImagenUpscaleFactor.x2; // Default - - // For editing parameters - final TextEditingController _editPromptController = TextEditingController(); - final TextEditingController _maskDilationController = TextEditingController(); - final TextEditingController _editStepsController = TextEditingController(); - ImagenEditMode _selectedEditMode = ImagenEditMode.inpaint; // Default - void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( (_) => _scrollController.animateTo( @@ -117,7 +100,9 @@ class _ImagenPageState extends State { if (!_loading) IconButton( onPressed: () async { - await _generateImageFromPrompt(_textController.text); + await _generateImageFromPrompt( + _textController.text, + ); }, icon: Icon( Icons.image_search, @@ -129,129 +114,6 @@ class _ImagenPageState extends State { const CircularProgressIndicator(), ], ), - const SizedBox(height: 20), - // Upscaling UI - const Text('Image Upscaling', style: TextStyle(fontWeight: FontWeight.bold)), - Row( - mainAxisAlignment: MainAxisAlignment.spaceEvenly, - children: [ - ElevatedButton( - onPressed: _loading ? null : _pickSourceImageForUpscaling, - child: const Text('Pick Source'), - ), - if (_sourceImageForUpscaling != null) - Text(_sourceImageForUpscaling!.mimeType, style: const TextStyle(fontSize: 12)), - DropdownButton( - value: _selectedUpscaleFactor, - onChanged: _loading ? null : (ImagenUpscaleFactor? newValue) { - if (newValue != null) { - setState(() { - _selectedUpscaleFactor = newValue; - }); - } - }, - items: ImagenUpscaleFactor.values - .map>( - (ImagenUpscaleFactor value) { - return DropdownMenuItem( - value: value, - child: Text(value.name), - ); - }).toList(), - ), - ElevatedButton( - onPressed: _loading || _sourceImageForUpscaling == null ? null : _upscaleImage, - child: const Text('Upscale Image'), - ), - ], - ), - const SizedBox(height: 20), - // Editing UI - const Text('Image Editing', style: TextStyle(fontWeight: FontWeight.bold)), - Row( - mainAxisAlignment: MainAxisAlignment.spaceEvenly, - children: [ - ElevatedButton( - onPressed: _loading ? null : _pickSourceImageForEditing, - child: const Text('Pick Source'), - ), - if (_sourceImageForEditing != null) - Text(_sourceImageForEditing!.mimeType, style: const TextStyle(fontSize: 12)), - Expanded( - child: TextField( - controller: _editPromptController, - decoration: const InputDecoration(hintText: 'Edit prompt'), - enabled: !_loading, - ), - ), - ], - ), - const SizedBox(height: 10), - // Inpaint/Outpaint Specific UI - const Text('Inpaint/Outpaint', style: TextStyle(fontStyle: FontStyle.italic)), - Row( - mainAxisAlignment: MainAxisAlignment.spaceEvenly, - children: [ - ElevatedButton( - onPressed: _loading ? null : _pickMaskImageForEditing, - child: const Text('Pick Mask'), - ), - if (_maskImageForEditing != null) - Text(_maskImageForEditing!.mimeType, style: const TextStyle(fontSize: 12)), - DropdownButton( - value: _selectedEditMode, - onChanged: _loading ? null : (ImagenEditMode? newValue) { - if (newValue != null) { - setState(() { - _selectedEditMode = newValue; - }); - } - }, - items: ImagenEditMode.values - .map>( - (ImagenEditMode value) { - return DropdownMenuItem( - value: value, - child: Text(value.name), - ); - }).toList(), - ), - ], - ), - Row( - children: [ - Expanded( - child: TextField( - controller: _maskDilationController, - decoration: const InputDecoration(hintText: 'Mask Dilation (e.g., 0.01)'), - keyboardType: const TextInputType.numberWithOptions(decimal: true), - enabled: !_loading, - ), - ), - const SizedBox(width: 10), - Expanded( - child: TextField( - controller: _editStepsController, - decoration: const InputDecoration(hintText: 'Edit Steps (e.g., 50)'), - keyboardType: TextInputType.number, - enabled: !_loading, - ), - ), - ], - ), - ElevatedButton( - onPressed: _loading || _sourceImageForEditing == null || _maskImageForEditing == null - ? null - : _editImageInpaintOutpaint, - child: const Text('Edit (Inpaint/Outpaint)'), - ), - const SizedBox(height: 10), - // Mask-Free Editing Button - const Text('Mask-Free Edit', style: TextStyle(fontStyle: FontStyle.italic)), - ElevatedButton( - onPressed: _loading || _sourceImageForEditing == null ? null : _editImageMaskFree, - child: const Text('Edit (Mask-Free)'), - ), ], ), ), @@ -261,199 +123,6 @@ class _ImagenPageState extends State { ); } - Future _pickImage() async { - final ImagePicker picker = ImagePicker(); - try { - final XFile? imageFile = await picker.pickImage(source: ImageSource.gallery); - if (imageFile != null) { - // Attempt to get mimeType, default if null. - // Note: imageFile.mimeType might be null on some platforms or for some files. - final String mimeType = imageFile.mimeType ?? 'image/jpeg'; - final Uint8List imageBytes = await imageFile.readAsBytes(); - return ImagenInlineImage(bytesBase64Encoded: imageBytes, mimeType: mimeType); - } - } catch (e) { - _showError('Error picking image: $e'); - } - return null; - } - - Future _pickSourceImageForUpscaling() async { - final pickedImage = await _pickImage(); - if (pickedImage != null) { - setState(() { - _sourceImageForUpscaling = pickedImage; - }); - } - } - - Future _pickSourceImageForEditing() async { - final pickedImage = await _pickImage(); - if (pickedImage != null) { - setState(() { - _sourceImageForEditing = pickedImage; - }); - } - } - - Future _pickMaskImageForEditing() async { - final pickedImage = await _pickImage(); - if (pickedImage != null) { - setState(() { - _maskImageForEditing = pickedImage; - }); - } - } - - Future _upscaleImage() async { - if (_sourceImageForUpscaling == null) { - _showError('Please pick a source image for upscaling.'); - return; - } - setState(() { - _loading = true; - }); - - try { - final response = await widget.model.upscaleImage( - image: _sourceImageForUpscaling!, - upscaleFactor: _selectedUpscaleFactor, - ); - if (response.images.isNotEmpty) { - final upscaledImage = response.images[0]; - setState(() { - _generatedContent.add( - MessageData( - image: Image.memory(upscaledImage.bytesBase64Encoded), - text: 'Upscaled image (Factor: ${_selectedUpscaleFactor.name})', - fromUser: false, - ), - ); - _scrollDown(); - }); - } else { - _showError('No image was returned from upscaling.'); - } - } catch (e) { - _showError('Error upscaling image: $e'); - } - - setState(() { - _loading = false; - }); - } - - Future _editImageInpaintOutpaint() async { - if (_sourceImageForEditing == null || _maskImageForEditing == null) { - _showError('Please pick a source image and a mask image for inpainting/outpainting.'); - return; - } - setState(() { - _loading = true; - }); - - final String prompt = _editPromptController.text; - final double? maskDilation = double.tryParse(_maskDilationController.text); - final int? editSteps = int.tryParse(_editStepsController.text); - - final editConfig = ImagenEditingConfig( - image: _sourceImageForEditing!, - mask: _maskImageForEditing!, - editMode: _selectedEditMode, - maskDilation: maskDilation, - editSteps: editSteps, - // numberOfImages: 1, // Default in model or could be added to UI - ); - - try { - final response = await widget.model.editImage( - prompt, - config: editConfig, - ); - if (response.images.isNotEmpty) { - final editedImage = response.images[0]; - setState(() { - _generatedContent.add( - MessageData( - image: Image.memory(editedImage.bytesBase64Encoded), - text: 'Edited image (Inpaint/Outpaint): $prompt', - fromUser: false, - ), - ); - _scrollDown(); - }); - } else { - _showError('No image was returned from editing.'); - } - } catch (e) { - _showError('Error editing image: $e'); - } - setState(() { - _loading = false; - }); - } - - Future _editImageMaskFree() async { - if (_sourceImageForEditing == null) { - _showError('Please pick a source image for mask-free editing.'); - return; - } - setState(() { - _loading = true; - }); - - final String prompt = _editPromptController.text; - final editConfig = ImagenEditingConfig.maskFree( - image: _sourceImageForEditing!, - // numberOfImages: 1, // Default in model or could be added to UI - ); - - try { - final response = await widget.model.editImage( - prompt, - config: editConfig, - ); - if (response.images.isNotEmpty) { - final editedImage = response.images[0]; - setState(() { - _generatedContent.add( - MessageData( - image: Image.memory(editedImage.bytesBase64Encoded), - text: 'Edited image (Mask-Free): $prompt', - fromUser: false, - ), - ); - _scrollDown(); - }); - } else { - _showError('No image was returned from mask-free editing.'); - } - } catch (e) { - _showError('Error performing mask-free edit: $e'); - } - setState(() { - _loading = false; - }); - } - - void _showNotImplementedDialog() { - showDialog( - context: context, - builder: (context) { - return AlertDialog( - title: const Text('Not Implemented'), - content: const Text('This feature will be implemented in a later step.'), - actions: [ - TextButton( - onPressed: () => Navigator.of(context).pop(), - child: const Text('OK'), - ), - ], - ); - }, - ); - } - Future _generateImageFromPrompt(String prompt) async { setState(() { _loading = true;