Skip to content

Commit

Permalink
less hardcoded parameters and split dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
VCasecnikovs committed Dec 15, 2020
1 parent a5e9716 commit bcfd1b5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
38 changes: 13 additions & 25 deletions SplitDataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,27 +21,22 @@
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"all_labels = []\n",
"\n",
"for date_path in Path(\"labels\").iterdir():\n",
"for date_path in Path(\"images\").iterdir():\n",
" for augs_path in date_path.iterdir():\n",
" with open(augs_path, 'r') as f:\n",
" labels = f.readlines()\n",
" for label in labels:\n",
" img_name = label.strip().split(\"/\")[-1]\n",
" splitted_path = str(augs_path).split(\"\\\\\")\n",
" splitted_path[0] = \"images\"\n",
" splitted_path.append(img_name)\n",
" all_labels.append(\"/\".join(splitted_path))"
" all_labels.append(str(augs_path)) "
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -50,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -59,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -69,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -79,20 +74,13 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(\"valid.txt\", \"w\") as vf:\n",
" vf.write(\"\\n\".join(valid_labels))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
5 changes: 4 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, list_path, img_dir="images", labels_dir="labels", img_extens
path = path.replace(img_dir, labels_dir)
for ext in img_extensions:
path = path.replace(ext, ".txt")

self.label_files.append(path)

self.img_size = img_size
Expand Down Expand Up @@ -47,7 +48,9 @@ def __getitem__(self, index):

if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))

else:
print(label_path)

# RESIZING
if width > height:
ratio = height/width
Expand Down
6 changes: 3 additions & 3 deletions pl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(self, hparams):

self.hparams = hparams

self.train_ds = ListDataset(hparams.train_ds, train=True)
self.valid_ds = ListDataset(hparams.valid_ds, train=False)
self.train_ds = ListDataset(hparams.train_ds, train=True, img_extensions=hparams.img_extensions)
self.valid_ds = ListDataset(hparams.valid_ds, train=False, img_extensions=hparams.img_extensions)

self.model = YOLOv4(n_classes=5,
self.model = YOLOv4(n_classes=hparams.n_classes,
pretrained=hparams.pretrained,
dropblock=hparams.Dropblock,
sam=hparams.SAM,
Expand Down

0 comments on commit bcfd1b5

Please sign in to comment.