From e47ee34498967df65e253c7cfef231467c32c078 Mon Sep 17 00:00:00 2001 From: Alexander Golodkov Date: Wed, 13 Dec 2023 16:54:33 +0300 Subject: [PATCH] fixed training script --- .../scripts/train/train_acc_orientation_classifier.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dedoc/scripts/train/train_acc_orientation_classifier.py b/dedoc/scripts/train/train_acc_orientation_classifier.py index fc43db40..825e5707 100644 --- a/dedoc/scripts/train/train_acc_orientation_classifier.py +++ b/dedoc/scripts/train/train_acc_orientation_classifier.py @@ -13,9 +13,9 @@ parser = argparse.ArgumentParser() checkpoint_path_save = os.path.abspath(os.path.join(os.path.dirname(__file__), - "../../resources/efficient_net_b0_fixed.pth")) + "../../../resources/efficient_net_b0_fixed.pth")) checkpoint_path_load = os.path.abspath(os.path.join(os.path.dirname(__file__), - "../../../resources/efficient_net_b0_fixed_tmp.pth")) + "../../../resources/efficient_net_b0_fixed.pth")) checkpoint_path = "../../../resources" parser.add_argument("-t", "--train", type=bool, help="run for train model", default=False) @@ -39,6 +39,7 @@ def accuracy_step(data_executor: DataLoaderImageOrient, net_executor: ColumnsOri :param net_executor: Classifier :return: """ + net_executor.net.eval() testloader = data_executor.load_dataset( csv_path=os.path.join(args.input_data_folder, 'test/labels.csv'), image_path=args.input_data_folder, @@ -157,6 +158,7 @@ def train_model(trainloader: DataLoader, def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientationClassifier) -> None: + classifier.net.train() # Part 1 - load datas trainloader = data_executor.load_dataset( csv_path=os.path.join(args.input_data_folder, 'train/labels.csv'), @@ -178,12 +180,11 @@ def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientat if __name__ == "__main__": - from config import _config as config + from dedoc.config import _config as config data_executor = DataLoaderImageOrient() net = ColumnsOrientationClassifier(on_gpu=True, - checkpoint_path=checkpoint_path if not args.train else None, + checkpoint_path=checkpoint_path if not args.train else '', config=config) - if args.train: train_step(data_executor, net) else: