-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassification.py
125 lines (101 loc) · 2.88 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import vertexai
import json
import numpy as np
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image
)
from datasets import load_dataset
from tqdm import tqdm
# Authentication
with open("config-vertexai.json") as f:
data = f.read()
creds = json.loads(data)
vertexai.init(
project=creds["project"],
location=creds["location"]
)
multimodal_model = GenerativeModel("gemini-pro-vision")
# Data
painting_style_ds = load_dataset(
"keremberke/painting-style-classification",
name="full"
)
sample_size = 50
test_data = painting_style_ds['test'].shuffle()[0:sample_size]
test_images = test_data['image_file_path']
test_labels = test_data['labels']
# Evaluation
system_instructions = """
Instructions: Consider the following image that contains movement art images that range from \
Abstract Expressionism to Pop Art.
Each image corresponds to one of the following classes:
Abstract_Expressionism
Action_painting
Analytical_Cubism
Art_Nouveau_Modern
Baroque
Color_Field_Painting
Contemporary_Realism
Cubism
Early_Renaissance
Expressionism
Fauvism
High_Renaissance
Impressionism
Mannerism_Late_Renaissance
Minimalism
Naive_Art_Primitivism
New_Realism
Northern_Renaissance
Pointillism
Pop_Art
Post_Impressionism
Realism
Rococo
Romanticism
Symbolism
Synthetic_Cubism
Ukiyo_e
"""
task_prompt = """
Identify the class of the art depicted in the image as one of the above classes.
The class label generated should strictly belong to one of the classes above.
Your answer should only contain the class depicted. Do not explain your answer.
"""
art_classification_generation_config = GenerationConfig(
temperature=0,
top_p=1.0,
max_output_tokens=16
)
dataset_labels = [
'Realism', 'Art_Nouveau_Modern', 'Analytical_Cubism',
'Cubism', 'Expressionism', 'Action_painting', 'Synthetic_Cubism',
'Symbolism', 'Ukiyo_e', 'Naive_Art_Primitivism', 'Post_Impressionism',
'Impressionism', 'Fauvism', 'Rococo', 'Minimalism',
'Mannerism_Late_Renaissance', 'Color_Field_Painting',
'High_Renaissance', 'Romanticism', 'Pop_Art', 'Contemporary_Realism',
'Baroque', 'New_Realism', 'Pointillism', 'Northern_Renaissance',
'Early_Renaissance', 'Abstract_Expressionism'
]
model_predictions, ground_truths = [], []
for test_image, test_label in tqdm(zip(test_images, test_labels)):
test_image_input = Image.load_from_file(test_image)
prompt = [
system_instructions,
test_image_input,
task_prompt
]
try:
response = multimodal_model.generate_content(
prompt,
generation_config=art_classification_generation_config
)
model_predictions.append(response.text.strip())
ground_truths.append(dataset_labels[test_label])
except Exception as e:
print(e)
continue
accuracy = (np.array(model_predictions) == np.array(ground_truths)).mean()
print(accuracy)