From 961876a96b8589a23d9649365feecf11e6daee3f Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 17 Dec 2024 14:20:29 +0100 Subject: [PATCH] types --- sunholo/genai/genaiv2.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sunholo/genai/genaiv2.py b/sunholo/genai/genaiv2.py index 10af86997..9aec87504 100644 --- a/sunholo/genai/genaiv2.py +++ b/sunholo/genai/genaiv2.py @@ -1,4 +1,5 @@ -from typing import Optional, List, Union, Dict, Any, TypedDict +from typing import Optional, List, Union, Dict, Any, TypedDict, TYPE_CHECKING, Generator + import enum import json from pydantic import BaseModel @@ -26,6 +27,13 @@ except ImportError: cv2 = None +if TYPE_CHECKING: + from google import genai + from google.genai import types + from google.genai.types import Tool, GenerateContentConfig, EmbedContentConfig +else: + genai = None + class GoogleAIConfig(BaseModel): """Configuration class for GoogleAI client initialization. See https://ai.google.dev/gemini-api/docs/models/gemini-v2 @@ -63,7 +71,7 @@ def __init__(self, config: GoogleAIConfig): self.default_model = "gemini-2.0-flash-exp" - def google_search_tool(self) -> types.Tool: + def google_search_tool(self) -> "types.Tool": from google.genai.types import Tool, GoogleSearch return Tool( google_search = GoogleSearch() @@ -78,7 +86,7 @@ def generate_text( top_k: int = 20, stop_sequences: Optional[List[str]] = None, system_prompt: Optional[str] = None, - tools: Optional[List[types.Tool]] = None + tools: Optional[List["types.Tool"]] = None ) -> str: """Generate text using the specified model. @@ -290,7 +298,7 @@ def stream_text( prompt: str, model: Optional[str] = None, **kwargs - ): + ) -> "Generator[str, None, None]": """Stream text generation responses. Args: