diff --git a/Fairness_Contrastive_learning.ipynb b/Fairness_Contrastive_learning.ipynb new file mode 100644 index 0000000..2c0130d --- /dev/null +++ b/Fairness_Contrastive_learning.ipynb @@ -0,0 +1,360 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNsf5ThIjJGFtnkwBnuKQeZ", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2yQBp_QQFW_Q", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b4b03da9-5301-4fdf-de8c-19e222cc0008" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting pydicom\n", + " Downloading pydicom-2.4.4-py3-none-any.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: pydicom\n", + "Successfully installed pydicom-2.4.4\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torchvision import models, transforms\n", + "from torchvision.models import resnet50\n", + "from torchvision.transforms.functional import to_pil_image\n", + "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n", + "!pip install pydicom\n", + "import pydicom\n", + "from PIL import Image\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## DATA" + ], + "metadata": { + "id": "790vZcpAo9Hl" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Data Processing\n", + "\n", + "\n", + "# Define the transformation pipeline\n", + "transform_pipeline = transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.RandomRotation(degrees=10),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + "])\n", + "\n", + "def dicom_to_jpg(dicom_path):\n", + " # Load the DICOM image\n", + " dicom_image = pydicom.dcmread(dicom_path)\n", + " image_array = dicom_image.pixel_array\n", + "\n", + " # Normalize to [0, 255]\n", + " image_array = (np.maximum(image_array, 0) / image_array.max()) * 255.0\n", + "\n", + " # Invert pixels if necessary\n", + " if dicom_image.PhotometricInterpretation == \"MONOCHROME1\":\n", + " image_array = 255.0 - image_array\n", + "\n", + " # Perform histogram equalization\n", + " image_eq = cv2.equalizeHist(image_array.astype(np.uint8))\n", + "\n", + " # Convert to PIL Image\n", + " pil_img = Image.fromarray(image_eq)\n", + "\n", + " # Save as JPG\n", + " pil_img.save(\"output.jpg\", \"JPEG\", quality=95)\n", + "\n", + " return pil_img\n", + "\n", + "def preprocess_and_augment(image_path):\n", + " # Convert DICOM to JPG if it's a DICOM file\n", + " if image_path.endswith('.dcm'):\n", + " image = dicom_to_jpg(image_path)\n", + " else:\n", + " image = Image.open(image_path)\n", + "\n", + " # Apply transformations\n", + " return transform_pipeline(image)\n", + "\n", + "# Example usage:\n", + "# processed_image = preprocess_and_augment('path_to_image.dcm')\n" + ], + "metadata": { + "id": "GHqSgjLxdunW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Dataloading\n", + "\n", + "\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "import random\n", + "\n", + "class TripletDataset(Dataset):\n", + " def __init__(self, image_paths, labels, transform=None):\n", + " # image_paths: List of paths to images\n", + " # labels: Dictionary mapping image paths to their labels\n", + " self.image_paths = image_paths\n", + " self.labels = labels\n", + " self.transform = transform\n", + " self.labels_set = set(labels.values())\n", + " self.labels_to_indices = {label: np.where(np.array(labels.values()) == label)[0]\n", + " for label in self.labels_set}\n", + "\n", + " def __getitem__(self, index):\n", + " anchor_path = self.image_paths[index]\n", + " anchor_label = self.labels[anchor_path]\n", + "\n", + " # Get a positive sample (same label, different image)\n", + " positive_index = index\n", + " while positive_index == index:\n", + " positive_index = random.choice(self.labels_to_indices[anchor_label])\n", + " positive_path = self.image_paths[positive_index]\n", + "\n", + " # Get a negative sample (different label)\n", + " negative_label = random.choice(list(self.labels_set - set([anchor_label])))\n", + " negative_index = random.choice(self.labels_to_indices[negative_label])\n", + " negative_path = self.image_paths[negative_index]\n", + "\n", + " # Load images and apply transformations\n", + " anchor_img = preprocess_and_augment(anchor_path)\n", + " positive_img = preprocess_and_augment(positive_path)\n", + " negative_img = preprocess_and_augment(negative_path)\n", + "\n", + " return anchor_img, positive_img, negative_img, anchor_label\n", + "\n", + " def __len__(self):\n", + " return len(self.image_paths)\n", + "\n", + "\n" + ], + "metadata": { + "id": "4s80u_dngUVk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## MODEL" + ], + "metadata": { + "id": "_jRJ_75yowFj" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Backbone\n", + "# Model Backbone\n", + "class Backbone(nn.Module):\n", + " def __init__(self):\n", + " super(Backbone, self).__init__()\n", + " # Use a pre-trained model without the top layer\n", + " self.base_model = resnet50(pretrained=True)\n", + " self.base_model.fc = nn.Identity() # Remove the top layer\n", + "\n", + " def forward(self, x):\n", + " return self.base_model(x)" + ], + "metadata": { + "cellView": "form", + "id": "sbjnxQNmGHbK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Contrastive Head\n", + "class ContrastiveHead(nn.Module):\n", + " def __init__(self, feature_dim=1024, embedding_dim=128):\n", + " super(ContrastiveHead, self).__init__()\n", + " self.fc = nn.Linear(feature_dim, embedding_dim)\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)" + ], + "metadata": { + "cellView": "form", + "id": "Ph-seiEoG2wc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title ContrastiveLoss\n", + "class ContrastiveLoss(nn.Module):\n", + " def __init__(self):\n", + " super(ContrastiveLoss, self).__init__()\n", + " self.temperature = temperature\n", + "\n", + " def forward(self, anchor, positives, negatives):\n", + " #compute the similarities\n", + " anchor_dot_positives = torch.matmul(anchor, positives.t()) / self.temperature\n", + " anchor_dot_negatives = torch.matmul(anchor, negatives.t()) / self.temperature\n", + "\n", + " # Compute the log-sum-exp of negatives for each anchor\n", + " negatives_logsumexp = torch.logsumexp(anchor_dot_negatives, dim=1)\n", + "\n", + " # Sum over all positives for each anchor, and average over all anchors in the batch\n", + " loss = 0\n", + " for i in range(anchor.size(0)):\n", + " for j in range(positives.size(0)):\n", + " loss -= anchor_dot_positives[i][j] - negatives_logsumexp[i]\n", + "\n", + " return loss.mean()\n" + ], + "metadata": { + "id": "Db5zzE2LOK2l" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Projection Head and Classification Head\n", + "class ProjectionHead(nn.Module):\n", + " def __init__(self, embedding_dim, projection_dim):\n", + " super(ProjectionHead, self).__init__()\n", + " self.fc = nn.Linear(embedding_dim, projection_dim)\n", + "\n", + " def forward(self, x):\n", + " return F.relu(self.fc(x))\n", + "\n", + "#\n", + "class ClassificationHead(nn.Module):\n", + " def __init__(self, projection_dim):\n", + " super(ClassificationHead, self).__init__()\n", + " self.fc = nn.Linear(projection_dim, 1)\n", + "\n", + " def forward(self, x):\n", + " return torch.sigmoid(self.fc(x))" + ], + "metadata": { + "id": "649Mvyzy0zhX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Model Training\n", + "# @markdown computes the loss and updates the model parameters.\n", + "\n", + "# Initialize the resnet50 backbone and the contrastive head\n", + "backbone = Backbone()\n", + "contrastive_head = ContrastiveHead()\n", + "\n", + "# Assuming `image_paths` is a list of image file paths and `labels` is a dict mapping image paths to labels\n", + "triplet_dataset = TripletDataset(image_paths=image_paths, labels=labels, transform=transform_pipeline)\n", + "triplet_dataloader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)\n", + "\n", + "# Adding the projection and classification heads to the model\n", + "projection_head = ProjectionHead(embedding_dim=128, projection_dim=128)\n", + "classification_head = ClassificationHead(projection_dim=128)\n", + "\n", + "# Binary Cross-Entropy Loss for binary classification\n", + "loss_function = nn.BCELoss()\n", + "\n", + "# Assuming 'data_loader' is a PyTorch DataLoader that provides batches of images and labels for the downstream task\n", + "optimizer = torch.optim.Adam(list(backbone.parameters()) + list(projection_head.parameters()) + list(classification_head.parameters()), lr=0.0001)\n", + "\n", + "num_epochs = 10\n", + "# Training loop for the downstream task\n", + "for epoch in range(num_epochs):\n", + " for images, labels in triplet_dataloader:\n", + " optimizer.zero_grad()\n", + " embeddings = contrastive_head(backbone(images))\n", + " projections = projection_head(embeddings)\n", + " predictions = classification_head(projections).squeeze(1)\n", + " loss = loss_function(predictions, labels.float())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "# Function to compute metrics\n", + "def compute_metrics(y_true, y_pred):\n", + " accuracy = accuracy_score(y_true, y_pred > 0.5)\n", + " precision = precision_score(y_true, y_pred > 0.5)\n", + " recall = recall_score(y_true, y_pred > 0.5)\n", + " f1 = f1_score(y_true, y_pred > 0.5)\n", + " auc = roc_auc_score(y_true, y_pred)\n", + " return accuracy, precision, recall, f1, auc\n", + "\n", + "# Example evaluation on validation set\n", + "# with torch.no_grad():\n", + "# y_true = []\n", + "# y_pred = []\n", + "# for images, labels in validation_loader:\n", + "# embeddings = contrastive_head(backbone(images))\n", + "# projections = projection_head(embeddings)\n", + "# predictions = classification_head(projections).squeeze(1)\n", + "# y_true.extend(labels.numpy())\n", + "# y_pred.extend(predictions.numpy())\n", + "# metrics = compute_metrics(np.array(y_true), np.array(y_pred))\n", + "# print(f\"Accuracy: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, F1: {metrics[3]}, AUC: {metrics[4]}\")\n" + ], + "metadata": { + "id": "9RKZ7MCux2wL" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file