Skip to content

Commit

Permalink
added inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
anikde committed Oct 7, 2024
1 parent a149307 commit b5cf34c
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
68 changes: 68 additions & 0 deletions YOLO/TableDetection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# YOLO RapidAI Table Detection

This repository contains scripts and resources for converting annotation data from JSON format to YOLO format, training a YOLO model for header and footer detection, and visualizing the ground truth annotations.
<!--
## Directory Structure
```
yolo_tabledetecton/
├── check.png # Sample image for checking
├── convert_data_to_yolo.py # Script for converting JSON to YOLO format
├── data/ # Directory for storing datasets
├── main.py # Main script for data conversion and processing
├── rapidAI_data_to_yolo.py # Alternate script for data conversion and data splitting
├── runs/ # Directory for storing training results
├── train.yaml # YAML configuration file for training
├── visualize_gt.py # Script for visualizing ground truth annotations
└── yolov8n.pt # Pre-trained YOLOv8n model
``` -->

## Installation

Create a virtual environment with python >= 3.9 and install the following python packages.
```
pip install ultralytics
pip install opencv-python
pip install pillow
pip install matplotlib
pip install fire
```
<!--
## Converting Data from JSON to YOLO Format
Use the ```convert_data_to_yolo.py``` script to convert your JSON annotations to YOLO format. This script processes a directory of JSON files and saves the converted annotations in the specified output directory.
Usage
```
python convert_data_to_yolo.py json_dir images_dir output_dir
```
Also use ``` python convert_data_to_yolo.py --help ``` to know more.
Use the ``` rapidAI_data_to_yolo.py``` script which converts JSON annotations to YOLO format, splits the dataset into training and testing sets, and saves the resulting files in the specified directory structure.
Usage
```
python rapidAI_data_to_yolo.py json_dir images_dir
```
Also use ``` python rapidAI_data_to_yolo.py --help ``` to know more. -->

## Inference
To get predictions of collection of images use ```predict.py``` script.
```
python predict.py --image_dir "/path/to/images" --model_path "/path/to/model" --save_dir "/path/to/save_dir"
```


<!-- ## Training the YOLO Model
To train the YOLO model using the converted data, use the ```main.py``` script.
Usage
```
python main.py --yaml_file train.yaml --epochs 20 --devices [1] --batch 128
```
Also use ``` python main.py --help ``` to know more.
## Visualizing Ground Truth Annotations
The ```visualize_gt.py``` script allows you to visualize the ground truth annotations on the images. This create a png file named as "check.png" by default. -->
79 changes: 79 additions & 0 deletions YOLO/TableDetection/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import fire
from ultralytics import YOLO
import numpy as np
import cv2
import os
import json

def create_label_mapping():
"""Define a mapping for label classes."""
label_mapping = {
0: "Table",
}
return label_mapping

def convert_to_json(filename, cls_list, conf_list, xyxy_list):
"""Convert prediction data to JSON format.
Args:
filename (str): Name of the image file.
cls_list (list): List of class indices.
conf_list (list): List of confidence scores.
xyxy_list (list): List of bounding box coordinates in xyxy format.
Returns:
str: JSON-formatted string of the prediction data.
"""
label_map = create_label_mapping()
json_data = {
"filename" : filename,
"detections": [
{
"cls": int(cls_list[i]),
"cls_name" : label_map.get(int(cls_list[i]), "Unknown label"),
"conf": float(conf_list[i]),
"bbox":
{
"x1": float(xyxy_list[i][0]),
"y1": float(xyxy_list[i][1]),
"x2": float(xyxy_list[i][2]),
"y2": float(xyxy_list[i][3])
},
}
for i in range(len(cls_list))
]
}
return json.dumps(json_data, indent=4)

def save_prediction(image_dir, model_path, save_dir):
"""Run model prediction on images and save the results.
Args:
image_dir (str): Directory containing images to be processed.
model_path (str): Path to the YOLO model weights file.
save_dir (str): Directory to save images with predictions drawn.
"""

if not os.path.exists(save_dir):
os.makedirs(save_dir)

model = YOLO(model_path)
for image in os.listdir(image_dir):
frame = cv2.imread(os.path.join(image_dir, image))
results = model.predict(frame, save=False)

json_string = convert_to_json(image, results[0].boxes.cls, results[0].boxes.conf, results[0].boxes.xyxy)

filename = image.split(".")[0]
os.makedirs("predictions/jsons", exist_ok=True)
with open(f"predictions/jsons/{filename}.json", 'w') as file:
file.write(json_string)

img = results[0].plot(font_size=20, pil=True)
cv2.imwrite(f"{save_dir}/{image}", img)

if __name__ == "__main__":
fire.Fire(save_prediction)


# python your_script.py --image_dir "/path/to/images" --model_path "model/best.pt" --save_dir "/path/to/save_dir"

0 comments on commit b5cf34c

Please sign in to comment.