Skip to content

Commit

Permalink
Inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
emathian committed Jan 12, 2024
1 parent a2126fd commit 83adaee
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ stdout.txt

# network weights
#CheckpointKi67/

# Inference folder
Inference/
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,22 @@ python eval_opt_thresholds.py --inputPath test256 --configPath configs/eval_Ki67
- Results are saved in the table specified by `dataname`.
- To optimize the threshold associated with cells detected as negative at the marker, specify `--channel 1`.


## Step 5: Inference
- The `infer.py` script is used to run the model in inference mode.
- Command line
```
python infer.py --inputPath KI67_Tiles_256_256_40x/ --configPath configs/eval_Ki67_LNEN.json --outputPath Inference --save_numpy --visualization
```
- The `inputPath` folder must have the following structure:
- `inputPath`
- patient_ID
- accept
- patient_id_tiles_x_y.jpg
- The inference script saves the inference results for each tile in three formats:
- `json` file (default)
- `npy` numpy matrix if the `--save_numpy` argument is specified
- `jpg` annotated image if the `--visualization` argument is specified
- The `outputPath` folder will have the same organization as the `inputPath` folder.

## TO DO LIST

Expand Down
73 changes: 73 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pipeline import Pipeline
from config import Config
import os
import glob
import imageio
import cv2
import argparse
import numpy as np
import json
import warnings
import pandas as pd
warnings.filterwarnings("ignore", category=DeprecationWarning)
def get_parser():

parser = argparse.ArgumentParser('Inference')
parser.add_argument('--inputPath', '-i', required=True, help="Path to the test set")
parser.add_argument('--outputPath', '-o', required=True, help="Path to the folder where the results will be saved")
parser.add_argument('--configPath', '-c', required=True, help="Path to the json config file")
parser.add_argument('--save_numpy', action='store_true', help="Save the inference results for each tile in npy format")
parser.add_argument('--visualization', action='store_true', help="Save the tile annotated in jpg format for visualization")
return parser

def visualizer(img,points):
r=1
colors=[
(255,0,0),
(0,255,0),
(0,0,255)
]
image=np.copy(img)
for p in points:
x,y,c=p[0],p[1],p[2]
cv2.circle(image, (int(x), int(y)), int(r), colors[int(c)], 2)
return image


def infer(args=None):
parser = get_parser()
args = parser.parse_args(args)
conf=Config()
conf.load(args.configPath)
pipeline=Pipeline(conf)

if os.path.isdir(args.inputPath):
#print('args.inputPath = ', args.inputPath, os.listdir(args.inputPath))
patient_ids_list = os.listdir(args.inputPath)
for patient_id in patient_ids_list:
print('patient_id' , patient_id)
os.makedirs(os.path.join(args.outputPath, patient_id), exist_ok = True)
os.makedirs(os.path.join(args.outputPath, patient_id, 'accept'), exist_ok = True)
data = [ os.path.join(args.inputPath, patient_id, 'accept', f) for f in os.listdir(os.path.join(args.inputPath, patient_id, 'accept')) if '.jpg' in f]
for d in data:
img=imageio.imread(d)
pred_cell, pred_neg, pred_pos =pipeline.predict(img, exhaustive=True)
if args.save_numpy:
np.save(os.path.join(args.outputPath, patient_id, 'accept',d.split("/")[-1][:-4] + "_pred_neg_mask.npy"), pred_neg)
np.save(os.path.join(args.outputPath, patient_id, 'accept',d.split("/")[-1][:-4] + "_pred_pos_mask.npy"), pred_pos)
if args.visualization:
output=visualizer(img,pred_cell)
imageio.imwrite(os.path.join(args.outputPath, patient_id, 'accept',d.split("/")[-1][:-4] + "viz.jpg"),output)
list_cells_for_json = []
for cell in pred_cell:
c_cell= {}
c_cell['x'] = cell[0]
c_cell['y'] = cell[1]
c_cell['label_id'] = int(cell[2]) + 1
list_cells_for_json.append(c_cell)
json_pred_fname = os.path.join(args.outputPath, patient_id, 'accept', d.split("/")[-1][:-3]+'json')
with open(json_pred_fname, 'w') as f:
json.dump(list_cells_for_json, f)

if __name__ == "__main__":
infer()

0 comments on commit 83adaee

Please sign in to comment.