Skip to content

Commit

Permalink
fixed training script (#383)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Golodkov <[email protected]>
  • Loading branch information
alexander1999-hub and Alexander Golodkov authored Dec 13, 2023
1 parent f3ec0e5 commit 72d27f7
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions dedoc/scripts/train/train_acc_orientation_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

parser = argparse.ArgumentParser()
checkpoint_path_save = os.path.abspath(os.path.join(os.path.dirname(__file__),
"../../resources/efficient_net_b0_fixed.pth"))
"../../../resources/efficient_net_b0_fixed.pth"))
checkpoint_path_load = os.path.abspath(os.path.join(os.path.dirname(__file__),
"../../../resources/efficient_net_b0_fixed_tmp.pth"))
"../../../resources/efficient_net_b0_fixed.pth"))
checkpoint_path = "../../../resources"

parser.add_argument("-t", "--train", type=bool, help="run for train model", default=False)
Expand All @@ -39,6 +39,7 @@ def accuracy_step(data_executor: DataLoaderImageOrient, net_executor: ColumnsOri
:param net_executor: Classifier
:return:
"""
net_executor.net.eval()
testloader = data_executor.load_dataset(
csv_path=os.path.join(args.input_data_folder, 'test/labels.csv'),
image_path=args.input_data_folder,
Expand Down Expand Up @@ -157,6 +158,7 @@ def train_model(trainloader: DataLoader,


def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientationClassifier) -> None:
classifier.net.train()
# Part 1 - load datas
trainloader = data_executor.load_dataset(
csv_path=os.path.join(args.input_data_folder, 'train/labels.csv'),
Expand All @@ -178,12 +180,11 @@ def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientat


if __name__ == "__main__":
from config import _config as config
from dedoc.config import _config as config
data_executor = DataLoaderImageOrient()
net = ColumnsOrientationClassifier(on_gpu=True,
checkpoint_path=checkpoint_path if not args.train else None,
checkpoint_path=checkpoint_path if not args.train else '',
config=config)

if args.train:
train_step(data_executor, net)
else:
Expand Down

0 comments on commit 72d27f7

Please sign in to comment.