|
237 | 237 | }
|
238 | 238 | ],
|
239 | 239 | "source": [
|
240 |
| - "class_names = sorted(x for x in os.listdir(data_dir)\n", |
241 |
| - " if os.path.isdir(os.path.join(data_dir, x)))\n", |
| 240 | + "class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))\n", |
242 | 241 | "num_class = len(class_names)\n",
|
243 | 242 | "image_files = [\n",
|
244 |
| - " [\n", |
245 |
| - " os.path.join(data_dir, class_names[i], x)\n", |
246 |
| - " for x in os.listdir(os.path.join(data_dir, class_names[i]))\n", |
247 |
| - " ]\n", |
| 243 | + " [os.path.join(data_dir, class_names[i], x) for x in os.listdir(os.path.join(data_dir, class_names[i]))]\n", |
248 | 244 | " for i in range(num_class)\n",
|
249 | 245 | "]\n",
|
250 | 246 | "num_each = [len(image_files[i]) for i in range(num_class)]\n",
|
|
341 | 337 | "test_x = [image_files_list[i] for i in test_indices]\n",
|
342 | 338 | "test_y = [image_class[i] for i in test_indices]\n",
|
343 | 339 | "\n",
|
344 |
| - "print(\n", |
345 |
| - " f\"Training count: {len(train_x)}, Validation count: \"\n", |
346 |
| - " f\"{len(val_x)}, Test count: {len(test_x)}\")" |
| 340 | + "print(f\"Training count: {len(train_x)}, Validation count: \" f\"{len(val_x)}, Test count: {len(test_x)}\")" |
347 | 341 | ]
|
348 | 342 | },
|
349 | 343 | {
|
|
370 | 364 | " ]\n",
|
371 | 365 | ")\n",
|
372 | 366 | "\n",
|
373 |
| - "val_transforms = Compose(\n", |
374 |
| - " [LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])\n", |
| 367 | + "val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])\n", |
375 | 368 | "\n",
|
376 | 369 | "y_pred_trans = Compose([Activations(softmax=True)])\n",
|
377 | 370 | "y_trans = Compose([AsDiscrete(to_onehot=num_class)])"
|
|
397 | 390 | "\n",
|
398 | 391 | "\n",
|
399 | 392 | "train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n",
|
400 |
| - "train_loader = DataLoader(\n", |
401 |
| - " train_ds, batch_size=300, shuffle=True, num_workers=10)\n", |
| 393 | + "train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)\n", |
402 | 394 | "\n",
|
403 | 395 | "val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n",
|
404 |
| - "val_loader = DataLoader(\n", |
405 |
| - " val_ds, batch_size=300, num_workers=10)\n", |
| 396 | + "val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)\n", |
406 | 397 | "\n",
|
407 | 398 | "test_ds = MedNISTDataset(test_x, test_y, val_transforms)\n",
|
408 |
| - "test_loader = DataLoader(\n", |
409 |
| - " test_ds, batch_size=300, num_workers=10)" |
| 399 | + "test_loader = DataLoader(test_ds, batch_size=300, num_workers=10)" |
410 | 400 | ]
|
411 | 401 | },
|
412 | 402 | {
|
|
430 | 420 | "outputs": [],
|
431 | 421 | "source": [
|
432 | 422 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
433 |
| - "model = DenseNet121(spatial_dims=2, in_channels=1,\n", |
434 |
| - " out_channels=num_class).to(device)\n", |
| 423 | + "model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_class).to(device)\n", |
435 | 424 | "loss_function = torch.nn.CrossEntropyLoss()\n",
|
436 | 425 | "optimizer = torch.optim.Adam(model.parameters(), 1e-5)\n",
|
437 | 426 | "max_epochs = 4\n",
|
|
477 | 466 | " loss.backward()\n",
|
478 | 467 | " optimizer.step()\n",
|
479 | 468 | " epoch_loss += loss.item()\n",
|
480 |
| - " print(\n", |
481 |
| - " f\"{step}/{len(train_ds) // train_loader.batch_size}, \"\n", |
482 |
| - " f\"train_loss: {loss.item():.4f}\")\n", |
| 469 | + " print(f\"{step}/{len(train_ds) // train_loader.batch_size}, \" f\"train_loss: {loss.item():.4f}\")\n", |
483 | 470 | " epoch_len = len(train_ds) // train_loader.batch_size\n",
|
484 | 471 | " epoch_loss /= step\n",
|
485 | 472 | " epoch_loss_values.append(epoch_loss)\n",
|
|
509 | 496 | " if result > best_metric:\n",
|
510 | 497 | " best_metric = result\n",
|
511 | 498 | " best_metric_epoch = epoch + 1\n",
|
512 |
| - " torch.save(model.state_dict(), os.path.join(\n", |
513 |
| - " root_dir, \"best_metric_model.pth\"))\n", |
| 499 | + " torch.save(model.state_dict(), os.path.join(root_dir, \"best_metric_model.pth\"))\n", |
514 | 500 | " print(\"saved new best metric model\")\n",
|
515 | 501 | " print(\n",
|
516 | 502 | " f\"current epoch: {epoch + 1} current AUC: {result:.4f}\"\n",
|
|
519 | 505 | " f\" at epoch: {best_metric_epoch}\"\n",
|
520 | 506 | " )\n",
|
521 | 507 | "\n",
|
522 |
| - "print(\n", |
523 |
| - " f\"train completed, best_metric: {best_metric:.4f} \"\n", |
524 |
| - " f\"at epoch: {best_metric_epoch}\")" |
| 508 | + "print(f\"train completed, best_metric: {best_metric:.4f} \" f\"at epoch: {best_metric_epoch}\")" |
525 | 509 | ]
|
526 | 510 | },
|
527 | 511 | {
|
|
581 | 565 | "metadata": {},
|
582 | 566 | "outputs": [],
|
583 | 567 | "source": [
|
584 |
| - "model.load_state_dict(torch.load(\n", |
585 |
| - " os.path.join(root_dir, \"best_metric_model.pth\")))\n", |
| 568 | + "model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n", |
586 | 569 | "model.eval()\n",
|
587 | 570 | "y_true = []\n",
|
588 | 571 | "y_pred = []\n",
|
|
626 | 609 | }
|
627 | 610 | ],
|
628 | 611 | "source": [
|
629 |
| - "print(classification_report(\n", |
630 |
| - " y_true, y_pred, target_names=class_names, digits=4))" |
| 612 | + "print(classification_report(y_true, y_pred, target_names=class_names, digits=4))" |
631 | 613 | ]
|
632 | 614 | },
|
633 | 615 | {
|
|
0 commit comments