The dice score is very low even after 100 epochs #5735
acamargofb
started this conversation in
General
Replies: 2 comments
-
Hello Aldo, Could you share validation curve as well? And, best metric 0.0488 at epoch 88, is it on validation set, right? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am running a code based on a example code for 2D segmentation, and I am getting a very low dice score even after 100 epochs. Does somebody know how can I improve this dice score ?
Thanks a lot in advance,
Aldo Camargo
%matplotlib ipympl
import logging
import os
import sys
import tempfile
from glob import glob
import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
#from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
Activations,
EnsureChannelFirstd,
AsDiscrete,
Compose,
LoadImaged,
LoadImage,
RandCropByPosNegLabeld,
CenterSpatialCrop,
RandRotate90,
RandRotate90d,
ScaleIntensityd,
ScaleIntensity,
GaussianSmoothd,
Lambdad,
)
from monai.visualize import plot_2d_or_3d_image
import matplotlib.pyplot as plt
import skimage.io as io
import skimage.color as color
images = sorted(glob(os.path.join(tempdir1, "frame*.png")))
segs = sorted(glob(os.path.join(tempdir2, "frame*.png")))
for image in images:
file_name = os.path.basename(image)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(image)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/images/{fName}.png'), gray_im)
#Block Label set
for seg in segs:
file_name = os.path.basename(seg)
fName, ext = os.path.splitext(file_name)
color_im = io.imread(seg)
gray_im = color.rgb2gray(color_im)
io.imsave(os.path.join(data_dir, f'grayscale/masks/{fName}.png'), gray_im)
images = sorted(glob(os.path.join(data_dir, f'grayscale/images/frame*.png')))
segs = sorted(glob(os.path.join(data_dir, f'grayscale/masks/frame*.png')))
define transforms for image and segmentation
train_imtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(512,512)), # it was (96,96)
#RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
train_segtrans = Compose(
[
LoadImage(image_only=True, ensure_channel_first=True),
ScaleIntensity(),
CenterSpatialCrop(roi_size=(512,512)), # it was (96,96)
# RandSpatialCrop((96, 96), random_size=False),
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
]
)
val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
define array dataset, data loader
check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(check_loader)
print(im.shape, seg.shape)
create a training data loader
train_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
create a validation data loader
val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
frame_transform = train_imtrans(os.path.join(data_dir, f'grayscale/images/frame_000839.png'))
from torchvision.utils import save_image
#tensor = frame_transform.cpu().numpy() # make sure tensor is on cpu
print(frame_transform.size())
save_image(frame_transform, 'GREY_img.png')
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(100):
print("-" * 10)
print(f"epoch {epoch + 1}/{10}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()
epoch 1/10
1/243, train_loss: 0.9906
2/243, train_loss: 0.9879
3/243, train_loss: 0.9874
4/243, train_loss: 0.9886
5/243, train_loss: 0.9878
6/243, train_loss: 0.9860
7/243, train_loss: 0.9859
8/243, train_loss: 0.9887
9/243, train_loss: 0.9870
10/243, train_loss: 0.9881
11/243, train_loss: 0.9874
12/243, train_loss: 0.9880
13/243, train_loss: 0.9891
14/243, train_loss: 0.9867
15/243, train_loss: 0.9871
16/243, train_loss: 0.9856
17/243, train_loss: 0.9863
18/243, train_loss: 0.9880
19/243, train_loss: 0.9883
20/243, train_loss: 0.9850
21/243, train_loss: 0.9866
22/243, train_loss: 0.9873
23/243, train_loss: 0.9885
24/243, train_loss: 0.9875
25/243, train_loss: 0.9877
26/243, train_loss: 0.9891
27/243, train_loss: 0.9887
28/243, train_loss: 0.9842
29/243, train_loss: 0.9864
30/243, train_loss: 0.9872
31/243, train_loss: 0.9849
32/243, train_loss: 0.9878
33/243, train_loss: 0.9865
34/243, train_loss: 0.9859
35/243, train_loss: 0.9841
36/243, train_loss: 0.9865
37/243, train_loss: 0.9840
38/243, train_loss: 0.9846
39/243, train_loss: 0.9876
40/243, train_loss: 0.9843
41/243, train_loss: 0.9842
42/243, train_loss: 0.9861
43/243, train_loss: 0.9828
44/243, train_loss: 0.9816
45/243, train_loss: 0.9820
46/243, train_loss: 0.9818
47/243, train_loss: 0.9849
48/243, train_loss: 0.9819
49/243, train_loss: 0.9843
50/243, train_loss: 0.9831
51/243, train_loss: 0.9825
52/243, train_loss: 0.9833
53/243, train_loss: 0.9848
54/243, train_loss: 0.9828
55/243, train_loss: 0.9823
56/243, train_loss: 0.9803
57/243, train_loss: 0.9829
58/243, train_loss: 0.9819
59/243, train_loss: 0.9823
60/243, train_loss: 0.9807
61/243, train_loss: 0.9803
62/243, train_loss: 0.9783
63/243, train_loss: 0.9822
64/243, train_loss: 0.9794
65/243, train_loss: 0.9808
66/243, train_loss: 0.9822
67/243, train_loss: 0.9797
68/243, train_loss: 0.9847
69/243, train_loss: 0.9798
70/243, train_loss: 0.9797
71/243, train_loss: 0.9801
72/243, train_loss: 0.9764
73/243, train_loss: 0.9775
74/243, train_loss: 0.9753
75/243, train_loss: 0.9784
76/243, train_loss: 0.9812
77/243, train_loss: 0.9786
78/243, train_loss: 0.9782
79/243, train_loss: 0.9775
80/243, train_loss: 0.9769
81/243, train_loss: 0.9803
82/243, train_loss: 0.9812
83/243, train_loss: 0.9753
84/243, train_loss: 0.9776
85/243, train_loss: 0.9787
86/243, train_loss: 0.9807
87/243, train_loss: 0.9786
88/243, train_loss: 0.9763
89/243, train_loss: 0.9784
90/243, train_loss: 0.9795
91/243, train_loss: 0.9771
92/243, train_loss: 0.9730
93/243, train_loss: 0.9802
94/243, train_loss: 0.9786
95/243, train_loss: 0.9770
96/243, train_loss: 0.9801
97/243, train_loss: 0.9755
98/243, train_loss: 0.9782
99/243, train_loss: 0.9748
100/243, train_loss: 0.9774
101/243, train_loss: 0.9792
102/243, train_loss: 0.9753
103/243, train_loss: 0.9745
104/243, train_loss: 0.9779
105/243, train_loss: 0.9750
106/243, train_loss: 0.9759
107/243, train_loss: 0.9783
108/243, train_loss: 0.9735
109/243, train_loss: 0.9777
110/243, train_loss: 0.9797
111/243, train_loss: 0.9784
112/243, train_loss: 0.9739
113/243, train_loss: 0.9751
114/243, train_loss: 0.9751
115/243, train_loss: 0.9784
116/243, train_loss: 0.9770
117/243, train_loss: 0.9794
118/243, train_loss: 0.9749
119/243, train_loss: 0.9790
120/243, train_loss: 0.9756
121/243, train_loss: 0.9758
122/243, train_loss: 0.9746
123/243, train_loss: 0.9736
124/243, train_loss: 0.9758
125/243, train_loss: 0.9704
126/243, train_loss: 0.9752
127/243, train_loss: 0.9727
128/243, train_loss: 0.9775
129/243, train_loss: 0.9729
130/243, train_loss: 0.9767
131/243, train_loss: 0.9712
132/243, train_loss: 0.9785
133/243, train_loss: 0.9747
134/243, train_loss: 0.9756
135/243, train_loss: 0.9754
136/243, train_loss: 0.9772
137/243, train_loss: 0.9754
138/243, train_loss: 0.9768
139/243, train_loss: 0.9711
140/243, train_loss: 0.9773
141/243, train_loss: 0.9706
142/243, train_loss: 0.9757
143/243, train_loss: 0.9756
144/243, train_loss: 0.9764
145/243, train_loss: 0.9726
146/243, train_loss: 0.9777
147/243, train_loss: 0.9708
148/243, train_loss: 0.9725
149/243, train_loss: 0.9694
150/243, train_loss: 0.9767
151/243, train_loss: 0.9743
152/243, train_loss: 0.9706
153/243, train_loss: 0.9699
154/243, train_loss: 0.9746
155/243, train_loss: 0.9721
156/243, train_loss: 0.9737
157/243, train_loss: 0.9722
158/243, train_loss: 0.9721
159/243, train_loss: 0.9718
160/243, train_loss: 0.9701
161/243, train_loss: 0.9699
162/243, train_loss: 0.9696
163/243, train_loss: 0.9652
164/243, train_loss: 0.9729
165/243, train_loss: 0.9750
166/243, train_loss: 0.9758
167/243, train_loss: 0.9713
168/243, train_loss: 0.9735
169/243, train_loss: 0.9722
170/243, train_loss: 0.9702
171/243, train_loss: 0.9696
172/243, train_loss: 0.9682
173/243, train_loss: 0.9752
174/243, train_loss: 0.9735
175/243, train_loss: 0.9692
176/243, train_loss: 0.9721
177/243, train_loss: 0.9725
178/243, train_loss: 0.9696
179/243, train_loss: 0.9733
180/243, train_loss: 0.9700
181/243, train_loss: 0.9761
182/243, train_loss: 0.9697
183/243, train_loss: 0.9727
184/243, train_loss: 0.9695
185/243, train_loss: 0.9675
186/243, train_loss: 0.9704
187/243, train_loss: 0.9650
188/243, train_loss: 0.9795
189/243, train_loss: 0.9725
190/243, train_loss: 0.9721
191/243, train_loss: 0.9735
192/243, train_loss: 0.9713
193/243, train_loss: 0.9670
194/243, train_loss: 0.9693
195/243, train_loss: 0.9718
196/243, train_loss: 0.9669
197/243, train_loss: 0.9720
198/243, train_loss: 0.9647
199/243, train_loss: 0.9689
200/243, train_loss: 0.9701
201/243, train_loss: 0.9670
202/243, train_loss: 0.9741
203/243, train_loss: 0.9693
204/243, train_loss: 0.9739
205/243, train_loss: 0.9704
206/243, train_loss: 0.9667
207/243, train_loss: 0.9665
208/243, train_loss: 0.9605
209/243, train_loss: 0.9699
210/243, train_loss: 0.9707
211/243, train_loss: 0.9703
212/243, train_loss: 0.9697
213/243, train_loss: 0.9691
214/243, train_loss: 0.9684
215/243, train_loss: 0.9676
216/243, train_loss: 0.9680
217/243, train_loss: 0.9648
218/243, train_loss: 0.9688
219/243, train_loss: 0.9747
220/243, train_loss: 0.9681
221/243, train_loss: 0.9660
222/243, train_loss: 0.9640
223/243, train_loss: 0.9655
224/243, train_loss: 0.9676
225/243, train_loss: 0.9647
226/243, train_loss: 0.9695
227/243, train_loss: 0.9691
228/243, train_loss: 0.9666
229/243, train_loss: 0.9670
230/243, train_loss: 0.9719
231/243, train_loss: 0.9634
232/243, train_loss: 0.9667
233/243, train_loss: 0.9646
234/243, train_loss: 0.9694
235/243, train_loss: 0.9650
236/243, train_loss: 0.9698
237/243, train_loss: 0.9659
238/243, train_loss: 0.9662
239/243, train_loss: 0.9628
240/243, train_loss: 0.9603
241/243, train_loss: 0.9593
242/243, train_loss: 0.9646
243/243, train_loss: 0.9670
244/243, train_loss: 0.9646
epoch 1 average loss: 0.9761
epoch 2/10
1/243, train_loss: 0.9655
...
...
242/243, train_loss: 0.0166
243/243, train_loss: 0.0136
244/243, train_loss: 0.0027
epoch 100 average loss: 0.0131
current epoch: 100 current mean dice: 0.0123 best mean dice: 0.0488 at epoch 88
train completed, best_metric: 0.0488 at epoch: 88
As tou could see the best dice is 0.0488 at epoch 88.
Please any help can be useful, thanks a lot and have a nice day,
Aldo Camargo
Beta Was this translation helpful? Give feedback.
All reactions