Skip to content

Commit

Permalink
Add more enum samples. (#543)
Browse files Browse the repository at this point in the history
* Add more enum samples

Change-Id: I743d5967cc1cc91576b8ddf5a60db1767d94508d

* format

Change-Id: I8f6f9389f1cae0a7c934217968d4e2e20bb9590e
  • Loading branch information
MarkDaoust authored Sep 9, 2024
1 parent 4647e79 commit 836d31a
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions samples/controlled_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ class Choice(enum.Enum):
print(result) # "Keyboard"
# [END json_enum]

def test_enum_in_json(self):
# [START enum_in_json]
import enum
from typing_extensions import TypedDict

class Grade(enum.Enum):
A_PLUS = "a+"
A = "a"
B = "b"
C = "c"
D = "d"
F = "f"

class Recipe(TypedDict):
recipe_name: str
grade: Grade

model = genai.GenerativeModel("gemini-1.5-pro-latest")

result = model.generate_content(
"List about 10 cookie recipes, grade them based on popularity",
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=list[Recipe]
),
)
print(result) # [{"grade": "a+", "recipe_name": "Chocolate Chip Cookies"}, ...]
# [END enum_in_json]

def test_json_enum_raw(self):
# [START json_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")
Expand All @@ -91,6 +119,47 @@ def test_json_enum_raw(self):
print(result) # "Keyboard"
# [END json_enum_raw]

def test_x_enum(self):
# [START x_enum]
import enum

class Choice(enum.Enum):
PERCUSSION = "Percussion"
STRING = "String"
WOODWIND = "Woodwind"
BRASS = "Brass"
KEYBOARD = "Keyboard"

model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="text/x.enum", response_schema=Choice
),
)
print(result) # "Keyboard"
# [END x_enum]

def test_x_enum_raw(self):
# [START x_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="text/x.enum",
response_schema={
"type": "STRING",
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
},
),
)
print(result) # "Keyboard"
# [END x_enum_raw]


if __name__ == "__main__":
absltest.main()

0 comments on commit 836d31a

Please sign in to comment.