-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,360 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyNsf5ThIjJGFtnkwBnuKQeZ", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/Devadeut/Neural-Networks-Hands-On-Projects/blob/main/Fairness_Contrastive_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "2yQBp_QQFW_Q", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "b4b03da9-5301-4fdf-de8c-19e222cc0008" | ||
}, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Collecting pydicom\n", | ||
" Downloading pydicom-2.4.4-py3-none-any.whl (1.8 MB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hInstalling collected packages: pydicom\n", | ||
"Successfully installed pydicom-2.4.4\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from torchvision import models, transforms\n", | ||
"from torchvision.models import resnet50\n", | ||
"from torchvision.transforms.functional import to_pil_image\n", | ||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n", | ||
"!pip install pydicom\n", | ||
"import pydicom\n", | ||
"from PIL import Image\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## DATA" | ||
], | ||
"metadata": { | ||
"id": "790vZcpAo9Hl" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Data Processing\n", | ||
"\n", | ||
"\n", | ||
"# Define the transformation pipeline\n", | ||
"transform_pipeline = transforms.Compose([\n", | ||
" transforms.Resize(256),\n", | ||
" transforms.CenterCrop(224),\n", | ||
" transforms.RandomRotation(degrees=10),\n", | ||
" transforms.RandomHorizontalFlip(),\n", | ||
" transforms.ToTensor(),\n", | ||
"])\n", | ||
"\n", | ||
"def dicom_to_jpg(dicom_path):\n", | ||
" # Load the DICOM image\n", | ||
" dicom_image = pydicom.dcmread(dicom_path)\n", | ||
" image_array = dicom_image.pixel_array\n", | ||
"\n", | ||
" # Normalize to [0, 255]\n", | ||
" image_array = (np.maximum(image_array, 0) / image_array.max()) * 255.0\n", | ||
"\n", | ||
" # Invert pixels if necessary\n", | ||
" if dicom_image.PhotometricInterpretation == \"MONOCHROME1\":\n", | ||
" image_array = 255.0 - image_array\n", | ||
"\n", | ||
" # Perform histogram equalization\n", | ||
" image_eq = cv2.equalizeHist(image_array.astype(np.uint8))\n", | ||
"\n", | ||
" # Convert to PIL Image\n", | ||
" pil_img = Image.fromarray(image_eq)\n", | ||
"\n", | ||
" # Save as JPG\n", | ||
" pil_img.save(\"output.jpg\", \"JPEG\", quality=95)\n", | ||
"\n", | ||
" return pil_img\n", | ||
"\n", | ||
"def preprocess_and_augment(image_path):\n", | ||
" # Convert DICOM to JPG if it's a DICOM file\n", | ||
" if image_path.endswith('.dcm'):\n", | ||
" image = dicom_to_jpg(image_path)\n", | ||
" else:\n", | ||
" image = Image.open(image_path)\n", | ||
"\n", | ||
" # Apply transformations\n", | ||
" return transform_pipeline(image)\n", | ||
"\n", | ||
"# Example usage:\n", | ||
"# processed_image = preprocess_and_augment('path_to_image.dcm')\n" | ||
], | ||
"metadata": { | ||
"id": "GHqSgjLxdunW" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Dataloading\n", | ||
"\n", | ||
"\n", | ||
"from torch.utils.data import Dataset\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"import random\n", | ||
"\n", | ||
"class TripletDataset(Dataset):\n", | ||
" def __init__(self, image_paths, labels, transform=None):\n", | ||
" # image_paths: List of paths to images\n", | ||
" # labels: Dictionary mapping image paths to their labels\n", | ||
" self.image_paths = image_paths\n", | ||
" self.labels = labels\n", | ||
" self.transform = transform\n", | ||
" self.labels_set = set(labels.values())\n", | ||
" self.labels_to_indices = {label: np.where(np.array(labels.values()) == label)[0]\n", | ||
" for label in self.labels_set}\n", | ||
"\n", | ||
" def __getitem__(self, index):\n", | ||
" anchor_path = self.image_paths[index]\n", | ||
" anchor_label = self.labels[anchor_path]\n", | ||
"\n", | ||
" # Get a positive sample (same label, different image)\n", | ||
" positive_index = index\n", | ||
" while positive_index == index:\n", | ||
" positive_index = random.choice(self.labels_to_indices[anchor_label])\n", | ||
" positive_path = self.image_paths[positive_index]\n", | ||
"\n", | ||
" # Get a negative sample (different label)\n", | ||
" negative_label = random.choice(list(self.labels_set - set([anchor_label])))\n", | ||
" negative_index = random.choice(self.labels_to_indices[negative_label])\n", | ||
" negative_path = self.image_paths[negative_index]\n", | ||
"\n", | ||
" # Load images and apply transformations\n", | ||
" anchor_img = preprocess_and_augment(anchor_path)\n", | ||
" positive_img = preprocess_and_augment(positive_path)\n", | ||
" negative_img = preprocess_and_augment(negative_path)\n", | ||
"\n", | ||
" return anchor_img, positive_img, negative_img, anchor_label\n", | ||
"\n", | ||
" def __len__(self):\n", | ||
" return len(self.image_paths)\n", | ||
"\n", | ||
"\n" | ||
], | ||
"metadata": { | ||
"id": "4s80u_dngUVk" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## MODEL" | ||
], | ||
"metadata": { | ||
"id": "_jRJ_75yowFj" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Backbone\n", | ||
"# Model Backbone\n", | ||
"class Backbone(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(Backbone, self).__init__()\n", | ||
" # Use a pre-trained model without the top layer\n", | ||
" self.base_model = resnet50(pretrained=True)\n", | ||
" self.base_model.fc = nn.Identity() # Remove the top layer\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" return self.base_model(x)" | ||
], | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "sbjnxQNmGHbK" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Contrastive Head\n", | ||
"class ContrastiveHead(nn.Module):\n", | ||
" def __init__(self, feature_dim=1024, embedding_dim=128):\n", | ||
" super(ContrastiveHead, self).__init__()\n", | ||
" self.fc = nn.Linear(feature_dim, embedding_dim)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" return self.fc(x)" | ||
], | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "Ph-seiEoG2wc" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title ContrastiveLoss\n", | ||
"class ContrastiveLoss(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(ContrastiveLoss, self).__init__()\n", | ||
" self.temperature = temperature\n", | ||
"\n", | ||
" def forward(self, anchor, positives, negatives):\n", | ||
" #compute the similarities\n", | ||
" anchor_dot_positives = torch.matmul(anchor, positives.t()) / self.temperature\n", | ||
" anchor_dot_negatives = torch.matmul(anchor, negatives.t()) / self.temperature\n", | ||
"\n", | ||
" # Compute the log-sum-exp of negatives for each anchor\n", | ||
" negatives_logsumexp = torch.logsumexp(anchor_dot_negatives, dim=1)\n", | ||
"\n", | ||
" # Sum over all positives for each anchor, and average over all anchors in the batch\n", | ||
" loss = 0\n", | ||
" for i in range(anchor.size(0)):\n", | ||
" for j in range(positives.size(0)):\n", | ||
" loss -= anchor_dot_positives[i][j] - negatives_logsumexp[i]\n", | ||
"\n", | ||
" return loss.mean()\n" | ||
], | ||
"metadata": { | ||
"id": "Db5zzE2LOK2l" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Projection Head and Classification Head\n", | ||
"class ProjectionHead(nn.Module):\n", | ||
" def __init__(self, embedding_dim, projection_dim):\n", | ||
" super(ProjectionHead, self).__init__()\n", | ||
" self.fc = nn.Linear(embedding_dim, projection_dim)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" return F.relu(self.fc(x))\n", | ||
"\n", | ||
"#\n", | ||
"class ClassificationHead(nn.Module):\n", | ||
" def __init__(self, projection_dim):\n", | ||
" super(ClassificationHead, self).__init__()\n", | ||
" self.fc = nn.Linear(projection_dim, 1)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" return torch.sigmoid(self.fc(x))" | ||
], | ||
"metadata": { | ||
"id": "649Mvyzy0zhX" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# @title Model Training\n", | ||
"# @markdown computes the loss and updates the model parameters.\n", | ||
"\n", | ||
"# Initialize the resnet50 backbone and the contrastive head\n", | ||
"backbone = Backbone()\n", | ||
"contrastive_head = ContrastiveHead()\n", | ||
"\n", | ||
"# Assuming `image_paths` is a list of image file paths and `labels` is a dict mapping image paths to labels\n", | ||
"triplet_dataset = TripletDataset(image_paths=image_paths, labels=labels, transform=transform_pipeline)\n", | ||
"triplet_dataloader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)\n", | ||
"\n", | ||
"# Adding the projection and classification heads to the model\n", | ||
"projection_head = ProjectionHead(embedding_dim=128, projection_dim=128)\n", | ||
"classification_head = ClassificationHead(projection_dim=128)\n", | ||
"\n", | ||
"# Binary Cross-Entropy Loss for binary classification\n", | ||
"loss_function = nn.BCELoss()\n", | ||
"\n", | ||
"# Assuming 'data_loader' is a PyTorch DataLoader that provides batches of images and labels for the downstream task\n", | ||
"optimizer = torch.optim.Adam(list(backbone.parameters()) + list(projection_head.parameters()) + list(classification_head.parameters()), lr=0.0001)\n", | ||
"\n", | ||
"num_epochs = 10\n", | ||
"# Training loop for the downstream task\n", | ||
"for epoch in range(num_epochs):\n", | ||
" for images, labels in triplet_dataloader:\n", | ||
" optimizer.zero_grad()\n", | ||
" embeddings = contrastive_head(backbone(images))\n", | ||
" projections = projection_head(embeddings)\n", | ||
" predictions = classification_head(projections).squeeze(1)\n", | ||
" loss = loss_function(predictions, labels.float())\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
"# Function to compute metrics\n", | ||
"def compute_metrics(y_true, y_pred):\n", | ||
" accuracy = accuracy_score(y_true, y_pred > 0.5)\n", | ||
" precision = precision_score(y_true, y_pred > 0.5)\n", | ||
" recall = recall_score(y_true, y_pred > 0.5)\n", | ||
" f1 = f1_score(y_true, y_pred > 0.5)\n", | ||
" auc = roc_auc_score(y_true, y_pred)\n", | ||
" return accuracy, precision, recall, f1, auc\n", | ||
"\n", | ||
"# Example evaluation on validation set\n", | ||
"# with torch.no_grad():\n", | ||
"# y_true = []\n", | ||
"# y_pred = []\n", | ||
"# for images, labels in validation_loader:\n", | ||
"# embeddings = contrastive_head(backbone(images))\n", | ||
"# projections = projection_head(embeddings)\n", | ||
"# predictions = classification_head(projections).squeeze(1)\n", | ||
"# y_true.extend(labels.numpy())\n", | ||
"# y_pred.extend(predictions.numpy())\n", | ||
"# metrics = compute_metrics(np.array(y_true), np.array(y_pred))\n", | ||
"# print(f\"Accuracy: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, F1: {metrics[3]}, AUC: {metrics[4]}\")\n" | ||
], | ||
"metadata": { | ||
"id": "9RKZ7MCux2wL" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |