Skip to content

Commit

Permalink
types
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkEdmondson1234 committed Dec 17, 2024
1 parent 4c7c646 commit 961876a
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions sunholo/genai/genaiv2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, List, Union, Dict, Any, TypedDict
from typing import Optional, List, Union, Dict, Any, TypedDict, TYPE_CHECKING, Generator

import enum
import json
from pydantic import BaseModel
Expand Down Expand Up @@ -26,6 +27,13 @@
except ImportError:
cv2 = None

if TYPE_CHECKING:
from google import genai
from google.genai import types
from google.genai.types import Tool, GenerateContentConfig, EmbedContentConfig
else:
genai = None

class GoogleAIConfig(BaseModel):
"""Configuration class for GoogleAI client initialization.
See https://ai.google.dev/gemini-api/docs/models/gemini-v2
Expand Down Expand Up @@ -63,7 +71,7 @@ def __init__(self, config: GoogleAIConfig):

self.default_model = "gemini-2.0-flash-exp"

def google_search_tool(self) -> types.Tool:
def google_search_tool(self) -> "types.Tool":
from google.genai.types import Tool, GoogleSearch
return Tool(
google_search = GoogleSearch()
Expand All @@ -78,7 +86,7 @@ def generate_text(
top_k: int = 20,
stop_sequences: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
tools: Optional[List[types.Tool]] = None
tools: Optional[List["types.Tool"]] = None
) -> str:
"""Generate text using the specified model.
Expand Down Expand Up @@ -290,7 +298,7 @@ def stream_text(
prompt: str,
model: Optional[str] = None,
**kwargs
):
) -> "Generator[str, None, None]":
"""Stream text generation responses.
Args:
Expand Down

0 comments on commit 961876a

Please sign in to comment.