diff --git a/anomaly_detection/padim/README.md b/anomaly_detection/padim/README.md index f0836d9cf..5df56d8eb 100644 --- a/anomaly_detection/padim/README.md +++ b/anomaly_detection/padim/README.md @@ -71,8 +71,12 @@ You can specify the directory of normal product files with the `--train_dir` opt $ python3 padim.py --train_dir train ``` -The feature vectors created from files in the train directory are saved to the pickle file. -From the second time, by specifying the pickle file by `--feat` option, +The feature vectors created from files in the train directory can be saved to the pickle, numpy or pytorch file. For that you have to specify desired format as `npy`, `pt` or `pkl`. Depending on --enable_optimization flag, files saved might be originally pytorch tensors (True) or numpy arrays (False). +```bash +$ python3 padim.py --save_format pt +``` + +From the second time, by specifying the saved file by `--feat` option, it can omit the calculation of the feature vector of the normal product. The name of the pickle file created is the same as the name of a normal product file directory. ```bash @@ -104,6 +108,20 @@ By adding the `--aug` option, you can process with augmentation. $ python3 padim.py --aug ``` +By adding the `--enable_optimization` option, you can use optimized code, which significantly speeds up distance matrix calculation. You have to have cuda compatible GPU and pytorch installed +(default is processing without optimization) +```bash +$ python3 padim.py --enable_optimization True +``` + +By adding the `--compare_optimization` option, you can compare output of the optimized code with the output of the original code. You have to have cuda compatible GPU and pytorch installed +(default is processing without comparison) +```bash +$ python3 padim.py --compare_optimization True +``` + + + ## PaDiM GUI You can also use the GUI to train and test. @@ -115,6 +133,7 @@ Start the GUI with the following command. ```bash $ python3 padim_gui.py ``` +Specify the inference mode from the `Set optimization` button and choose the `Trained file format`. Specify the folder from the `Select train folder` button and press the `Train button`. @@ -123,6 +142,8 @@ Inference results are listed in Result images. Change the Threshold and press the `Test button` again +If you want to benchmark the inference, set `Benchmark mode` + ## Reference [PaDiM-Anomaly-Detection-Localization-master](https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master) diff --git a/anomaly_detection/padim/padim.py b/anomaly_detection/padim/padim.py index 646ba571f..38ab09bc1 100644 --- a/anomaly_detection/padim/padim.py +++ b/anomaly_detection/padim/padim.py @@ -49,8 +49,8 @@ help='arch model.' ) parser.add_argument( - '-f', '--feat', metavar="PICKLE_FILE", default=None, - help='train set feature pkl files.' + '-f', '--feat', metavar="FILE", default=None, + help='train set feature files.' ) parser.add_argument( '-bs', '--batch_size', default=32, @@ -80,8 +80,50 @@ '-an', '--aug_num', type=int, default=5, help='specify the amplification number of augmentation.' ) +parser.add_argument( + '-eon', '--enable_optimization', type=bool, default=False, + help='Flag to enable optimized code' +) +parser.add_argument( + '--compare_optimization', type=bool, default=False, + help='Flag to compare output of optimization with original code' +) +parser.add_argument( + '--compare_optimization', type=bool, default=False, + help='Flag to compare output of optimization with original code' +) + +parser.add_argument( + '--save_format', metavar="FILE", default="pkl", + help='chose training file format pt, npy or pkl.' +) + +parser.add_argument( + '--optimization_device', metavar="device", default='cpu', choices=('cpu', 'cuda', 'mps'), + help='chose optimization device' +) + args = update_parser(parser) +if args.compare_optimization: + args.enable_optimization = True + train_output_list=[] + +if args.enable_optimization: + import torch + if args.optimization_device=="cuda" and torch.cuda.is_available(): + device = torch.device("cuda") + + + elif args.optimization_device=="mps" and torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + weights_torch=gaussian_kernel1d_torch(4, 0, int(4.0*float(4)+0.5), device).unsqueeze(0).unsqueeze(0).expand(1, 1, 33) + logger.info("Torch device : " + str(device)) + + # ====================== # Main functions @@ -101,7 +143,11 @@ def plot_fig(file_list, test_imgs, scores, anormal_scores, gt_imgs, threshold, s gt = gt.transpose(1, 2, 0).squeeze() else: gt = np.zeros((1,1,1)) - heat_map, mask, vis_img = visualize(img, scores[i], threshold) + if args.enable_optimization: + heat_map, mask, vis_img = visualize(img, scores[i].squeeze(0).cpu().numpy(), threshold) + else: + heat_map, mask, vis_img = visualize(img, scores[i], threshold) + fig_img, ax_img = plt.subplots(1, 5, figsize=(12, 3)) fig_img.subplots_adjust(right=0.9) @@ -152,21 +198,34 @@ def plot_fig(file_list, test_imgs, scores, anormal_scores, gt_imgs, threshold, s fig_img.savefig(savepath_tmp, dpi=100) plt.close() +def infer_init_run(net, params, train_outputs, IMAGE_SIZE): + import numpy as np + dummy_image = np.random.rand(1, 3, 224, 224) * 255.0 # Scale between 0 and 255 + # Convert the dtype to float32 for efficiency + dummy_image = dummy_image.astype(np.float32) + logger.info(f"PaDiM initialization inference starts!") + + if args.enable_optimization: + score = infer_optimized(net, params, train_outputs, dummy_image, IMAGE_SIZE, device, logger, weights_torch) + else: + score = infer(net, params, train_outputs, dummy_image, IMAGE_SIZE) + logger.info(f"PaDiM initialization inference finish!") + def train_from_image_or_video(net, params): # training - train_outputs = training(net, params, IMAGE_RESIZE, IMAGE_SIZE, KEEP_ASPECT, int(args.batch_size), args.train_dir, args.aug, args.aug_num, args.seed, logger) - + if args.enable_optimization: + train_outputs = training_optimized(net, params, IMAGE_RESIZE, IMAGE_SIZE, KEEP_ASPECT, int(args.batch_size), args.train_dir, args.aug, args.aug_num, args.seed, logger) + else: + train_outputs = training(net, params, IMAGE_RESIZE, IMAGE_SIZE, KEEP_ASPECT, int(args.batch_size), args.train_dir, args.aug, args.aug_num, args.seed, logger) # save learned distribution if args.feat: train_feat_file = args.feat else: train_dir = args.train_dir - train_feat_file = "%s.pkl" % os.path.basename(train_dir) - logger.info('saving train set feature to: %s ...' % train_feat_file) - with open(train_feat_file, 'wb') as f: - pickle.dump(train_outputs, f) - logger.info('saved.') + train_feat_file = str(os.path.basename(train_dir))+"."+str(args.save_format) + + train_outputs=_save_training_flie(train_feat_file, args.save_format, train_outputs) return train_outputs @@ -200,14 +259,20 @@ def decide_threshold_from_gt_image(net, params, train_outputs, gt_imgs): img = load_image(image_path) img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT, crop_size = IMAGE_SIZE) + if args.enable_optimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger, weights_torch) + else: + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) - dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) score_map.append(dist_tmp) + if args.enable_optimization: + scores = normalize_scores_torch(score_map, IMAGE_SIZE) + threshold = decide_threshold(scores.cpu().numpy(), gt_imgs) - scores = normalize_scores(score_map, IMAGE_SIZE) - - threshold = decide_threshold(scores, gt_imgs) + else: + scores = normalize_scores_torch(score_map, IMAGE_SIZE) + threshold = decide_threshold(scores, gt_imgs) return threshold @@ -219,6 +284,7 @@ def infer_from_image(net, params, train_outputs, threshold, gt_imgs): test_imgs = [] score_map = [] + infer_init_run(net, params, train_outputs, IMAGE_SIZE) for i_img in range(0, len(args.input)): logger.info('from (%s) ' % (args.input[i_img])) @@ -228,25 +294,50 @@ def infer_from_image(net, params, train_outputs, threshold, gt_imgs): img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT, crop_size = IMAGE_SIZE) test_imgs.append(img[0]) - + if args.benchmark: logger.info('BENCHMARK mode') total_time = 0 - for i in range(args.benchmark_count): - start = int(round(time.time() * 1000)) - dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) - end = int(round(time.time() * 1000)) - logger.info(f'\tailia processing time {end - start} ms') - if i != 0: - total_time = total_time + (end - start) - logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms') + if args.enable_optimization: + for i in range(args.benchmark_count): + start = int(round(time.time() * 1000)) + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger, weights_torch) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms') + else: + for i in range(args.benchmark_count): + start = int(round(time.time() * 1000)) + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms') + if args.compare_optimization: + logger.info(f'\tResults of optimized and original code is the same: {np.allclose(infer(net, params, train_output_list[0], img, IMAGE_SIZE), infer_optimized(net, params, train_output_list[1], img, IMAGE_SIZE, device,logger, weights_torch))}') + + else: - dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + if args.enable_optimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger, weights_torch) + else: + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + if args.compare_optimization: + logger.info('Results of optimized and original code is the same: '+ str(np.allclose(infer(net, params, train_output_list[0], img, IMAGE_SIZE), infer_optimized(net, params, train_output_list[1], img, IMAGE_SIZE, device,logger, weights_torch)))) + score_map.append(dist_tmp) + if args.enable_optimization: + scores = normalize_scores_torch(score_map, IMAGE_SIZE) + anormal_scores = calculate_anormal_scores_torch(score_map, IMAGE_SIZE) + + else: + scores = normalize_scores(score_map, IMAGE_SIZE) + anormal_scores = calculate_anormal_scores(score_map, IMAGE_SIZE) - scores = normalize_scores(score_map, IMAGE_SIZE) - anormal_scores = calculate_anormal_scores(score_map, IMAGE_SIZE) # Plot gt image plot_fig(args.input, test_imgs, scores, anormal_scores, gt_imgs, threshold, args.savepath) @@ -264,29 +355,56 @@ def infer_from_video(net, params, train_outputs, threshold): score_map = [] frame_shown = False - while(True): - ret, frame = capture.read() - if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: - break - if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: - break + infer_init_run(net, params, train_outputs, IMAGE_SIZE) + + if args.enable_optimization: + while(True): + ret, frame = capture.read() + if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: + break + if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: + break - img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT) + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT) - dist_tmp = infer(net, params, train_outputs, img) + dist_tmp = infer_optimized(net, params, train_outputs, img, device,logger, weights_torch) - score_map.append(dist_tmp) - scores = normalize_scores(score_map) # min max is calculated dynamically, please set fixed min max value from calibration data for production + score_map.append(dist_tmp) + scores = normalize_scores(score_map) # min max is calculated dynamically, please set fixed min max value from calibration data for production + + heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) + frame = pack_visualize(heat_map, mask, vis_img, scores, IMAGE_SIZE) + + cv2.imshow('frame', frame) + frame_shown = True + + if writer is not None: + writer.write(frame) + else: + while(True): + ret, frame = capture.read() + if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: + break + if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: + break + + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT) + + dist_tmp = infer(net, params, train_outputs, img) - heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) - frame = pack_visualize(heat_map, mask, vis_img, scores, IMAGE_SIZE) + score_map.append(dist_tmp) + scores = normalize_scores(score_map) # min max is calculated dynamically, please set fixed min max value from calibration data for production - cv2.imshow('frame', frame) - frame_shown = True + heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) + frame = pack_visualize(heat_map, mask, vis_img, scores, IMAGE_SIZE) - if writer is not None: - writer.write(frame) + cv2.imshow('frame', frame) + frame_shown = True + + if writer is not None: + writer.write(frame) capture.release() cv2.destroyAllWindows() @@ -295,15 +413,14 @@ def infer_from_video(net, params, train_outputs, threshold): def train_and_infer(net, params): + timestart=time.time() if args.feat: - logger.info('loading train set feature from: %s' % args.feat) - with open(args.feat, 'rb') as f: - train_outputs = pickle.load(f) + train_outputs=_load_training_file(args.feat, args.save_format) logger.info('loaded.') else: train_outputs = train_from_image_or_video(net, params) - if args.threshold is None: + if args.threshold is None: if args.video: threshold = 0.5 gt_imgs = None @@ -324,10 +441,170 @@ def train_and_infer(net, params): infer_from_image(net, params, train_outputs, threshold, gt_imgs) logger.info('Script finished successfully.') +def _save_training_flie(train_feat_file, save_format, train_outputs): + if args.compare_optimization: + train_output_list.append(train_outputs) + + if not args.enable_optimization: + if save_format == "pkl" : + if train_feat_file==None: + train_feat_file = "train.pkl" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "npy" : + filename=train_feat_file.split(".")[0].strip() + for i, output in enumerate(train_outputs): + + if train_feat_file==None: + train_feat_file = "train_output_"+str(i)+".npy" + else: + train_feat_file = filename+str("_")+str(i)+".npy" + + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"{filename}_{i}.npy", output) + logger.info('Saved.') + elif save_format == "pt": + if train_feat_file==None: + train_feat_file = "train.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') + + + else: + if save_format=="npy": + filename=train_feat_file.split(".")[0].strip() + for i, output in enumerate(train_outputs): + if train_feat_file==None: + train_feat_file = "train_output_"+str(i)+".npy" + else: + train_feat_file = filename+"_"+str(i)+".npy" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"{filename}_{i}.npy", output) + logger.info('Saved.') + + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + if save_format == "pkl" : + if train_feat_file==None: + train_feat_file = "trainOptimized.pkl" + logger.info('saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "pt": + if train_feat_file==None: + train_feat_file = "trainOptimized.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') + + + if args.compare_optimization: + train_output_list.append(train_outputs) + + return train_outputs + +def _load_training_file(train_feat_file, save_format): + if _check_file_exists(train_feat_file, save_format): + if not train_feat_file: + train_feat_file = "trainOptimized."+save_format + else: + save_format=train_feat_file.split(".")[1].strip() + logger.info(f"Save format {save_format}") + + if args.enable_optimization: + + + if save_format== "pkl": + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + train_outputs=[train_outputs[0].to(device), train_outputs[1], + train_outputs[2].to(device), train_outputs[3] ] + elif save_format == "npy": + train_outputs = [] + i = 0 + if train_feat_file: + train_feat_file=train_feat_file.split("_")[0].strip() + else: + train_feat_file="train_" + + while True: + try: + + logger.info(f"{train_feat_file}_{i}.npy") + train_outputs.append(np.load(f"{train_feat_file}_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + elif save_format == "pt": + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + else: + + if save_format == "pkl": + train_feat_file = "train."+save_format + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + elif save_format == "npy": + train_outputs = [] + i = 0 + if train_feat_file: + train_feat_file=train_feat_file.split("_")[0].strip() + else: + train_feat_file="train_" + while True: + try: + + logger.info(f"{train_feat_file}_{i}.npy") + train_outputs.append(np.load(f"{train_feat_file}_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + + elif save_format == "pt": + train_feat_file = "train."+save_format + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + train_outputs_numpy = [] + if train_outputs[0] is torch.Tensor: + for item in train_outputs: + if isinstance(item, torch.Tensor): + train_outputs_numpy.append(item.cpu().numpy()) # Move to CPU and convert to NumPy + else: + train_outputs_numpy.append(item) + return train_outputs_numpy + return train_outputs + else: + logger.info(f"filename {train_feat_file} have not been found") + + +def _check_file_exists(train_feat_file, save_format): + if train_feat_file == None: + if save_format=="npy": + filename="train_output_0.npy" + elif args.enable_optimization: + filename="trainOptimized."+save_format + else: + filename="train."+save_format + if not os.path.isfile(filename): + logger.info(f"File {filename} does not exist. Unable to load the model") + else: + return os.path.isfile(train_feat_file) + + return os.path.isfile(filename) + def main(): # model files check and download - weight_path, model_path, params = get_params(args.arch) + starttime=time.time() + weight_path, model_path, params = get_params(args.arch) check_and_download_models(weight_path, model_path, REMOTE_PATH) # create net instance @@ -335,6 +612,7 @@ def main(): # check input train_and_infer(net, params) + logger.info('Script finished execution time: '+str(int((time.time()-starttime)*1000))) if __name__ == '__main__': diff --git a/anomaly_detection/padim/padim_gui.py b/anomaly_detection/padim/padim_gui.py index d9ebd476c..c9e0c05fc 100644 --- a/anomaly_detection/padim/padim_gui.py +++ b/anomaly_detection/padim/padim_gui.py @@ -40,6 +40,7 @@ model_index = 0 slider_index = 50 + REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/padim/' train_folder = None @@ -47,6 +48,18 @@ test_type = "folder" test_roi = None score_cache = {} +EnableOptimization = True +save_format="pkl" +BENCHMARK = False +selected_device = "cpu" +if EnableOptimization: + import torch + + device = torch.device("cpu") + weights_torch=gaussian_kernel1d_torch(4, 0, int(4.0*float(4)+0.5), device).unsqueeze(0).unsqueeze(0).expand(1, 1, 33) + + + logger.info("Torch device : " + str(device)) # ====================== # List box cursor changed @@ -91,6 +104,56 @@ def slider_changed(event): global scale, slider_index slider_index = scale.get() +def enable_optimization(event): + global EnableOptimization + selection = event.widget.curselection() + if selection: + selected_index = selection[0] + selected_value = event.widget.get(selected_index) + EnableOptimization = (selected_value == "True") + else: + EnableOptimization = False + logger.info(f"EnableOptimization set to: {EnableOptimization}") + +def save_type_select(event): + global save_format + selection = event.widget.curselection() + if selection: + selected_index = selection[0] + selected_value = event.widget.get(selected_index) + save_format = selected_value + else: + save_format = "pkl" + logger.info(f"Selected format set to: {save_format}") + +def enable_benchmark(event): + global BENCHMARK + selection = event.widget.curselection() + if selection: + selected_index = selection[0] + selected_value = event.widget.get(selected_index) + BENCHMARK = (selected_value == "True") + else: + BENCHMARK = False + logger.info(f"BENCHMARK set to: {BENCHMARK}") + +def select_device(event): + global device, weights_torch + selection = event.widget.curselection() + if selection: + selected_index = selection[0] + selected_value = event.widget.get(selected_index) + device = torch.device(selected_value) + weights_torch = gaussian_kernel1d_torch(4, 0, int(4.0*float(4)+0.5), device).unsqueeze(0).unsqueeze(0).expand(1, 1, 33) + + if selected_value=="cuda:0": + torch.cuda.set_per_process_memory_fraction(fraction=0.5, device="cuda:0") + else: + device = torch.device("cpu") + + + logger.info(f"Device set to: {device}") + # ====================== # List box double click # ====================== @@ -147,7 +210,7 @@ def create_photo_image(path,w=CANVAS_W,h=CANVAS_H): #image_bgr = cv2.resize(image_bgr,(w,h)) image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) # imreadはBGRなのでRGBに変換 image_pil = Image.fromarray(image_rgb) # RGBからPILフォーマットへ変換 - image_pil.thumbnail((w,h), Image.ANTIALIAS) + image_pil.thumbnail((w,h), Image.LANCZOS) image_tk = ImageTk.PhotoImage(image_pil) # ImageTkフォーマットへ変換 return image_tk @@ -226,16 +289,51 @@ def train_button_clicked(): aug = False aug_num = 0 seed = 1024 - train_outputs = training(net, params, get_image_resize(), get_image_crop_size(), get_keep_aspect(), batch_size, train_dir, aug, aug_num, seed, logger) + train_outputs=training_optimized(net, params, get_image_resize(), get_image_crop_size(), get_keep_aspect(), batch_size, train_dir, aug, aug_num, seed, logger) + if not EnableOptimization: + if save_format == "pkl" : + train_feat_file = "train.pkl" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "npy" : + + for i, output in enumerate(train_outputs): + train_feat_file = "train_output_"+str(i)+".npy" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"train_output_{i}.npy", output) + logger.info('Saved.') + elif save_format == "pt": + train_feat_file = "train.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') - # save learned distribution - train_feat_file = "train.pkl" - #train_dir = args.train_dir - #train_feat_file = "%s.pkl" % os.path.basename(train_dir) - logger.info('saving train set feature to: %s ...' % train_feat_file) - with open(train_feat_file, 'wb') as f: - pickle.dump(train_outputs, f) - logger.info('saved.') + else: + if save_format=="npy": + for i, output in enumerate(train_outputs): + train_feat_file = "train_output_"+str(i)+".npy" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"train_output_{i}.npy", output) + logger.info('Saved.') + + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + if save_format == "pkl" : + train_feat_file = "trainOptimized.pkl" + logger.info('saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "pt": + train_feat_file = "trainOptimized.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') + + + global score_cache score_cache = {} @@ -259,11 +357,59 @@ def test_button_clicked(): # create net instance env_id = ailia.get_gpu_environment_id() net = ailia.Net(model_path, weight_path, env_id=env_id) + if _check_file_exists(save_format): + if EnableOptimization: + train_feat_file = "trainOptimized."+save_format + + if save_format== "pkl": + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + train_outputs=[train_outputs[0].to(device), train_outputs[1], + train_outputs[2].to(device), train_outputs[3] ] + elif save_format == "npy": + train_outputs = [] + i = 0 + while True: + try: + logger.info(f"Loading train_output_{i}.npy") + train_outputs.append(np.load(f"train_output_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + elif save_format == "pt": + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + else: + train_feat_file = "train."+save_format + if save_format == "pkl": + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + elif save_format == "npy": + train_outputs = [] + i = 0 + while True: + try: + logger.info(f"Loading train_output_{i}.npy") + train_outputs.append(np.load(f"train_output_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + + elif save_format == "pt": + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + train_outputs_numpy = [] + if train_outputs[0] is torch.Tensor: + for item in train_outputs: + if isinstance(item, torch.Tensor): + train_outputs_numpy.append(item.cpu().numpy()) # Move to CPU and convert to NumPy + else: + train_outputs_numpy.append(item) - # load trained model - with open("train.pkl", 'rb') as f: - train_outputs = pickle.load(f) - threshold = slider_index / 100.0 if test_type == "folder": @@ -271,6 +417,18 @@ def test_button_clicked(): else: test_from_video(net, params, train_outputs, threshold) +def _check_file_exists(save_format): + if save_format=="npy": + filename="train_output_0.npy" + elif EnableOptimization: + filename="trainOptimized."+save_format + else: + filename="train."+save_format + if not os.path.isfile(filename): + logger.info(f"File {filename} does not exist. Unable to load the model") + + return os.path.isfile(filename) + def test_from_folder(net, params, train_outputs, threshold): # file loop test_imgs = [] @@ -282,7 +440,7 @@ def test_from_folder(net, params, train_outputs, threshold): roi_img = preprocess(roi_img, get_image_resize(), keep_aspect=get_keep_aspect(), crop_size=get_image_crop_size(), mask=True) else: roi_img = None - + score_map = [] for i_img in range(0, len(test_list)): logger.info('from (%s) ' % (test_list[i_img])) @@ -296,20 +454,59 @@ def test_from_folder(net, params, train_outputs, threshold): if image_path in score_cache: dist_tmp = score_cache[image_path].copy() else: - dist_tmp = infer(net, params, train_outputs, img, get_image_crop_size()) - score_cache[image_path] = dist_tmp.copy() + if BENCHMARK: + import time + logger.info('BENCHMARK mode') + total_time = 0 + if EnableOptimization: + for i in range(6): + start = int(round(time.time() * 1000)) + dist_tmp = infer_optimized(net, params, train_outputs, img, get_image_crop_size(), device, logger, weights_torch) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / 5} ms') + else: + for i in range(6): + start = int(round(time.time() * 1000)) + dist_tmp = infer(net, params, train_outputs, img, get_image_crop_size()) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / 5} ms') + + else: + if EnableOptimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, get_image_crop_size(), device, logger, weights_torch) + score_cache[image_path] = dist_tmp + else: + dist_tmp = infer(net, params, train_outputs, img, get_image_crop_size()) + score_cache[image_path] = dist_tmp.copy() score_map.append(dist_tmp) - - scores = normalize_scores(score_map, get_image_crop_size(), roi_img) - anormal_scores = calculate_anormal_scores(score_map, get_image_crop_size()) + if EnableOptimization: + scores = normalize_scores_torch(score_map, get_image_crop_size()).squeeze(0).cpu().numpy() + anormal_scores = calculate_anormal_scores_torch(score_map, get_image_crop_size()) + else: + scores = normalize_scores(score_map, get_image_crop_size(), roi_img) + anormal_scores = calculate_anormal_scores(score_map, get_image_crop_size()) # Plot gt image os.makedirs("result", exist_ok=True) global result_list, listsResult, ListboxResult result_list = [] - for i in range(0, scores.shape[0]): + scores_numpy=np.zeros(scores.shape) + for i in range(0, len(test_imgs)): img = denormalization(test_imgs[i]) heat_map, mask, vis_img = visualize(img, scores[i], threshold) + """ + if EnableOptimization: + scores_numpy[i] = scores[i].squeeze(0).cpu().numpy() + heat_map, mask, vis_img = visualize(img, scores[i].squeeze(0).cpu().numpy(), threshold) + else: + heat_map, mask, vis_img = visualize(img, scores[i], threshold) + """ frame = pack_visualize(heat_map, mask, vis_img, scores, get_image_crop_size()) dirname, path = os.path.split(test_list[i]) output_path = "result/"+path @@ -345,17 +542,24 @@ def test_from_video(net, params, train_outputs, threshold): img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = preprocess(img, get_image_resize(), keep_aspect=get_keep_aspect(), crop_size=get_image_crop_size()) + if EnableOptimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, get_image_crop_size(), device, logger, weights_torch), + score_map.append(dist_tmp) + roi_img = None + scores = normalize_scores_torch(score_map, get_image_crop_size(), roi_img) + - dist_tmp = infer(net, params, train_outputs, img, get_image_crop_size()) + else: + dist_tmp = infer(net, params, train_outputs, img, get_image_crop_size()) - score_map.append(dist_tmp) - roi_img = None - scores = normalize_scores(score_map, get_image_crop_size(), roi_img) # min max is calculated dynamically, please set fixed min max value from calibration data for production + score_map.append(dist_tmp) + roi_img = None + scores = normalize_scores(score_map, get_image_crop_size(), roi_img) # min max is calculated dynamically, please set fixed min max value from calibration data for production - heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) - frame = pack_visualize(heat_map, mask, vis_img, scores, get_image_crop_size()) + heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) + frame = pack_visualize(heat_map, mask, vis_img, scores, get_image_crop_size()) - cv2.imshow('frame', frame) + cv2.imshow('frame', frame) frame_shown = True if writer is not None: @@ -537,6 +741,8 @@ def get_model_id_list(): def get_model_resolution_list(): return [224, 448, 224] +def get_optimization_list(): + return [True, False] # ====================== # GUI # ====================== @@ -553,6 +759,7 @@ def main(): global listsModel, ListboxModel global valueKeepAspect, valueCenterCrop + # rootメインウィンドウの設定 root = tk.Tk() root.title("PaDiM GUI") @@ -567,18 +774,24 @@ def main(): test_list = get_test_file_list() result_list = get_result_file_list() model_list = get_model_list() + listsInput = tk.StringVar(value=train_list) listsOutput = tk.StringVar(value=test_list) listsResult = tk.StringVar(value=result_list) listsModel = tk.StringVar(value=model_list) + + # 各種ウィジェットの作成 ListboxInput = tk.Listbox(frame, listvariable=listsInput, width=20, height=12, selectmode=tk.BROWSE, exportselection=False) ListboxOutput = tk.Listbox(frame, listvariable=listsOutput, width=20, height=12, selectmode=tk.BROWSE, exportselection=False) ListboxResult = tk.Listbox(frame, listvariable=listsResult, width=20, height=12, selectmode=tk.BROWSE, exportselection=False) ListboxModel = tk.Listbox(frame, listvariable=listsModel, width=20, height=6, selectmode=tk.BROWSE, exportselection=False) + + # Bind the listbox selection event to the toggle_optimization function + ListboxInput.bind("<>", input_changed) ListboxOutput.bind("<>", output_changed) ListboxResult.bind("<>", result_changed) @@ -587,11 +800,13 @@ def main(): ListboxInput.bind("", input_double_click) ListboxOutput.bind("", output_double_click) ListboxResult.bind("", result_double_click) + ListboxInput.select_set(input_index) ListboxOutput.select_set(output_index) ListboxResult.select_set(result_index) ListboxModel.select_set(model_index) + textRun = tk.StringVar(frame) textRun.set("Train") @@ -650,6 +865,18 @@ def main(): textSave = tk.StringVar(frame) textSave.set("Save images") + textOptimization = tk.StringVar(frame) + textOptimization.set("Set optimization") + + textSaveFile = tk.StringVar(frame) + textSaveFile.set("Trained file format") + + textBenchmark = tk.StringVar(frame) + textBenchmark.set("Benchmark mode") + + textDevice = tk.StringVar(frame) + textDevice.set("Optimization device") + valueKeepAspect = tkinter.BooleanVar() valueKeepAspect.set(True) valueCenterCrop = tkinter.BooleanVar() @@ -667,6 +894,11 @@ def main(): labelModel = tk.Label(frame, textvariable=textModel) labelTestSettings = tk.Label(frame, textvariable=textTestSettings) labelSlider = tk.Label(frame, textvariable=textSlider) + labelOPt = tk.Label(frame, textvariable=textOptimization) + labelSaveFile = tk.Label(frame, textvariable=textSaveFile) + labelBenchmark = tk.Label(frame, textvariable=textBenchmark) + labelDevice = tk.Label(frame, textvariable=textDevice) + buttonTrain = tk.Button(frame, textvariable=textRun, command=train_button_clicked, width=14) buttonTest = tk.Button(frame, textvariable=textStop, command=test_button_clicked, width=14) @@ -738,12 +970,72 @@ def main(): labelSlider.grid(row=9, column=4, sticky=tk.NW, columnspan=3) scale.grid(row=10, column=4, sticky=tk.NW, columnspan=3) + options = ["False", "True"] + listsOPt = tk.StringVar(value=options) + labelOPt.grid(row=11, column=0, sticky=tk.NW, columnspan=3) + ListboxOptimization = tk.Listbox(frame, listvariable=listsOPt, width=20, height=len(options), selectmode=tk.BROWSE, exportselection=False) + ListboxOptimization.grid(row=11, column=0, padx=0, pady=20) + + # Set the initial selection in the Listbox + initial_selection = options.index("True") # Default to "True" + ListboxOptimization.select_set(initial_selection) + ListboxOptimization.event_generate('<>') # Trigger the event to set the initial state + + # Bind the listbox selection event to the toggle_optimization function + ListboxOptimization.bind('<>', enable_optimization) + + fileOptions = ["pkl", "pt", "npy"] + listsFileOPt = tk.StringVar(value=fileOptions) + labelSaveFile.grid(row=12, column=0, sticky=tk.NW, columnspan=3) + ListboxFileSelect = tk.Listbox(frame, listvariable=listsFileOPt, width=20, height=len(fileOptions), selectmode=tk.BROWSE, exportselection=False) + ListboxFileSelect.grid(row=12, column=0, padx=0, pady=20) + + # Set the initial selection in the Listbox + initial_selectionFile = fileOptions.index("pkl") # Default to "pkl" + ListboxFileSelect.select_set(initial_selectionFile) + ListboxFileSelect.event_generate('<>') # Trigger the event to set the initial state + + # Bind the listbox selection event to the save_type_select function + ListboxFileSelect.bind('<>', save_type_select) + + + fileDevice = ["cpu", "cuda:0", "mps"] + listsFileDevice = tk.StringVar(value=fileDevice) + labelDevice.grid(row=11, column=4, sticky=tk.NW) + ListboxDeviceSelect = tk.Listbox(frame, listvariable=listsFileDevice, width=20, height=len(fileDevice), selectmode=tk.BROWSE, exportselection=False) + ListboxDeviceSelect.grid(row=11, column=4, padx=0, pady=20) + + # Set the initial selection in the Listbox + initial_selection_Device = fileDevice.index("cpu") # Default to "False" + ListboxDeviceSelect.select_set(initial_selection_Device) + ListboxDeviceSelect.event_generate('<>') # Trigger the event to set the initial state + + # Bind the listbox selection event to the save_type_select function + ListboxDeviceSelect.bind('<>', select_device) + # メインフレームの作成と設置 + + + fileBenchmark = ["True", "False"] + listsFileBenchmark = tk.StringVar(value=fileBenchmark) + labelBenchmark.grid(row=11, column=5, sticky=tk.NW) + ListboxFileBenchmark = tk.Listbox(frame, listvariable=listsFileBenchmark, width=20, height=len(fileOptions), selectmode=tk.BROWSE, exportselection=False) + ListboxFileBenchmark.grid(row=11, column=5, padx=0, pady=20) + + # Set the initial selection in the Listbox + initial_selection_benchmark = fileBenchmark.index("False") # Default to "False" + ListboxFileBenchmark.select_set(initial_selection_benchmark) + ListboxFileBenchmark.event_generate('<>') # Trigger the event to set the initial state + + # Bind the listbox selection event to the save_type_select function + ListboxFileBenchmark.bind('<>', enable_benchmark) # メインフレームの作成と設置 frame = ttk.Frame(root) frame.pack(padx=20, pady=10) root.mainloop() + + if __name__ == '__main__': main() diff --git a/anomaly_detection/padim/padim_utils.py b/anomaly_detection/padim/padim_utils.py index 1586a77d4..1fef26520 100644 --- a/anomaly_detection/padim/padim_utils.py +++ b/anomaly_detection/padim/padim_utils.py @@ -9,6 +9,7 @@ from image_utils import normalize_image # noqa: E402 from detector_utils import load_image # noqa: E402 + from scipy.spatial.distance import mahalanobis from scipy.ndimage import gaussian_filter @@ -16,6 +17,9 @@ from skimage import morphology from skimage.segmentation import mark_boundaries +import torch +import torch.nn.functional as F + WEIGHT_RESNET18_PATH = 'resnet18.onnx' MODEL_RESNET18_PATH = 'resnet18.onnx.prototxt' @@ -56,6 +60,48 @@ def embedding_concat(x, y): return a +def embedding_concat_optimizded(x, y, device): + B, C1, H1, W1 = x.shape + _, C2, H2, W2 = y.shape + + assert H1 == W1 + + s = H1 // H2 + # Chesboard pattern downscaling + sel = torch.from_numpy(np.asarray([ + np.array([i for i in range(i, H1, s)]) for i in range(s) + ])).to(device) + index3 = sel[torch.repeat_interleave(torch.arange(s, device=device), s)] + index4 = sel[torch.tile(torch.arange(s, device=device), (s,))] + #a = x[:, :, index3[:, None].T, index4.T].permute(0, 1, 4, 2, 3) + a = x[:, :, index3[:, None].permute(2, 1, 0), index4.permute(1, 0)].permute(0, 1, 4, 2, 3) + # Concatenation + z = torch.cat(( + a, + torch.tile(y[:, :, None, :, :], (1, 1, s * s, 1, 1)) + ), axis=1) + # Downsizing and reshaping + z = z.reshape(B, -1, s, s, H2, W2).permute(0, 1, 4, 2, 5, 3).reshape(B, -1, H1, W1) + return z + +def embedding_concat_numpy(x, y): + B, C1, H1, W1 = x.shape + _, C2, H2, W2 = y.shape + assert H1 == W1 + s = H1 // H2 + #Chesboard pattern downscaling + sel = np.asarray([np.array([i for i in range(i, H1, s)]) for i in range(s)]) + index3= sel[np.repeat(np.arange(s), s)] + index4= sel[np.tile(np.arange(s), s)] + a=x[:, :, index3[:, None].T, index4.T].transpose((0, 1, 4, 2, 3)) + #concatination + z=np.concatenate((a, np.tile((y[:, :, None, :, :]), (1, 1, s*s, 1, 1))), axis=1)#.reshape((B, -1, H2 , W2)) + _, C3, _, _, _ = z.shape #(1, 448, 16, 28, 28) + #downsizing and rescaling + z=z.reshape((B, C3, s, s, H2,W2)).transpose((0, 1, 4, 2, 5, 3)).reshape(B, C3, H1, W1) + + return z + def preprocess(img, size, crop_size, mask=False, keep_aspect = True): h, w = img.shape[:2] @@ -69,7 +115,7 @@ def preprocess(img, size, crop_size, mask=False, keep_aspect = True): else: size = (size, size) img = np.array(Image.fromarray(img).resize( - size, resample=Image.ANTIALIAS if not mask else Image.NEAREST)) + size, resample=Image.LANCZOS if not mask else Image.NEAREST)) # center crop h, w = img.shape[:2] @@ -139,11 +185,22 @@ def preprocess_aug(img, size, crop_size, mask=False, keep_aspect = True, angle_r return img -def postprocess(outputs): +def postprocess(outputs, training_c=False): + embedding_vectors = outputs['layer1'] + for layer_name in ['layer2', 'layer3']: + if training_c: + embedding_vectors = embedding_concat_numpy(embedding_vectors, outputs[layer_name]) + else: + embedding_vectors = embedding_concat(embedding_vectors, outputs[layer_name]) + + + return embedding_vectors + +def postprocess_optimized(outputs, device): # Embedding concat embedding_vectors = outputs['layer1'] for layer_name in ['layer2', 'layer3']: - embedding_vectors = embedding_concat(embedding_vectors, outputs[layer_name]) + embedding_vectors = embedding_concat_optimizded(embedding_vectors, outputs[layer_name], device) return embedding_vectors @@ -277,6 +334,114 @@ def training(net, params, size, crop_size, keep_aspect, batch_size, train_dir, a train_outputs = [mean, cov, cov_inv, idx] return train_outputs + +def training_optimized(net, params, size, crop_size, keep_aspect, batch_size, train_dir, aug, aug_num, seed, logger): + # set seed + random.seed(seed) + idx = random.sample(range(0, params["t_d"]), params["d"]) + + if os.path.isdir(train_dir): + train_imgs = sorted([ + os.path.join(train_dir, f) for f in os.listdir(train_dir) + if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp') + ]) + if len(train_imgs) == 0: + logger.error("train images not found in '%s'" % train_dir) + sys.exit(-1) + else: + logger.info("capture 200 frames from video") + train_imgs = capture_training_frames_from_video(train_dir) + + if not aug: + logger.info('extract train set features without augmentation') + aug_num = 1 + else: + logger.info('extract train set features with augmentation') + aug_num = aug_num + mean = None + N = 0 + for i_aug in range(aug_num): + for i_img in range(0, len(train_imgs), batch_size): + # prepare input data + imgs = [] + if not aug: + logger.info('from (%s ~ %s) ' % + (i_img, + min(len(train_imgs) - 1, + i_img + batch_size))) + else: + logger.info('from (%s ~ %s) on augmentation lap %d' % + (i_img, + min(len(train_imgs) - 1, + i_img + batch_size), i_aug)) + for image_path in train_imgs[i_img:i_img + batch_size]: + if type(image_path) is str: + img = load_image(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + else: + img = cv2.cvtColor(image_path, cv2.COLOR_BGR2RGB) + if not aug: + img = preprocess(img, size, crop_size, keep_aspect=keep_aspect) + else: + img = preprocess_aug(img, size, crop_size, keep_aspect=keep_aspect) + imgs.append(img) + + # countup N + N += len(imgs) + + imgs = np.vstack(imgs) + + logger.debug(f'input images shape: {imgs.shape}') + net.set_input_shape(imgs.shape) + + # inference + _ = net.predict(imgs) + + train_outputs = OrderedDict([ + ('layer1', []), ('layer2', []), ('layer3', []) + ]) + for key, name in zip(train_outputs.keys(), params["feat_names"]): + train_outputs[key].append(net.get_blob_data(name)) + for k, v in train_outputs.items(): + train_outputs[k] = v[0] + + embedding_vectors = postprocess(train_outputs, training_c=True) + + # randomly select d dimension + embedding_vectors = embedding_vectors[:, idx, :, :] + + # reshape 2d pixels to 1d features + B, C, H, W = embedding_vectors.shape + embedding_vectors = embedding_vectors.reshape(B, C, H * W) + + # initialize mean and covariance matrix + if (mean is None): + mean = np.zeros((C, H * W), dtype=np.float32) + cov = np.zeros((C, C, H * W), dtype=np.float32) + + # calculate multivariate Gaussian distribution + # (add up mean and covariance matrix) + mean += np.sum(embedding_vectors, axis=0) + for i in range(H * W): + # https://github.com/numpy/numpy/blob/v1.21.0/numpy/lib/function_base.py#L2324-L2543 + m = embedding_vectors[:, :, i] + m = m - (mean[:, [i]].T / N) + cov[:, :, i] += m.T @ m + + # devide mean by N + mean = mean / N + # devide covariance by N-1, and calculate inverse + I = np.identity(C) + for i in range(H * W): + cov[:, :, i] = (cov[:, :, i] / (N - 1)) + 0.01 * I + + cov_inv = np.zeros(cov.shape) + for i in range(H * W): + cov_inv[:, :, i] = np.linalg.inv(cov[:, :, i]) + + train_outputs = [mean, cov, cov_inv, idx] + return train_outputs + def infer(net, params, train_outputs, img, crop_size): # prepare input data imgs = [] @@ -325,6 +490,120 @@ def infer(net, params, train_outputs, img, crop_size): return dist_tmp +def infer_optimized(net, params, train_outputs, img, crop_size, device, logger, weights_torch=None): + # prepare input data + imgs = [] + imgs.append(img) + imgs = np.vstack(imgs) + + # inference + net.set_input_shape(imgs.shape) + _ = net.predict(imgs) + + test_outputs = OrderedDict([ + ('layer1', []), ('layer2', []), ('layer3', []) + ]) + for key, name in zip(test_outputs.keys(), params["feat_names"]): + test_outputs[key].append(net.get_blob_data(name)) + for k, v in test_outputs.items(): + test_outputs[k] = torch.from_numpy(v[0]).to(device) + + embedding_vectors = postprocess_optimized(test_outputs, device) + + # randomly select d dimension + idx = train_outputs[3] + embedding_vectors = embedding_vectors[:, idx, :, :] + + # reshape 2d pixels to 1d features + B, C, H, W = embedding_vectors.shape + embedding_vectors = embedding_vectors.view(B, C, H * W) + + # calculate distance matrix + #mean_vectors = train_outputs[0] + #inv_cov_matrices = train_outputs[2] + #samples = embedding_vectors[0] + # + if str(device) not in str(train_outputs[0].device): + logger.info(f"Changing device from {train_outputs[0].device} to {device}") + train_outputs[0]=train_outputs[0].to(device) + embedding_vectors[0]=embedding_vectors[0].to(device) + train_outputs[2]=train_outputs[2].to(device) + # Step 1: Compute the difference between each sample and its corresponding mean + dist_tmp = embedding_vectors[0] - train_outputs[0] + # Step 2: Apply the inverse covariance matrix + transformed_differences = torch.einsum('ijk,jk->ik', train_outputs[2], dist_tmp) + + # Step 3: Compute the Mahalanobis distance + dist_tmp = torch.sqrt(torch.sum(dist_tmp * transformed_differences, dim=0)) + transformed_differences=0 + # upsample + dist_tmp=F.interpolate(dist_tmp.view(1, -1).view( H, W).unsqueeze(0).unsqueeze(0), + size=(crop_size, crop_size), mode='bilinear', align_corners=False).squeeze(0) + dist_tmp=gausian_filter_torch(dist_tmp, weights_torch, mode='reflect') + + """ + if str(device) not in str(mean_vectors.device): + logger.info(f"Changing device from {mean_vectors.device} to {device}") + mean_vectors=mean_vectors.to(device) + samples=samples.to(device) + inv_cov_matrices=inv_cov_matrices.to(device) + # Step 1: Compute the difference between each sample and its corresponding mean + differences = samples - mean_vectors + # Step 2: Apply the inverse covariance matrix + transformed_differences = torch.einsum('ijk,jk->ik', inv_cov_matrices, differences) + # Step 3: Compute the Mahalanobis distance + dist_tmp = torch.sqrt(torch.sum(differences * transformed_differences, dim=0)) + # upsample + """ + + #dist_tmp = dist_tmp.view(1, -1).view( H, W) + """ + dist_tmp = dist_tmp.view(1, -1).view( H, W).cpu().numpy() + print("crop_size ", crop_size) + dist_tmp = np.array(Image.fromarray(dist_tmp).resize( + (crop_size, crop_size), resample=Image.BILINEAR) + ) + print("Shape before gausian filter: ", dist_tmp.shape) + # apply gaussian smoothing on the score map + dist_tmp = gaussian_filter(dist_tmp, sigma=4) + + print("Shape after gausian filter: ", dist_tmp.shape) + """ + + #dist_tmp=F.interpolate(dist_tmp.unsqueeze(0).unsqueeze(0), + #size=(crop_size, crop_size), mode='bilinear', align_corners=False).squeeze(0) + #dist_tmp=gausian_filter_torch(dist_tmp, weights_torch, mode='reflect') + + + return dist_tmp + +def gaussian_kernel1d_torch(sigma, order, radius, device): + """ + Computes a 1-D Gaussian convolution kernel. + """ + if order < 0: + raise ValueError('order must be non-negative') + #exponent_range = torch.arange(order + 1, device=device) + sigma2 = sigma * sigma + x = torch.arange(-radius, radius + 1, dtype=torch.float32, device=device) + phi_x = torch.exp(-0.5 / sigma2 * x ** 2) + phi_x = phi_x / phi_x.sum() + + if order == 0: + return phi_x + +def gausian_filter_torch(input, weights, output=None, mode='constant', cval=0.0, origin=0): + + input=input.permute(2, 0,1 ) + input=F.pad(input, pad=(16, 16), mode='reflect') + input=F.conv1d(input, weights ) + + input=input.permute(2, 1,0 ) + input=F.pad(input, pad=(16, 16), mode='reflect', ) + input=F.conv1d(input, weights).permute(1, 0,2 ) + return input + + def normalize_scores(score_map, crop_size, roi_img = None): N = len(score_map) score_map = np.vstack(score_map) @@ -343,6 +622,27 @@ def normalize_scores(score_map, crop_size, roi_img = None): return scores +def normalize_scores_torch(score_map, crop_size, roi_img=None): + """ + score_map is list of torch tensors + crop size int + """ + # Convert list of tensors to a single tensor + score_map = torch.stack(score_map) + + # Handle ROI (Region of Interest) + if roi_img is not None: + roi_img = (roi_img > 0.5).float() # Threshold to binary mask + for i in range(score_map.shape[0]): + score_map[i] *= roi_img[0, 0] # Element-wise multiplication + + # Normalization using min-max scaling (avoiding division by zero) + max_score = score_map.max() + min_score = score_map.min() + scores = (score_map - min_score) / torch.clamp(max_score - min_score, min=1e-8) + + return scores + def calculate_anormal_scores(score_map, crop_size): N = len(score_map) score_map = np.vstack(score_map) @@ -352,7 +652,22 @@ def calculate_anormal_scores(score_map, crop_size): anormal_scores = np.zeros((score_map.shape[0])) for i in range(score_map.shape[0]): anormal_scores[i] = score_map[i].max() + return anormal_scores + +def calculate_anormal_scores_torch(score_map, crop_size): + N = len(score_map) + + + # Stack the score maps into a single tensor + score_map = torch.vstack(score_map) + score_map = score_map.unsqueeze(0).view(N, crop_size, crop_size) + + # Calculate anormal scores + anormal_scores = np.zeros((N)) + for i in range(score_map.shape[0]): + anormal_scores[i] = score_map[i].max().cpu().numpy() + return anormal_scores diff --git a/anomaly_detection/padim_old/padim.py b/anomaly_detection/padim_old/padim.py new file mode 100644 index 000000000..1f37833d4 --- /dev/null +++ b/anomaly_detection/padim_old/padim.py @@ -0,0 +1,605 @@ +import os +import sys +import time +from collections import OrderedDict +import random +import pickle + +import numpy as np +import cv2 +from PIL import Image +import matplotlib +import matplotlib.pyplot as plt + +import ailia + +# import original modules +sys.path.append('../../util') +from arg_utils import get_base_parser, update_parser, get_savepath # noqa: E402 +from model_utils import check_and_download_models # noqa: E402 +from detector_utils import load_image # noqa: E402 +import webcamera_utils # noqa: E402 + +# logger +from logging import getLogger # noqa: E402 + +from padim_utils import * + +logger = getLogger(__name__) + +# ====================== +# Parameters +# ====================== + +REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/padim/' + +IMAGE_PATH = './bottle_000.png' +SAVE_IMAGE_PATH = './output.png' +IMAGE_RESIZE = 256 +IMAGE_SIZE = 224 +KEEP_ASPECT = True + +# ====================== +# Arguemnt Parser Config +# ====================== + +parser = get_base_parser('PaDiM model', IMAGE_PATH, SAVE_IMAGE_PATH) +parser.add_argument( + '-a', '--arch', default='resnet18', choices=('resnet18', 'wide_resnet50_2'), + help='arch model.' +) +parser.add_argument( + '-f', '--feat', metavar="FILE", default=None, + help='train set feature files.' +) +parser.add_argument( + '-bs', '--batch_size', default=32, + help='batch size.' +) +parser.add_argument( + '-tr', '--train_dir', metavar="DIR", default="./train", + help='directory of the train files.' +) +parser.add_argument( + '-gt', '--gt_dir', metavar="DIR", default="./gt_masks", + help='directory of the ground truth mask files.' +) +parser.add_argument( + '--seed', type=int, default=1024, + help='random seed' +) +parser.add_argument( + '-th', '--threshold', type=float, default=None, + help='threshold' +) +parser.add_argument( + '-ag', '--aug', action='store_true', + help='process with augmentation.' +) +parser.add_argument( + '-an', '--aug_num', type=int, default=5, + help='specify the amplification number of augmentation.' +) +parser.add_argument( + '-eon', '--enable_optimization', type=bool, default=False, + help='Flag to enable optimized code' +) +parser.add_argument( + '--compare_optimization', type=bool, default=False, + help='Flag to compare output of optimization with original code' +) +parser.add_argument( + '--compare_optimization', type=bool, default=False, + help='Flag to compare output of optimization with original code' +) + +parser.add_argument( + '--save_format', metavar="FILE", default="pkl", + help='chose training file format pt, npy or pkl.' +) + +parser.add_argument( + '--optimization_device', metavar="device", default='cpu', choices=('cpu', 'cuda', 'mps'), + help='chose optimization device' +) + +args = update_parser(parser) + +if args.compare_optimization: + args.enable_optimization = True + train_output_list=[] + +if args.enable_optimization: + import torch + if args.optimization_device=="cuda" and torch.cuda.is_available(): + device = torch.device("cuda") + + elif args.optimization_device=="mps" and torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + logger.info("Torch device : " + str(device)) + + + +# ====================== +# Main functions +# ====================== + + +def plot_fig(file_list, test_imgs, scores, anormal_scores, gt_imgs, threshold, savepath): + num = len(file_list) + vmax = scores.max() * 255. + vmin = scores.min() * 255. + for i in range(num): + image_path = file_list[i] + img = test_imgs[i] + img = denormalization(img) + if gt_imgs is not None: + gt = gt_imgs[i] + gt = gt.transpose(1, 2, 0).squeeze() + else: + gt = np.zeros((1,1,1)) + print(scores[i], scores[i].shape) + heat_map, mask, vis_img = visualize(img, scores[i], threshold) + + fig_img, ax_img = plt.subplots(1, 5, figsize=(12, 3)) + fig_img.subplots_adjust(right=0.9) + + fig_img.suptitle("Input : " + image_path + " Anomaly score : " + str(anormal_scores[i])) + logger.info("Anomaly score : " + str(anormal_scores[i])) + + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + for ax_i in ax_img: + ax_i.axes.xaxis.set_visible(False) + ax_i.axes.yaxis.set_visible(False) + ax_img[0].imshow(img) + ax_img[0].title.set_text('Image') + ax_img[1].imshow(gt, cmap='gray') + ax_img[1].title.set_text('GroundTruth') + ax = ax_img[2].imshow(heat_map, cmap='jet', norm=norm) + ax_img[2].imshow(img, cmap='gray', interpolation='none') + ax_img[2].imshow(heat_map, cmap='jet', alpha=0.5, interpolation='none') + ax_img[2].title.set_text('Predicted heat map') + ax_img[3].imshow(mask, cmap='gray') + ax_img[3].title.set_text('Predicted mask') + ax_img[4].imshow(vis_img) + ax_img[4].title.set_text('Segmentation result') + left = 0.92 + bottom = 0.15 + width = 0.015 + height = 1 - 2 * bottom + rect = [left, bottom, width, height] + cbar_ax = fig_img.add_axes(rect) + cb = plt.colorbar(ax, shrink=0.6, cax=cbar_ax, fraction=0.046) + cb.ax.tick_params(labelsize=8) + font = { + 'family': 'serif', + 'color': 'black', + 'weight': 'normal', + 'size': 8, + } + cb.set_label('Anomaly Score', fontdict=font) + + if ('.' in savepath.split('/')[-1]): + savepath_tmp = get_savepath(savepath, image_path, ext='.png') + else: + filename_tmp = image_path.split('/')[-1] + ext_tmp = '.' + filename_tmp.split('.')[-1] + filename_tmp = filename_tmp.replace(ext_tmp, '.png') + savepath_tmp = '%s/%s' % (savepath, filename_tmp) + logger.info(f'saved at : {savepath_tmp}') + fig_img.savefig(savepath_tmp, dpi=100) + plt.close() + +def infer_init_run(net, params, train_outputs, IMAGE_SIZE): + import numpy as np + dummy_image = np.random.rand(1, 3, 224, 224) * 255.0 # Scale between 0 and 255 + # Convert the dtype to float32 for efficiency + dummy_image = dummy_image.astype(np.float32) + logger.info(f"PaDiM initialization inference starts!") + if args.enable_optimization: + score = infer_optimized(net, params, train_outputs, dummy_image, IMAGE_SIZE, device, logger) + else: + score = infer(net, params, train_outputs, dummy_image, IMAGE_SIZE) + logger.info(f"PaDiM initialization inference finish!") + + +def train_from_image_or_video(net, params): + # training + if args.enable_optimization: + train_outputs = training_optimized(net, params, IMAGE_RESIZE, IMAGE_SIZE, KEEP_ASPECT, int(args.batch_size), args.train_dir, args.aug, args.aug_num, args.seed, logger) + else: + train_outputs = training(net, params, IMAGE_RESIZE, IMAGE_SIZE, KEEP_ASPECT, int(args.batch_size), args.train_dir, args.aug, args.aug_num, args.seed, logger) + # save learned distribution + if args.feat: + train_feat_file = args.feat + else: + train_dir = args.train_dir + train_feat_file = str(os.path.basename(train_dir))+"."+str(args.save_format) + + train_outputs=_save_training_flie(train_feat_file, args.save_format, train_outputs) + + return train_outputs + + +def load_gt_imgs(gt_type_dir): + gt_imgs = [] + for i_img in range(0, len(args.input)): + image_path = args.input[i_img] + gt_img = None + if gt_type_dir: + fname = os.path.splitext(os.path.basename(image_path))[0] + gt_fpath = os.path.join(gt_type_dir, fname + '_mask.png') + if os.path.exists(gt_fpath): + gt_img = load_image(gt_fpath) + gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGRA2RGB) + gt_img = preprocess(gt_img, IMAGE_RESIZE, mask=True, keep_aspect=KEEP_ASPECT, crop_size = IMAGE_SIZE) + if gt_img is not None: + gt_img = gt_img[0, [0]] + else: + gt_img = np.zeros((1, IMAGE_SIZE, IMAGE_SIZE)) + gt_imgs.append(gt_img) + return gt_imgs + + +def decide_threshold_from_gt_image(net, params, train_outputs, gt_imgs): + score_map = [] + for i_img in range(0, len(args.input)): + logger.info('from (%s) ' % (args.input[i_img])) + + image_path = args.input[i_img] + img = load_image(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT, crop_size = IMAGE_SIZE) + if args.enable_optimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger) + else: + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + + + score_map.append(dist_tmp) + + scores = normalize_scores(score_map, IMAGE_SIZE) + print("scores shape", scores.shape) + + threshold = decide_threshold(scores, gt_imgs) + + return threshold + +def infer_from_image(net, params, train_outputs, threshold, gt_imgs): + if len(args.input) == 0: + logger.error("Input file not found") + return + + test_imgs = [] + + score_map = [] + infer_init_run(net, params, train_outputs, IMAGE_SIZE) + for i_img in range(0, len(args.input)): + logger.info('from (%s) ' % (args.input[i_img])) + + image_path = args.input[i_img] + img = load_image(image_path) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT, crop_size = IMAGE_SIZE) + + test_imgs.append(img[0]) + + if args.benchmark: + logger.info('BENCHMARK mode') + total_time = 0 + if args.enable_optimization: + for i in range(args.benchmark_count): + start = int(round(time.time() * 1000)) + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms') + else: + for i in range(args.benchmark_count): + start = int(round(time.time() * 1000)) + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + end = int(round(time.time() * 1000)) + logger.info(f'\tailia processing time {end - start} ms') + if i != 0: + total_time = total_time + (end - start) + logger.info(f'\taverage time {total_time / (args.benchmark_count - 1)} ms') + if args.compare_optimization: + logger.info(f'\tResults of optimized and original code is the same: {np.allclose(infer(net, params, train_output_list[0], img, IMAGE_SIZE), infer_optimized(net, params, train_output_list[1], img, IMAGE_SIZE, device,logger))}') + + + else: + if args.enable_optimization: + dist_tmp = infer_optimized(net, params, train_outputs, img, IMAGE_SIZE, device,logger) + else: + dist_tmp = infer(net, params, train_outputs, img, IMAGE_SIZE) + if args.compare_optimization: + logger.info('Results of optimized and original code is the same: '+ str(np.allclose(infer(net, params, train_output_list[0], img, IMAGE_SIZE), infer_optimized(net, params, train_output_list[1], img, IMAGE_SIZE, device,logger)))) + + + score_map.append(dist_tmp) + print("score_map shape ", np.asarray(score_map).shape) + scores = normalize_scores(score_map, IMAGE_SIZE) + print("normalize_scores shape ", np.asarray(scores).shape) + print("IMAGE_SIZE shape ", IMAGE_SIZE) + anormal_scores = calculate_anormal_scores(score_map, IMAGE_SIZE) + print("IMAGE_SIZE shape ", np.asarray(anormal_scores).shape) + + # Plot gt image + plot_fig(args.input, test_imgs, scores, anormal_scores, gt_imgs, threshold, args.savepath) + + +def infer_from_video(net, params, train_outputs, threshold): + capture = webcamera_utils.get_capture(args.video) + if args.savepath != SAVE_IMAGE_PATH: + f_h = int(IMAGE_SIZE) + f_w = int(IMAGE_SIZE) * 3 + writer = webcamera_utils.get_writer(args.savepath, f_h, f_w) + else: + writer = None + + score_map = [] + + frame_shown = False + infer_init_run(net, params, train_outputs, IMAGE_SIZE) + if args.enable_optimization: + while(True): + ret, frame = capture.read() + if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: + break + if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: + break + + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT) + + dist_tmp = infer_optimized(net, params, train_outputs, img, device,logger) + + score_map.append(dist_tmp) + scores = normalize_scores(score_map) # min max is calculated dynamically, please set fixed min max value from calibration data for production + + heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) + frame = pack_visualize(heat_map, mask, vis_img, scores, IMAGE_SIZE) + + cv2.imshow('frame', frame) + frame_shown = True + + if writer is not None: + writer.write(frame) + else: + while(True): + ret, frame = capture.read() + if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: + break + if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: + break + + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = preprocess(img, IMAGE_RESIZE, keep_aspect=KEEP_ASPECT) + + dist_tmp = infer(net, params, train_outputs, img) + + score_map.append(dist_tmp) + scores = normalize_scores(score_map) # min max is calculated dynamically, please set fixed min max value from calibration data for production + + heat_map, mask, vis_img = visualize(denormalization(img[0]), scores[len(scores)-1], threshold) + frame = pack_visualize(heat_map, mask, vis_img, scores, IMAGE_SIZE) + + cv2.imshow('frame', frame) + frame_shown = True + + if writer is not None: + writer.write(frame) + + capture.release() + cv2.destroyAllWindows() + if writer is not None: + writer.release() + + +def train_and_infer(net, params): + timestart=time.time() + if args.feat: + train_outputs=_load_training_file(args.feat, args.save_format) + logger.info('loaded.') + else: + train_outputs = train_from_image_or_video(net, params) + + if args.threshold is None: + if args.video: + threshold = 0.5 + gt_imgs = None + logger.info('Please set threshold manually for video mdoe') + else: + gt_type_dir = args.gt_dir if args.gt_dir else None + gt_imgs = load_gt_imgs(gt_type_dir) + + threshold = decide_threshold_from_gt_image(net, params, train_outputs, gt_imgs) + logger.info('Optimal threshold: %f' % threshold) + else: + threshold = args.threshold + gt_imgs = None + + if args.video: + infer_from_video(net, params, train_outputs, threshold) + else: + infer_from_image(net, params, train_outputs, threshold, gt_imgs) + logger.info('Script finished successfully.') + +def _save_training_flie(train_feat_file, save_format, train_outputs): + if args.compare_optimization: + train_output_list.append(train_outputs) + + if not args.enable_optimization: + if save_format == "pkl" : + if train_feat_file==None: + train_feat_file = "train.pkl" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "npy" : + filename=train_feat_file.split(".")[0].strip() + for i, output in enumerate(train_outputs): + + if train_feat_file==None: + train_feat_file = "train_output_"+str(i)+".npy" + else: + train_feat_file = filename+str("_")+str(i)+".npy" + + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"{filename}_{i}.npy", output) + logger.info('Saved.') + elif save_format == "pt": + if train_feat_file==None: + train_feat_file = "train.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') + + + else: + if save_format=="npy": + filename=train_feat_file.split(".")[0].strip() + for i, output in enumerate(train_outputs): + if train_feat_file==None: + train_feat_file = "train_output_"+str(i)+".npy" + else: + train_feat_file = filename+"_"+str(i)+".npy" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + np.save(f"{filename}_{i}.npy", output) + logger.info('Saved.') + + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + if save_format == "pkl" : + if train_feat_file==None: + train_feat_file = "trainOptimized.pkl" + logger.info('saving train set feature to: %s ...' % train_feat_file) + with open(train_feat_file, 'wb') as f: + pickle.dump(train_outputs, f) + logger.info('Saved.') + elif save_format == "pt": + if train_feat_file==None: + train_feat_file = "trainOptimized.pt" + logger.info('Saving train set feature to: %s ...' % train_feat_file) + torch.save(train_outputs, train_feat_file) + logger.info('Saved.') + + + if args.compare_optimization: + train_output_list.append(train_outputs) + + return train_outputs + +def _load_training_file(train_feat_file, save_format): + if _check_file_exists(train_feat_file, save_format): + if not train_feat_file: + train_feat_file = "trainOptimized."+save_format + else: + save_format=train_feat_file.split(".")[1].strip() + logger.info(f"Save format {save_format}") + + if args.enable_optimization: + + + if save_format== "pkl": + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + elif save_format == "npy": + train_outputs = [] + i = 0 + if train_feat_file: + train_feat_file=train_feat_file.split("_")[0].strip() + else: + train_feat_file="train_" + + while True: + try: + + logger.info(f"{train_feat_file}_{i}.npy") + train_outputs.append(np.load(f"{train_feat_file}_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + train_outputs=[torch.from_numpy(train_outputs[0]).float().to(device), train_outputs[1], + torch.from_numpy(train_outputs[2]).float().to(device), train_outputs[3] ] + elif save_format == "pt": + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + else: + + if save_format == "pkl": + train_feat_file = "train."+save_format + logger.info(f"Loading {train_feat_file}") + with open(train_feat_file, 'rb') as f: + train_outputs = pickle.load(f) + elif save_format == "npy": + train_outputs = [] + i = 0 + if train_feat_file: + train_feat_file=train_feat_file.split("_")[0].strip() + else: + train_feat_file="train_" + while True: + try: + + logger.info(f"{train_feat_file}_{i}.npy") + train_outputs.append(np.load(f"{train_feat_file}_{i}.npy", allow_pickle=True)) + i += 1 + except FileNotFoundError: + break # Stop when there are no more files to load + + elif save_format == "pt": + train_feat_file = "train."+save_format + logger.info(f"Loading {train_feat_file}") + train_outputs = torch.load(train_feat_file) + train_outputs_numpy = [] + if train_outputs[0] is torch.Tensor: + for item in train_outputs: + if isinstance(item, torch.Tensor): + train_outputs_numpy.append(item.cpu().numpy()) # Move to CPU and convert to NumPy + else: + train_outputs_numpy.append(item) + return train_outputs_numpy + return train_outputs + else: + logger.info(f"filename {train_feat_file} have not been found") + + +def _check_file_exists(train_feat_file, save_format): + if train_feat_file == None: + if save_format=="npy": + filename="train_output_0.npy" + elif args.enable_optimization: + filename="trainOptimized."+save_format + else: + filename="train."+save_format + if not os.path.isfile(filename): + logger.info(f"File {filename} does not exist. Unable to load the model") + else: + return os.path.isfile(train_feat_file) + + return os.path.isfile(filename) + + +def main(): + # model files check and download + starttime=time.time() + weight_path, model_path, params = get_params(args.arch) + check_and_download_models(weight_path, model_path, REMOTE_PATH) + + # create net instance + net = ailia.Net(model_path, weight_path, env_id=args.env_id) + + # check input + train_and_infer(net, params) + logger.info('Script finished execution time: '+str(int((time.time()-starttime)*1000))) + + +if __name__ == '__main__': + main()