-
Notifications
You must be signed in to change notification settings - Fork 6.5k
/
Copy pathtest_embeddings_examples.py
139 lines (111 loc) · 4.75 KB
/
test_embeddings_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import backoff
from google.api_core.exceptions import FailedPrecondition, ResourceExhausted
import google.auth
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer as aiplatform_init
import pytest
import batch_example
import code_retrieval_example
import document_retrieval_example
import generate_embeddings_with_lower_dimension
import model_tuning_example
import multimodal_example
import multimodal_example_syntax
import multimodal_image_example
import multimodal_video_example
import text_example_syntax
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
@pytest.fixture(scope="session")
def test_embed_text_batch() -> None:
os.environ["GCS_OUTPUT_URI"] = "gs://python-docs-samples-tests/"
batch_prediction_job = batch_example.embed_text_batch()
assert batch_prediction_job
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_multimodal_embedding_image_video_text() -> None:
embeddings = multimodal_example.get_image_video_text_embeddings()
assert embeddings is not None
assert embeddings.image_embedding is not None
assert embeddings.video_embeddings is not None
assert embeddings.text_embedding is not None
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_multimodal_embedding_video() -> None:
embeddings = multimodal_video_example.get_video_embeddings()
assert embeddings is not None
assert embeddings.video_embeddings is not None
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_multimodal_embedding_image() -> None:
embeddings = multimodal_image_example.get_image_text_embeddings()
assert embeddings is not None
assert embeddings.image_embedding is not None
assert embeddings.text_embedding is not None
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_generate_embeddings_with_lower_dimension() -> None:
embeddings = (
generate_embeddings_with_lower_dimension.generate_embeddings_with_lower_dimension()
)
assert embeddings is not None
assert embeddings.image_embedding is not None
assert len(embeddings.image_embedding) == 128
assert embeddings.text_embedding is not None
assert len(embeddings.text_embedding) == 128
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_create_embeddings() -> None:
embeddings = multimodal_example_syntax.create_embeddings()
assert embeddings is not None
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_create_text_embeddings() -> None:
embeddings = text_example_syntax.create_embeddings()
assert embeddings is not None
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_text_embed_text() -> None:
embeddings = document_retrieval_example.embed_text()
assert [len(e) for e in embeddings] == [256, 256]
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
def test_code_embed_text() -> None:
texts = [
"banana bread?",
"banana muffin?",
"banana?",
]
dimensionality = 256
embeddings = code_retrieval_example.embed_text(
texts=texts,
task="CODE_RETRIEVAL_QUERY",
dimensionality=dimensionality,
)
assert [len(e) for e in embeddings] == [dimensionality or 768] * len(texts)
@backoff.on_exception(backoff.expo, FailedPrecondition, max_time=300)
def dispose(tuning_job) -> None: # noqa: ANN001
if tuning_job._status.name == "PIPELINE_STATE_RUNNING":
tuning_job._cancel()
def test_tune_embedding_model() -> None:
credentials, _ = google.auth.default( # Set explicit credentials with Oauth scopes.
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
aiplatform.init(
api_endpoint="us-central1-aiplatform.googleapis.com:443",
project=os.getenv("GOOGLE_CLOUD_PROJECT"),
staging_bucket="gs://ucaip-samples-us-central1/training_pipeline_output",
credentials=credentials,
)
tuning_job = model_tuning_example.tune_embedding_model(
aiplatform_init.global_config.api_endpoint
)
try:
assert tuning_job._status.name != "PIPELINE_STATE_FAILED"
finally:
dispose(tuning_job)