Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add early stopping support for object detection example #87

Merged
merged 6 commits into from
Jan 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 103 additions & 26 deletions object_detection/detectron2_training-kfold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from detectron2.engine import DefaultTrainer\n",
"from detectron2.config import get_cfg\n",
"import pickle\n",
"# import some common libraries\n",
"from detectron2.data import build_detection_test_loader, build_detection_train_loader\n",
"import numpy as np\n",
"import os, json, cv2, random\n",
"from detectron2.data import build_detection_test_loader\n",
Expand All @@ -52,12 +55,15 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!wget -nc \"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\" && unzip -q -o annotations_trainval2017.zip\n",
"!wget -nc \"http://images.cocodataset.org/zips/val2017.zip\" && unzip -q -o val2017.zip\n",
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip"
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip\n",
"!wget -nc \"https://cleanlab-public.s3.amazonaws.com/ObjectDetectionBenchmarking/tutorial/TRAIN_COCO_ALL_labels.pkl\""
]
},
{
Expand Down Expand Up @@ -92,7 +98,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import json\n",
Expand Down Expand Up @@ -141,13 +149,26 @@
" annotations_count = len(data_dict['annotations'])\n",
" print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
"\n",
" \n",
"def unregister_coco_instances(name):\n",
" if name in DatasetCatalog.list():\n",
" DatasetCatalog.remove(name)\n",
" MetadataCatalog.remove(name)\n",
"\n",
"# Generate K-Fold cross-validation\n",
"kf = KFold(n_splits=NUM_FOLDS)\n",
"pairs = []\n",
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
" train_data, test_data = split_data(train_indices, test_indices)\n",
" train_file = f\"train_coco_{fold}_fold.json\"\n",
" test_file = f\"test_coco_{fold}_fold.json\"\n",
" # Unregister instances with the same names only if they exist\n",
" unregister_coco_instances(train_file)\n",
" unregister_coco_instances(test_file)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
aditya1503 marked this conversation as resolved.
Show resolved Hide resolved
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
" pairs.append([train_file,test_file])\n",
" with open(train_file, 'w') as train_file:\n",
" json.dump(train_data, train_file)\n",
Expand All @@ -156,7 +177,9 @@
" print(f\"Data info for training data fold {fold}:\")\n",
" print_data_info(train_data, fold)\n",
" print(f\"Data info for test data fold {fold}:\")\n",
" print_data_info(test_data, fold)\n"
" print_data_info(test_data, fold)\n",
" \n",
"TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")"
]
},
{
Expand All @@ -175,36 +198,83 @@
"The number of worker threads is set to 2 and the batch size is set to 2.\n",
"The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n",
"\n",
"<strong>Note:</strong> The number of iterations was set based on [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.)"
"<strong>Note:</strong> The choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"def train_data(TRAIN,VALIDATION,folder):\n",
"class Early_stopping(DefaultTrainer):\n",
" def __init__(self, cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\"):\n",
" super().__init__(cfg)\n",
" self.early_stop_patience = early_stop_patience\n",
" self.model_checkpoint_path = model_checkpoint_path\n",
" self.best_validation_loss = float('inf')\n",
" self.current_patience = 0\n",
"\n",
" def build_train_loader(self, cfg):\n",
" return build_detection_train_loader(cfg)\n",
" \n",
" def data_loader_mapper(self, batch):\n",
" return batch\n",
"\n",
" def run_hooks(self):\n",
" val_loss = self.validation()\n",
" if val_loss < self.best_validation_loss:\n",
" self.best_validation_loss = val_loss\n",
" self.current_patience = 0\n",
" self.save_checkpoint()\n",
" else:\n",
" self.current_patience += 1\n",
" if self.current_patience >= self.early_stop_patience:\n",
" self._trainer.save_checkpoint()\n",
" self._trainer.has_finished = True\n",
"\n",
" def validation(self):\n",
" # Define evaluator here\n",
" evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n",
" val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n",
" val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n",
" val_loss = val_results[\"total_loss\"]\n",
" return val_loss\n",
"\n",
" def save_checkpoint(self):\n",
" checkpointer = DetectionCheckpointer(self.model)\n",
" checkpointer.save(self.model_checkpoint_path)\n",
" \n",
"\n",
"def train_model(TRAIN,VALIDATION,folder):\n",
" cfg = get_cfg()\n",
" MODEL = 'faster_rcnn_X_101_32x8d_FPN_3x.yaml'\n",
" cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/\"+MODEL))\n",
" cfg.DATASETS.TRAIN = (TRAIN,)\n",
" cfg.DATASETS.TEST = (VALIDATION,)\n",
" cfg.DATALOADER.NUM_WORKERS = 2\n",
" cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
" #Uncomment if you want to use pre-trained weights for finetuning, not recommended for K fold training\n",
" # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
" \n",
" \n",
" cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n",
" cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n",
" cfg.SOLVER.MAX_ITER = 6000 # \n",
" cfg.SOLVER.BASE_LR = 0.004 # pick a good LR\n",
" cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n",
" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n",
" cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n",
" cfg.TEST.EVAL_PERIOD = 500\n",
" cfg.TEST.EVAL_PERIOD = 15000\n",
" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
" trainer = DefaultTrainer(cfg) \n",
" trainer = Early_stopping(cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\")\n",
" # Specify evaluators during testing\n",
" evaluator = COCOEvaluator(cfg.DATASETS.TEST[0], cfg, True, output_dir=\"./output/\")\n",
" trainer.resume_or_load(resume=False)\n",
" trainer.test(cfg, trainer.model, evaluators=[evaluator])\n",
" trainer.resume_or_load(resume=False)\n",
" trainer.train();\n"
" trainer.train();\n",
" return cfg\n"
]
},
{
Expand All @@ -224,7 +294,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def format_detectron2_predictions(instances, num_classes):\n",
Expand Down Expand Up @@ -254,7 +326,7 @@
" formatted_results = []\n",
" for i in results:\n",
" if len(i) == 0:\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, num_classes))\n",
" formatted_array = np.array(i, dtype=np.float32).reshape((0, 5))\n",
" else:\n",
" formatted_array = np.array(i, dtype=np.float32)\n",
" formatted_results.append(formatted_array)\n",
Expand All @@ -266,46 +338,51 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"for k in range(0,NUM_FOLDS):\n",
" result_dict = {}\n",
" train_data = pairs[k][0]\n",
" val_data = pairs[k][1]\n",
" train_data(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
" cfg = train_model(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
" evaluator = COCOEvaluator(val_data, output_dir=\"output\")\n",
" val_loader = build_detection_test_loader(cfg, val_data)\n",
" cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, \"model_final.pth\") # path to the model we just trained\n",
" cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 # set a custom testing threshold\n",
" predictor = DefaultPredictor(cfg)\n",
" dataset = json.load(open(\"../\"+pairs[k][1]+'.json','rb'))\n",
" for image in dat['images']:\n",
" im_name = os.path.join(TRAIN_PATH, i['file_name'])\n",
" dataset = json.load(open(pairs[k][1],'rb'))\n",
" for image in dataset['images']:\n",
" im_name = os.path.join(TRAIN_PATH, image['file_name'])\n",
" im = cv2.imread(im_name)\n",
" outputs = predictor(im)\n",
" result_dict[im_name](format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
" result_dict[im_name] = (format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
" pickle.dump(result_dict,open(\"results_fold_\"+str(k)+\".pkl\",'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"result_dict = {}\n",
"for k in range(0,NUM_FOLDS):\n",
" res_d = pickle.load(open(\"results_fold_\"+str(k)+'.pkl','rb'))\n",
" for r in res_d:\n",
" result_dict[r] = res_d[i]"
" result_dict[r] = res_d[r]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset = pickle.load(open(\"TRAIN_COCO_ALL_labels.pkl\",'rb'))\n",
Expand Down Expand Up @@ -333,7 +410,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
Loading