diff --git a/.flake8 b/.flake8 index f59d3e5f6..f98e9e526 100644 --- a/.flake8 +++ b/.flake8 @@ -18,6 +18,7 @@ per-file-ignores = tools/infer/text/parallel/base_predict.py:E402 tools/infer/text/parallel/predict_system.py:E402 tools/infer/text/predict_system.py:E402 + tools/infer/text/predict_cls.py:E402 tools/infer/text/predict_rec.py:E402 tools/dataset_converters/convert.py:F401,F403 mindocr/data/transforms/transforms_factory.py:F401,F403 diff --git a/mindocr/data/transforms/rec_transforms.py b/mindocr/data/transforms/rec_transforms.py index dfe697b5f..7be0bdd97 100644 --- a/mindocr/data/transforms/rec_transforms.py +++ b/mindocr/data/transforms/rec_transforms.py @@ -411,14 +411,16 @@ def __call__(self, data): # tar_h, tar_w = self.targt_shape resize_h = self.tar_h - max_wh_ratio = self.tar_w / float(self.tar_h) - if not self.keep_ratio: assert self.tar_w is not None, "Must specify target_width if keep_ratio is False" resize_w = self.tar_w # if self.tar_w is not None else resized_h * self.max_wh_ratio else: src_wh_ratio = w / float(h) - resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h) + if self.tar_w is not None: + max_wh_ratio = self.tar_w / float(self.tar_h) + resize_w = math.ceil(min(src_wh_ratio, max_wh_ratio) * resize_h) + else: + resize_w = math.ceil(src_wh_ratio * resize_h) resized_img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interpolation) # TODO: norm before padding diff --git a/tests/st/test_online_infer.py b/tests/st/test_online_infer.py index 1e17cb383..5b92e88b4 100644 --- a/tests/st/test_online_infer.py +++ b/tests/st/test_online_infer.py @@ -42,6 +42,7 @@ def _gen_text_image(texts=TEXTS_2, boxes=BOXES_2, save_fp="gen_img.jpg"): det_img_fp = _gen_text_image(save_fp="gen_det_input.jpg") rec_img_fp = _gen_text_image([TEXTS_2[0]], [BOXES_2[0]], "gen_rec_input.jpg") +cls_img_fp = rec_img_fp def test_det_infer(): @@ -55,6 +56,17 @@ def test_det_infer(): assert ret == 0, "Det inference fails" +def test_cls_infer(): + algo = "MV3" + cmd = ( + f"python tools/infer/text/predict_cls.py --image_dir {cls_img_fp} --cls_algorithm {algo} " + f"--draw_img_save_dir ./infer_test" + ) + print(f"Running command: \n{cmd}") + ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) + assert ret == 0, "Cls inference fails" + + def test_rec_infer(): algo = "CRNN" cmd = ( @@ -68,10 +80,12 @@ def test_rec_infer(): def test_system_infer(): det_algo = "DB" + cls_algo = "MV3" rec_algo = "CRNN_CH" cmd = ( f"python tools/infer/text/predict_system.py --image_dir {det_img_fp} --det_algorithm {det_algo} " - f"--rec_algorithm {rec_algo} --draw_img_save_dir ./infer_test --visualize_output True" + f"--cls_algorithm {cls_algo} --rec_algorithm {rec_algo} " + f"--draw_img_save_dir ./infer_test --visualize_output True" ) print(f"Running command: \n{cmd}") ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) diff --git a/tools/infer/text/config.py b/tools/infer/text/config.py index 4076cd4f6..43c1f4109 100644 --- a/tools/infer/text/config.py +++ b/tools/infer/text/config.py @@ -37,6 +37,7 @@ def create_parser(): # parser.add_argument("--gpu_id", type=int, default=0) parser.add_argument("--det_model_config", type=str, help="path to det model yaml config") # added + parser.add_argument("--cls_model_config", type=str, help="path to cls model yaml config") # added parser.add_argument("--rec_model_config", type=str, help="path to rec model yaml config") # added # params for text detector @@ -90,6 +91,39 @@ def create_parser(): parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") + # params for text direction classification + parser.add_argument( + "--cls_model_dir", + type=str, + help="directory containing the text direction classification model checkpoint best.ckpt, " + "or path to a specific checkpoint file.", + ) # determine the network weights + parser.add_argument( + "--cls_batch_mode", + type=str2bool, + default=True, + help="Whether to run text direction classification inference in batch-mode, " + "which is faster but may degrade the accraucy due to padding or resizing to the same shape.", + ) # added + parser.add_argument("--cls_batch_num", type=int, default=8) + parser.add_argument("--cls_algorithm", type=str, choices=["MV3"]) + parser.add_argument( + "--cls_rotate_thre", + type=float, + default=0.9, + help="Rotate the image when text direction classification score is larger than this threshold.", + ) + parser.add_argument( + "--cls_image_shape", + type=str, + default="3, 48, 192", + help="C, H, W for taget image shape. max_wh_ratio=W/H will be used to control the maximum width " + "after 'aspect-ratio-kept' resizing. Set W larger for longer text.", + ) + parser.add_argument( + "--cls_label_list", type=str, nargs="+", default=["0", "180"], choices=[["0", "180"], ["0", "90", "180", "270"]] + ) + # params for text recognizer parser.add_argument( "--rec_algorithm", diff --git a/tools/infer/text/postprocess.py b/tools/infer/text/postprocess.py index d2524c47f..552bed06a 100644 --- a/tools/infer/text/postprocess.py +++ b/tools/infer/text/postprocess.py @@ -35,7 +35,7 @@ def __init__(self, task="det", algo="DB", **kwargs): raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.") self.rescale_internally = True self.round = True - elif task == "rec": + elif task in ("rec", "cls"): # TODO: update character_dict_path and use_space_char after CRNN trained using en_dict.txt released if algo.startswith("CRNN") or algo.startswith("SVTR"): # TODO: allow users to input char dict path @@ -52,7 +52,10 @@ def __init__(self, task="det", algo="DB", **kwargs): character_dict_path=dict_path, use_space_char=False, ) - + elif algo.startswith("MV"): + postproc_cfg = dict( + name="ClsPostprocess", + ) else: raise ValueError(f"No postprocess config defined for {algo}. Please check the algorithm name.") @@ -108,6 +111,6 @@ def __call__(self, pred, data=None): det_res = dict(polys=polys, scores=scores) return det_res - elif self.task == "rec": + elif self.task in ("rec", "cls"): output = self.postprocess(pred) return output diff --git a/tools/infer/text/predict_cls.py b/tools/infer/text/predict_cls.py new file mode 100644 index 000000000..576cbdc61 --- /dev/null +++ b/tools/infer/text/predict_cls.py @@ -0,0 +1,266 @@ +""" +Text classification inference + +Example: + $ python tools/infer/text/predict_cls.py --image_dir {img_dir_or_img_path} --cls_algorithm MV3 +""" +import os +import sys +from time import time + +import numpy as np +from config import parse_args +from postprocess import Postprocessor +from preprocess import Preprocessor +from tqdm import tqdm +from utils import get_ckpt_file, get_image_paths + +import mindspore as ms + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) + +from mindocr import build_model +from mindocr.utils.logger import Logger +from mindocr.utils.visualize import show_imgs + +# map algorithm name to model name (which can be checked by `mindocr.list_models()`) +# NOTE: Modify it to add new model for inference. +algo_to_model_name = {"MV3": "cls_mobilenet_v3_small_100_model"} +_logger = Logger("mindocr") + + +class DirectionClassifier(object): + def __init__(self, args): + self.batch_num = args.cls_batch_num + self.batch_mode = args.cls_batch_mode + self.rotate_thre = args.cls_rotate_thre + self.visualize_output = args.visualize_output + # self.batch_mode = args.cls_batch_mode and (self.batch_num > 1) + _logger.info( + "classify text direction in {} mode {}".format( + "batch" if self.batch_mode else "serial", + "batch_size: " + str(self.batch_num) if self.batch_mode else "", + ) + ) + + # build model for algorithm with pretrained weights or local checkpoint + ckpt_dir = args.cls_model_dir + if ckpt_dir is None: + pretrained = True + ckpt_load_path = None + else: + ckpt_load_path = get_ckpt_file(ckpt_dir) + pretrained = False + assert args.cls_algorithm in algo_to_model_name, f"Invalid cls_algorithm {args.cls_algorithm}. " + f"Supported classification algorithms are {list(algo_to_model_name.keys())}." + + model_name = algo_to_model_name[args.cls_algorithm] + self.model = build_model(model_name, pretrained=pretrained, ckpt_load_path=ckpt_load_path) + self.model.set_train(False) + _logger.info( + "Init text direction classification model: {} --> {}. Model weights loaded from {}".format( + args.cls_algorithm, model_name, "pretrained url" if pretrained else ckpt_load_path + ) + ) + + # build preprocess and postprocess + # NOTE: most process hyper-params should be set optimally for the pick algo. + self.preprocess = Preprocessor( + task="cls", + algo=args.cls_algorithm, + cls_image_shape=args.cls_image_shape, + cls_batch_mode=self.batch_mode, + cls_batch_num=self.batch_num, + ) + + # TODO: try GeneratorDataset to wrap preprocess transform on batch for possible speed-up. + # if use_ms_dataset: ds = ms.dataset.GeneratorDataset(wrap_preprocess, ) in run_batchwise + + self.postprocess = Postprocessor(task="cls", algo=args.cls_algorithm, label_list=args.cls_label_list) + + self.vis_dir = args.draw_img_save_dir + os.makedirs(self.vis_dir, exist_ok=True) + + def __call__(self, img_or_path_list: list): + """ + Run text direction classification serially for input images + + Args: + img_or_path_list: list of str for img path or np.array for RGB image + + Return: + list of dict, each contains the follow keys for text direction classification result. + e.g. [{'texts': 'abc', 'confs': 0.9}, {'texts': 'cd', 'confs': 1.0}] + - texts: text string + - confs: prediction confidence + """ + + assert isinstance( + img_or_path_list, list + ), "Input for text direction classification must be list of images or image paths." + _logger.info(f"num images for cls: {len(img_or_path_list)}") + if self.batch_mode: + cls_res_all, all_rotated_imgs = self.run_batchwise(img_or_path_list) + else: + cls_res_all, all_rotated_imgs = [], [] + for i, img_or_path in enumerate(img_or_path_list): + cls_res, rotated_imgs = self.run_single(img_or_path, i) + cls_res_all.append(cls_res) + all_rotated_imgs.extend(rotated_imgs) + + # TODO: add vis and save function + return cls_res_all, all_rotated_imgs + + def rotate(self, img_batch, batch_res): + rotated_img_batch = [] + for i, score in enumerate(batch_res["scores"]): + tmp_img = img_batch[i] + if int(batch_res["angles"][i]) != 0 and score > self.rotate_thre: + tmp_img = np.rot90(tmp_img, k=int(int(batch_res["angles"][i]) / 90)) + _logger.info(f"After text direction classification, image is rotated {batch_res['angles'][i]} degree.") + rotated_img_batch.append(tmp_img.transpose(1, 2, 0)) # c, h, w --> h, w, c for saving and visualization + return rotated_img_batch + + def run_batchwise(self, img_or_path_list: list): + """ + Run text direction classification serially for input images + + Args: + img_or_path_list: list of str for img path or np.array for RGB image + + Return: + cls_res: list of tuple, where each tuple is (text, score) - text direction classification result for + each input image in order. where text is the predicted text string, score is its confidence score. + e.g. [('apple', 0.9), ('bike', 1.0)] + """ + cls_res, all_rotated_imgs = [], [] + num_imgs = len(img_or_path_list) + + for idx in tqdm(range(0, num_imgs, self.batch_num)): # batch begin index i + batch_begin = idx + batch_end = min(idx + self.batch_num, num_imgs) + # print(f"Cls img idx range: [{batch_begin}, {batch_end})") + # TODO: set max_wh_ratio to the maximum wh ratio of images in the batch. and update it for resize, + # TODO: which may improve text direction classification accuracy in batch-mode + # TODO: especially for long text image. max_wh_ratio=max(max_wh_ratio, img_w / img_h). + # TODO: The short ones should be scaled with a.r. unchanged and padded to max width in batch. + + # preprocess + # TODO: run in parallel with multiprocessing + img_batch = [] + for j in range(batch_begin, batch_end): # image index j + data = self.preprocess(img_or_path_list[j]) + img_batch.append(data["image"]) # c,h,w + if self.visualize_output: + fn = os.path.basename(data.get("img_path", f"crop_{j}.png")).split(".")[0] + show_imgs( + [data["image"]], + title=fn + "_cls_preprocessed", + mean_rgb=[127.0, 127.0, 127.0], + std_rgb=[127.0, 127.0, 127.0], + is_chw=True, + show=False, + save_path=os.path.join(self.vis_dir, fn + "_cls_preproc.png"), + ) + + img_batch = np.stack(img_batch) if len(img_batch) > 1 else np.expand_dims(img_batch[0], axis=0) + # infer + net_pred = self.model(ms.Tensor(img_batch)) + + # postprocess + batch_res = self.postprocess(net_pred) + img_batch = self.rotate(img_batch, batch_res) + + cls_res.extend(list(zip(batch_res["angles"], batch_res["scores"]))) + all_rotated_imgs.extend(img_batch) + + return cls_res, all_rotated_imgs + + def run_single(self, img_or_path, crop_idx=0): + """ + Text direction classification inference on a single image + Args: + img_or_path: str for image path or np.array for image rgb value + + Return: + dict with keys: + - texts (str): preditive text string + - confs (int): confidence of the prediction + """ + # preprocess + data = self.preprocess(img_or_path) + + # visualize preprocess result + if self.visualize_output: + # show_imgs([data['image_ori']], is_bgr_img=False, title=f'origin_{i}') + fn = os.path.basename(data.get("img_path", f"crop_{crop_idx}.png")).split(".")[0] + show_imgs( + [data["image"]], + title=fn + "_cls_preprocessed", + mean_rgb=[127.0, 127.0, 127.0], + std_rgb=[127.0, 127.0, 127.0], + is_chw=True, + show=False, + save_path=os.path.join(self.vis_dir, fn + "_cls_preproc.png"), + ) + print("Origin image shape: ", data["image_ori"].shape) + print("Preprocessed image shape: ", data["image"].shape) + + # infer + input_np = data["image"] + if len(input_np.shape) == 3: + net_input = np.expand_dims(input_np, axis=0) + + net_output = self.model(ms.Tensor(net_input)) + + # postprocess + cls_res = self.postprocess(net_output) + net_input_rot = self.rotate(net_input, cls_res) + cls_res = (cls_res["angles"][0], cls_res["scores"][0]) + + print(f"Crop {crop_idx} cls result:", cls_res) + + return cls_res, net_input_rot + + +def save_cls_res(cls_res_all, img_paths, include_score=False, save_path="./cls_results.txt"): + lines = [] + for i, cls_res in enumerate(cls_res_all): + if include_score: + img_pred = os.path.basename(img_paths[i]) + "\t" + str(list(cls_res)) + "\n" + else: + img_pred = os.path.basename(img_paths[i]) + "\t" + cls_res[0] + "\n" + lines.append(img_pred) + + with open(save_path, "w") as f: + f.writelines(lines) + + return lines + + +if __name__ == "__main__": + # parse args + args = parse_args() + save_dir = args.draw_img_save_dir + img_paths = get_image_paths(args.image_dir) + # uncomment it to quick test the infer FPS + # img_paths = img_paths[:250] + + ms.set_context(mode=args.mode) + + # init classifier + classifier = DirectionClassifier(args) + + # TODO: warmup + + # run for each image + start = time() + cls_res_all, _ = classifier(img_paths) + t = time() - start + # save all results in a txt file + save_fp = os.path.join(save_dir, "cls_results.txt" if args.cls_batch_mode else "cls_results_serial.txt") + save_cls_res(cls_res_all, img_paths, include_score=True, save_path=save_fp) + # print('All cls res: ', cls_res_all) + print("Done! Text direction classification results saved in ", save_dir) + print("CLS time: ", t, "FPS: ", len(img_paths) / t) diff --git a/tools/infer/text/predict_det.py b/tools/infer/text/predict_det.py index f8c6cdd62..0a1730af9 100644 --- a/tools/infer/text/predict_det.py +++ b/tools/infer/text/predict_det.py @@ -2,7 +2,7 @@ Text detection inference Example: - $ python tools/infer/text/predict_det.py --image_dir {path_to_img} --rec_algorithm DB++ + $ python tools/infer/text/predict_det.py --image_dir {img_dir_or_img_path} --rec_algorithm DB++ """ import json @@ -222,7 +222,6 @@ def save_det_res(det_res_all: List[dict], img_paths: List[str], include_score=Fa with open(save_path, "w") as f: f.writelines(lines) - f.close() if __name__ == "__main__": diff --git a/tools/infer/text/predict_rec.py b/tools/infer/text/predict_rec.py index e7251ef40..0da119b6b 100644 --- a/tools/infer/text/predict_rec.py +++ b/tools/infer/text/predict_rec.py @@ -2,8 +2,8 @@ Text recognition inference Example: - $ python tools/infer/text/predict_rec.py --image_dir {path_to_img} --rec_algorithm CRNN - $ python tools/infer/text/predict_rec.py --image_dir {path_to_img} --rec_algorithm CRNN_CH + $ python tools/infer/text/predict_rec.py --image_dir {img_dir_or_img_path} --rec_algorithm CRNN + $ python tools/infer/text/predict_rec.py --image_dir {img_dir_or_img_path} --rec_algorithm CRNN_CH """ import os import sys @@ -246,14 +246,13 @@ def save_rec_res(rec_res_all, img_paths, include_score=False, save_path="./rec_r lines = [] for i, rec_res in enumerate(rec_res_all): if include_score: - img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\t" + rec_res[1] + "\n" + img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\t" + str(rec_res[1]) + "\n" else: img_pred = os.path.basename(img_paths[i]) + "\t" + rec_res[0] + "\n" lines.append(img_pred) with open(save_path, "w") as f: f.writelines(lines) - f.close() return lines diff --git a/tools/infer/text/predict_system.py b/tools/infer/text/predict_system.py index 3a76e6d4c..4479b89cf 100644 --- a/tools/infer/text/predict_system.py +++ b/tools/infer/text/predict_system.py @@ -2,10 +2,12 @@ Text detection and recognition inference Example: - $ python tools/infer/text/predict_system.py --image_dir {path_to_img_file} --det_algorithm DB++ \ + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ --rec_algorithm CRNN - $ python tools/infer/text/predict_system.py --image_dir {path_to_img_dir} --det_algorithm DB++ \ + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ --rec_algorithm CRNN_CH + $ python tools/infer/text/predict_system.py --image_dir {img_dir_or_img_path} --det_algorithm DB++ \ + --cls_algorithm MV3 --rec_algorithm CRNN_CH --visualize_output True """ import json @@ -17,6 +19,7 @@ import cv2 import numpy as np from config import parse_args +from predict_cls import DirectionClassifier from predict_det import TextDetector from predict_rec import TextRecognizer from utils import crop_text_region, get_image_paths @@ -36,7 +39,10 @@ class TextSystem(object): def __init__(self, args): self.text_detect = TextDetector(args) self.text_recognize = TextRecognizer(args) + if args.cls_algorithm: + self.direction_classify = DirectionClassifier(args) + self.cls_algorithm = args.cls_algorithm self.box_type = args.det_box_type self.drop_score = args.drop_score self.save_crop_res = args.save_crop_res @@ -49,7 +55,7 @@ def __init__(self, args): def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): """ - Detect and recognize texts in an image + Detect, (classify direction) and recognize texts in an image Args: img_or_path (str or np.ndarray): path to image or image rgb values as a numpy array @@ -86,6 +92,13 @@ def __call__(self, img_or_path: Union[str, np.ndarray], do_visualize=True): cv2.imwrite(os.path.join(self.crop_res_save_dir, f"{fn}_crop_{i}.jpg"), cropped_img) # show_imgs(crops, is_bgr_img=False) + # classify the text direction of cropped images + if self.cls_algorithm: + cls_start = time() + _, crops = self.direction_classify(crops) + time_profile["cls"] = time() - cls_start + _logger.info(f"\nCls time: {time_profile['cls']}") + # recognize cropped images rs = time() rec_res_all_crops = self.text_recognize(crops, do_visualize=False) @@ -172,7 +185,7 @@ def main(): text_spot(img_paths[0], do_visualize=False) # run - tot_time = {} # {'det': 0, 'rec': 0, 'all': 0} + tot_time = {} # {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} boxes_all, text_scores_all = [], [] for i, img_path in enumerate(img_paths): print(f"\nINFO: Infering [{i+1}/{len(img_paths)}]: ", img_path) diff --git a/tools/infer/text/preprocess.py b/tools/infer/text/preprocess.py index 219559232..ba4cb1ce7 100644 --- a/tools/infer/text/preprocess.py +++ b/tools/infer/text/preprocess.py @@ -54,7 +54,7 @@ def __init__(self, task="det", algo="DB", **kwargs): # TODO: modify the base pipeline for non-DBNet network if needed # if algo == 'DB++': # pipeline[1]['DetResize']['limit_side_len'] = 1152 - elif task == "rec": + elif task in ("rec", "cls"): # defalut value if not claim in optim_hparam DEFAULT_PADDING = True DEFAULT_KEEP_RATIO = True @@ -64,29 +64,35 @@ def __init__(self, task="det", algo="DB", **kwargs): # register optimal hparam for each model optimal_hparam = { - # 'CRNN': dict(target_height=32, target_width=100, padding=True, keep_ratio=True, norm_before_pad=True), + # "CRNN": dict(target_height=32, target_width=100, padding=True, keep_ratio=True, norm_before_pad=True), "CRNN": dict(target_height=32, target_width=100, padding=False, keep_ratio=False), - "CRNN_CH": dict(target_height=32, taget_width=320, padding=True, keep_ratio=True), + "CRNN_CH": dict(target_height=32, target_width=320, padding=True, keep_ratio=True), "RARE": dict(target_height=32, target_width=100, padding=False, keep_ratio=False), "RARE_CH": dict(target_height=32, target_width=320, padding=True, keep_ratio=True), "SVTR": dict(target_height=64, target_width=256, padding=False, keep_ratio=False), + "MV3": dict(target_height=48, target_width=192, padding=False, keep_ratio=False), } - # get hparam by combining default value, optimal value, and arg parser value. Prior: optimal value -> - # parser value -> default value - parsed_img_shape = kwargs.get("rec_image_shape", "3, 32, 320").split(",") + # get hparam by combining default value, optimal value, and arg parser value. + # priority: optimal value -> parser value -> default value + if task == "cls": + parsed_img_shape = kwargs.get("cls_image_shape", "3, 48, 192").split(",") + batch_mode = kwargs.get("cls_batch_mode", False) # and (batch_num > 1) + else: + parsed_img_shape = kwargs.get("rec_image_shape", "3, 32, 100").split(",") + batch_mode = kwargs.get("rec_batch_mode", False) # and (batch_num > 1) + parsed_height, parsed_width = int(parsed_img_shape[1]), int(parsed_img_shape[2]) if algo in optimal_hparam: target_height = optimal_hparam[algo]["target_height"] + norm_before_pad = optimal_hparam[algo].get("norm_before_pad", DEFAULT_NORM_BEFORE_PAD) else: target_height = parsed_height - - norm_before_pad = optimal_hparam[algo].get("norm_before_pad", DEFAULT_NORM_BEFORE_PAD) + norm_before_pad = DEFAULT_NORM_BEFORE_PAD # TODO: update max_wh_ratio for each batch # max_wh_ratio = parsed_width / float(parsed_height) # batch_num = kwargs.get('rec_batch_num', 1) - batch_mode = kwargs.get("rec_batch_mode", False) # and (batch_num > 1) if not batch_mode: # For single infer, the optimal choice is to resize the image to target height while keeping # aspect ratio, no padding. limit the max width. @@ -106,25 +112,24 @@ def __init__(self, task="det", algo="DB", **kwargs): if (target_height != parsed_height) or (target_width != parsed_width): _logger.warning( - f"`rec_image_shape` {parsed_img_shape[1:]} dose not meet the network input requirement or " - f"is not optimal, which should be [{target_height}, {target_width}] under batch mode = {batch_mode}" + f"`{task}_image_shape` {parsed_img_shape[1:]} dose not meet the network input requirement " + f"or is not optimal, which should be [{target_height}, {target_width}] under " + f"batch mode = {batch_mode}." ) - _logger.info( - f"Pick optimal preprocess hyper-params for rec algo {algo}:\n" - + "\n".join( - [ - f"{k}:\t{str(v)}" - for k, v in dict( - target_height=target_height, - target_width=target_width, - padding=padding, - keep_ratio=keep_ratio, - norm_before_pad=norm_before_pad, - ).items() - ] - ) + hparam = "\n".join( + [ + f"{k}:\t{str(v)}" + for k, v in dict( + target_height=target_height, + target_width=target_width, + padding=padding, + keep_ratio=keep_ratio, + norm_before_pad=norm_before_pad, + ).items() + ] ) + _logger.info(f"Pick optimal preprocess hyper-params for {task} algo {algo}:\n{hparam}") pipeline = [ {"DecodeImage": {"img_mode": "RGB", "keep_ori": True, "to_float32": False}},