-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasset_library.py
498 lines (390 loc) · 19.4 KB
/
asset_library.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import tkinter as tk
from tkinter import ttk, filedialog, scrolledtext, messagebox
from PIL import Image, ImageTk
import boto3
import tempfile
import os
from tkinter import font as tkFont
from tkinter import scrolledtext
import customtkinter as ctk
import pandas as pd
import torch
import torchvision.transforms as transforms
import time
import traceback
import torchvision.models as models
from PIL import Image
from neo4j import GraphDatabase
from scipy import spatial
import numpy as np
from botocore.exceptions import NoCredentialsError
uri = ""
username = ""
password = ""
class Neo4jDatabase:
def __init__(self):
self._driver = GraphDatabase.driver(uri, auth=(username, password))
def close(self):
self._driver.close()
def create_image_node(self, embeddings, name, image_type):
with self._driver.session() as session:
session.execute_write(self._create_image_node, embeddings, name, image_type)
@staticmethod
def _create_image_node(tx, embeddings, name, type):
query = (
"CREATE (img:Image {embeddings: $embeddings,name: $name,type: $type})"
)
tx.run(query, embeddings=embeddings,name = name, type = type)
class AWSApp:
def __init__(self, master):
self.master = master
self.master.title("AWS Image App")
screen_width = self.master.winfo_screenwidth()
screen_height = self.master.winfo_screenheight()
self.master.geometry(f"{screen_width}x{screen_height}")
style = ttk.Style()
style.configure("TNotebook.Tab", padding=(10, 8), font=('Arial', 14), background='#121212')
self.notebook = ttk.Notebook(self.master)
self.notebook.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
# Initialize AWS clients
self.s3_client = None
self.rekognition_client = None
self.dynamodb_client = None
# Track the uploaded image path
self.uploaded_image_path = None
self.upload_frame = UploadFrame(self.notebook)
self.notebook.add(self.upload_frame.frame, text="Upload Images")
self.search_frame = SearchFrame(self.notebook)
self.notebook.add(self.search_frame.frame, text="Find Similar Images")
# Set the initial tab
self.notebook.select(self.upload_frame.frame)
class UploadFrame:
def __init__(self, notebook):
self.notebook = notebook
self.s3_client = None
self.dynamodb_client = None
# Initialize AWS clients
self.frame = tk.Frame(self.notebook, bg='#121212')
self.frame.pack(fill=tk.BOTH, expand=True)
helv36 = tkFont.Font(family='Courier', size=18, weight='bold')
button_font = ctk.CTkFont(family='Courier', size=24, weight='bold')
# AWS Information Entry
tk.Label(self.frame, text="Select CSV File:", font=helv36, bg="#000",
fg="#fff").pack(pady=10)
ctk.CTkButton(self.frame, text="Browse Access Key", command=self.load_aws_info,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack()
self.image_path_label = tk.Label(self.frame, text="Image Path:", bg='#121212', fg='white')
self.image_path_label.pack(pady=(100, 0))
self.uploaded_image_paths = None
self.uploaded_fbx_paths = None
# Image Upload
ctk.CTkButton(self.frame, text="Upload Image", command=self.browse_image,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack(pady=(0, 10))
self.image_type_label = tk.Label(self.frame, text="Image Type:", bg='#121212', fg='white')
self.image_type_label.pack()
self.image_type_entry = tk.Entry(self.frame, bg='#484848', fg='white')
self.image_type_entry.pack(pady=10)
self.progress_label = tk.Label(self.frame, text="Progress: ", bg='#121212', fg='white')
self.progress_label.pack(pady=10)
# Image Upload
ctk.CTkButton(self.frame, text="Upload FBX", command=self.browse_fbx,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack(pady=(0, 10))
ctk.CTkButton(self.frame, text="Submit", command=self.submit_image,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack(pady=30)
def load_aws_info(self):
file_path = filedialog.askopenfilename(title="Select CSV File", filetypes=[("CSV files", "*.csv")])
if file_path:
aws_info = self.read_aws_info_from_csv(file_path)
self.initialize_aws_clients(aws_info)
def read_aws_info_from_csv(self, file_path):
aws_info = {}
creds = pd.read_csv(file_path)
aws_info['AccessKey'] = creds.iloc[0, 0]
aws_info['SecretKey'] = creds.iloc[0, 1]
aws_info['Region']='ap-south-1'
# print(aws_info)
print('Info uploaded')
return aws_info
def initialize_aws_clients(self, aws_info):
try:
self.s3_client = boto3.client('s3', aws_access_key_id=aws_info['AccessKey'],
aws_secret_access_key=aws_info['SecretKey'],
region_name=aws_info['Region'])
tk.messagebox.showinfo("Success", "AWS clients initialized successfully.")
except Exception as e:
tk.messagebox.showerror("Error", f"Failed to initialize AWS clients. {str(e)}")
def browse_image(self):
# file_path = filedialog.askopenfilename()
# self.image_path_label.config(text="Image Path: " + file_path)
# self.uploaded_image_path = file_path
file_paths = filedialog.askopenfilenames()
self.image_path_label.config(text=f"Number of images selected: {len(file_paths)}")
self.uploaded_image_paths = file_paths
def browse_fbx(self):
file_paths = filedialog.askopenfilenames()
self.image_path_label.config(text=f"Number of images selected: {len(file_paths)}")
self.uploaded_fbx_paths = file_paths
# file_paths = filedialog.askopenfilenames()
# self.image_path_label.config(text=f"Number of files selected: {len(file_paths)}")
# self.uploaded_image_paths = file_paths
def submit_image(self):
# image_type = self.image_type_entry.get()
# if self.uploaded_image_paths is None or self.uploaded_fbx_paths is None or self.uploaded_image_paths == '' or self.uploaded_fbx_paths == '':
# messagebox.showerror("Error", "Please choose files to upload first.")
# return
if self.uploaded_image_paths is not None and self.uploaded_image_paths != '':
# messagebox.showerror("Error", "Please choose an image first.")
# return
image_type = self.image_type_entry.get()
if not image_type:
messagebox.showerror("Error", "Please enter an image type.")
return
try:
# Upload image to S3
i = 0
for uploaded_image_path in self.uploaded_image_paths:
self.progress_label.config(text=f"Progress: {i} of {len(self.uploaded_image_paths)} files uploaded")
image_name = os.path.basename(uploaded_image_path)
s3_key_image = f'images/{image_name}'
self.s3_client.upload_file(uploaded_image_path, 'dns-assets', s3_key_image)
# Vectorize the image using InceptionV3 model
image_vector = self.vectorize_image(uploaded_image_path)
# Upload image information to Neo4j
self.upload_to_neo4j(image_name, image_vector, image_type)
i+=1
self.progress_label.config(text=f"Progress: {i} of {len(self.uploaded_image_paths)} files uploaded")
except Exception as e:
messagebox.showerror("Error", str(e))
self.uploaded_image_paths = None
if self.uploaded_fbx_paths is not None and self.uploaded_fbx_paths != '':
try:
i = 0
for uploaded_fbx_path in self.uploaded_fbx_paths:
self.progress_label.config(text=f"Progress: {i} of {len(self.uploaded_fbx_paths)} files uploaded")
s3_key_fbx = f'fbx/{os.path.basename(uploaded_fbx_path)}'
self.s3_client.upload_file(uploaded_fbx_path, 'dns-assets', s3_key_fbx)
self.progress_label.config(text=f"Progress: {i} of {len(self.uploaded_fbx_paths)} files uploaded")
except Exception as e:
messagebox.showerror("Error", str(e))
messagebox.showinfo("Success", "Files uploaded successfully!")
def vectorize_image(self, image_path):
# return image_vector.tolist()
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = torch.nn.Sequential(*list(model.children())[:-1])
img = Image.open(image_path).convert('RGB')
img = preprocess(img)
img = img.unsqueeze(0) # Add batch dimension
with torch.no_grad():
embeddings = model(img)
return embeddings.squeeze().numpy()
def upload_to_neo4j(self, image_name, embedding, image_type):
neo4j_db = Neo4jDatabase()
neo4j_db.create_image_node(embedding, image_name, image_type)
class SearchFrame:
def __init__(self, notebook):
# self.master = master
self.notebook = notebook
self.frame = tk.Frame(self.notebook, bg='#121212')
self.frame.pack(fill=tk.BOTH, expand=True)
helv36 = tkFont.Font(family='Courier', size=18, weight='bold')
button_font = ctk.CTkFont(family='Courier', size=24, weight='bold')
# AWS Information Entry
tk.Label(self.frame, text="Select CSV File:", font=helv36, bg="#000",
fg="#fff").pack(pady=10)
ctk.CTkButton(self.frame, text="Browse Access Key", command=self.load_aws_info,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack()
self.image_type_label = tk.Label(self.frame, text="Image Type:", bg='#121212', fg='white')
self.image_type_label.pack(pady=(100, 0))
self.image_type_entry = tk.Entry(self.frame, bg='#484848', fg='white')
self.image_type_entry.pack(pady=(0, 0))
# Image Upload
ctk.CTkButton(self.frame, text="Search Image", command=self.search_image,
font=button_font, corner_radius=10, fg_color='#bb86fc', text_color='#000',
hover_color='#a435f0').pack(pady=(10, 10))
# Result Display
self.helv12 = tkFont.Font(family='Courier', size=12)
self.result_frame = tk.Frame(self.frame, bg='#121212')
self.result_frame.place(relx=0.5, rely=0.5, anchor=tk.CENTER) # Center the frame
self.result_frame.pack()
# Result Display
self.result_text = scrolledtext.ScrolledText(self.frame, wrap=tk.WORD, bg='#121212', fg='white', insertbackground='white', selectbackground='#444', selectforeground='white', font=self.helv12)
self.result_text.pack()
# Initialize AWS clients
self.s3_client = None
# Configure ttk.Style to use curved edges for buttons
style = ttk.Style()
# Apply styling to buttons
style.configure('TButton', borderwidth=0, focuscolor='#121212', lightcolor='#121212', darkcolor='#121212', relief='flat', background='#444', foreground='white', padding=10, font=('Arial', 10))
style.map('TButton', background=[('active', '#555')])
def load_aws_info(self):
file_path = filedialog.askopenfilename(title="Select CSV File", filetypes=[("CSV files", "*.csv")])
if file_path:
aws_info = self.read_aws_info_from_csv(file_path)
self.initialize_aws_clients(aws_info)
def read_aws_info_from_csv(self, file_path):
aws_info = {}
creds = pd.read_csv(file_path)
aws_info['AccessKey'] = creds.iloc[0, 0]
aws_info['SecretKey'] = creds.iloc[0, 1]
aws_info['Region']='ap-south-1'
# print(aws_info)
print('Info. uploaded and initialized')
return aws_info
def initialize_aws_clients(self, aws_info):
try:
self.s3_client = boto3.client('s3', aws_access_key_id=aws_info['AccessKey'],
aws_secret_access_key=aws_info['SecretKey'],
region_name=aws_info['Region'])
tk.messagebox.showinfo("Success", "AWS clients initialized successfully.")
except Exception as e:
tk.messagebox.showerror("Error", f"Failed to initialize AWS clients. {str(e)}")
def search_image(self):
if self.s3_client is None:
tk.messagebox.showerror("Error", "AWS client not initialized. Please load AWS info first.")
return
image_type = self.image_type_entry.get()
file_path = filedialog.askopenfilename()
if file_path:
self.uploaded_image_path = file_path
# Display the uploaded image preview
image = Image.open(file_path)
image.thumbnail((300, 300)) # Adjust the size as needed
photo = ImageTk.PhotoImage(image)
# Search Similar Images
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = torch.nn.Sequential(*list(model.children())[:-1])
img = Image.open(file_path).convert('RGB')
img = preprocess(img)
img = img.unsqueeze(0) # Add batch dimension
with torch.no_grad():
embeddings = model(img)
embeddings = embeddings.squeeze().numpy()
similar_images = self.search_similar_images(embeddings, image_type)
# Display the results on the UI
i = 0
similar_images_list = []
for key in similar_images.keys():
if i == 10:
break
similar_images_list.append(key)
self.display_results(similar_images_list)
def get_image_labels(self, file_path):
with open(file_path, 'rb') as image_file:
image_bytes = image_file.read()
response = self.rekognition_client.detect_labels(
Image={
'Bytes': image_bytes
}
)
# Extract labels from the response
labels = [label['Name'] for label in response['Labels']]
return labels
def search_similar_images(self, embeddings, image_type):
uri = ""
username = ""
password = ""
driver_con = GraphDatabase.driver(uri, auth=(username, password))
query = (
"""
MATCH (n:Image)
where n.type=$image_type
RETURN n.embeddings as embedding, n.name as path;
"""
)
with driver_con.session() as session:
images = {}
result = session.run(query, image_type=image_type)
print(result)
for record in result:
images[record['path']] = record['embedding']
similarity = {}
for key in images.keys():
similarity[key] = 1 - spatial.distance.cosine(embeddings, images[key])
similarity = dict(sorted(similarity.items(), key=lambda item: item[1]))
return similarity
def create_button(self, image_id):
download_button = ctk.CTkButton(self.result_text, text=f"Download Image {image_id}",
command=lambda i=image_id: self.download_image(i),
font=('Helvetica', 12), fg_color='#ac7ed7',
text_color='black', hover_color='#a435f0')
return download_button
# self.result_text.config(state=tk.DISABLED) # Disable text widget for editing
def display_results(self, similar_images):
for widget in self.result_frame.winfo_children():
widget.destroy()
if similar_images:
# Display data
for image in similar_images:
# image_id = image['image_id']
image_id = image
result_text = f"POID: {image_id}\n"
self.result_text.insert(tk.END, result_text)
# Add a button for each row (image ID)
download_button = self.create_button(image_id)
self.result_text.window_create(tk.END, window=download_button)
self.result_text.insert(tk.END, '\n\n\n')
else:
self.result_text.insert(tk.END, "No similar images found.")
# Configure grid weights to make it expandable
for i in range(self.result_frame.grid_size()[1]):
self.result_frame.grid_columnconfigure(i, weight=1)
self.result_text.config(state=tk.DISABLED)
def download_image(self, image_id):
if self.s3_client is None:
# tk.messagebox.showerror("Error", "AWS S3 client not initialized.")
tk.messagebox.showinfo("Download Image", f"Downloading image with ID: {image_id}")
return
try:
# Replace 'your-bucket-name' with your actual S3 bucket name
bucket_name = 'dns-assets'
folder_path = 'images' # Update this to the folder path where your images are stored
object_key = f"{folder_path}/{image_id}" # Assuming images are stored with '.jpg' extension
print(object_key)
# Specify the local directory where you want to save the downloaded image
local_directory = tempfile.gettempdir()
print(local_directory)
local_path = os.path.join(local_directory, f"{image_id}")
# Download the image from S3
i = 0
sleep = 2
retries=3
while(i <= retries):
try:
# self.s3_client.download_file(bucket,s3_path,local_path)
self.s3_client.download_file(bucket_name, object_key, local_path)
break
except Exception as e:
print("404 file not found !!!")
i = i+1
if i>retries:
raise Exception(traceback.format_exc())
time.sleep(sleep)
sleep = sleep*2
print("retry: "+str(i))
# self.s3_client.download_file(bucket_name, object_key, local_path)
# Open the downloaded image using the default viewer
os.startfile(local_path)
except Exception as e:
tk.messagebox.showerror("Error", f"Failed to download image. {str(e)}")
if __name__ == "__main__":
root = tk.Tk()
app = AWSApp(root)
root.configure(bg='#121212')
root.mainloop()