Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
Devadeut committed Feb 15, 2024
1 parent ec28664 commit 163e1b4
Showing 1 changed file with 360 additions and 0 deletions.
360 changes: 360 additions & 0 deletions Fairness_Contrastive_learning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNsf5ThIjJGFtnkwBnuKQeZ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/Devadeut/Neural-Networks-Hands-On-Projects/blob/main/Fairness_Contrastive_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2yQBp_QQFW_Q",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b4b03da9-5301-4fdf-de8c-19e222cc0008"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting pydicom\n",
" Downloading pydicom-2.4.4-py3-none-any.whl (1.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: pydicom\n",
"Successfully installed pydicom-2.4.4\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchvision import models, transforms\n",
"from torchvision.models import resnet50\n",
"from torchvision.transforms.functional import to_pil_image\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n",
"!pip install pydicom\n",
"import pydicom\n",
"from PIL import Image\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"## DATA"
],
"metadata": {
"id": "790vZcpAo9Hl"
}
},
{
"cell_type": "code",
"source": [
"# @title Data Processing\n",
"\n",
"\n",
"# Define the transformation pipeline\n",
"transform_pipeline = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.RandomRotation(degrees=10),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
"])\n",
"\n",
"def dicom_to_jpg(dicom_path):\n",
" # Load the DICOM image\n",
" dicom_image = pydicom.dcmread(dicom_path)\n",
" image_array = dicom_image.pixel_array\n",
"\n",
" # Normalize to [0, 255]\n",
" image_array = (np.maximum(image_array, 0) / image_array.max()) * 255.0\n",
"\n",
" # Invert pixels if necessary\n",
" if dicom_image.PhotometricInterpretation == \"MONOCHROME1\":\n",
" image_array = 255.0 - image_array\n",
"\n",
" # Perform histogram equalization\n",
" image_eq = cv2.equalizeHist(image_array.astype(np.uint8))\n",
"\n",
" # Convert to PIL Image\n",
" pil_img = Image.fromarray(image_eq)\n",
"\n",
" # Save as JPG\n",
" pil_img.save(\"output.jpg\", \"JPEG\", quality=95)\n",
"\n",
" return pil_img\n",
"\n",
"def preprocess_and_augment(image_path):\n",
" # Convert DICOM to JPG if it's a DICOM file\n",
" if image_path.endswith('.dcm'):\n",
" image = dicom_to_jpg(image_path)\n",
" else:\n",
" image = Image.open(image_path)\n",
"\n",
" # Apply transformations\n",
" return transform_pipeline(image)\n",
"\n",
"# Example usage:\n",
"# processed_image = preprocess_and_augment('path_to_image.dcm')\n"
],
"metadata": {
"id": "GHqSgjLxdunW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title Dataloading\n",
"\n",
"\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"import random\n",
"\n",
"class TripletDataset(Dataset):\n",
" def __init__(self, image_paths, labels, transform=None):\n",
" # image_paths: List of paths to images\n",
" # labels: Dictionary mapping image paths to their labels\n",
" self.image_paths = image_paths\n",
" self.labels = labels\n",
" self.transform = transform\n",
" self.labels_set = set(labels.values())\n",
" self.labels_to_indices = {label: np.where(np.array(labels.values()) == label)[0]\n",
" for label in self.labels_set}\n",
"\n",
" def __getitem__(self, index):\n",
" anchor_path = self.image_paths[index]\n",
" anchor_label = self.labels[anchor_path]\n",
"\n",
" # Get a positive sample (same label, different image)\n",
" positive_index = index\n",
" while positive_index == index:\n",
" positive_index = random.choice(self.labels_to_indices[anchor_label])\n",
" positive_path = self.image_paths[positive_index]\n",
"\n",
" # Get a negative sample (different label)\n",
" negative_label = random.choice(list(self.labels_set - set([anchor_label])))\n",
" negative_index = random.choice(self.labels_to_indices[negative_label])\n",
" negative_path = self.image_paths[negative_index]\n",
"\n",
" # Load images and apply transformations\n",
" anchor_img = preprocess_and_augment(anchor_path)\n",
" positive_img = preprocess_and_augment(positive_path)\n",
" negative_img = preprocess_and_augment(negative_path)\n",
"\n",
" return anchor_img, positive_img, negative_img, anchor_label\n",
"\n",
" def __len__(self):\n",
" return len(self.image_paths)\n",
"\n",
"\n"
],
"metadata": {
"id": "4s80u_dngUVk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## MODEL"
],
"metadata": {
"id": "_jRJ_75yowFj"
}
},
{
"cell_type": "code",
"source": [
"# @title Backbone\n",
"# Model Backbone\n",
"class Backbone(nn.Module):\n",
" def __init__(self):\n",
" super(Backbone, self).__init__()\n",
" # Use a pre-trained model without the top layer\n",
" self.base_model = resnet50(pretrained=True)\n",
" self.base_model.fc = nn.Identity() # Remove the top layer\n",
"\n",
" def forward(self, x):\n",
" return self.base_model(x)"
],
"metadata": {
"cellView": "form",
"id": "sbjnxQNmGHbK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title Contrastive Head\n",
"class ContrastiveHead(nn.Module):\n",
" def __init__(self, feature_dim=1024, embedding_dim=128):\n",
" super(ContrastiveHead, self).__init__()\n",
" self.fc = nn.Linear(feature_dim, embedding_dim)\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)"
],
"metadata": {
"cellView": "form",
"id": "Ph-seiEoG2wc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title ContrastiveLoss\n",
"class ContrastiveLoss(nn.Module):\n",
" def __init__(self):\n",
" super(ContrastiveLoss, self).__init__()\n",
" self.temperature = temperature\n",
"\n",
" def forward(self, anchor, positives, negatives):\n",
" #compute the similarities\n",
" anchor_dot_positives = torch.matmul(anchor, positives.t()) / self.temperature\n",
" anchor_dot_negatives = torch.matmul(anchor, negatives.t()) / self.temperature\n",
"\n",
" # Compute the log-sum-exp of negatives for each anchor\n",
" negatives_logsumexp = torch.logsumexp(anchor_dot_negatives, dim=1)\n",
"\n",
" # Sum over all positives for each anchor, and average over all anchors in the batch\n",
" loss = 0\n",
" for i in range(anchor.size(0)):\n",
" for j in range(positives.size(0)):\n",
" loss -= anchor_dot_positives[i][j] - negatives_logsumexp[i]\n",
"\n",
" return loss.mean()\n"
],
"metadata": {
"id": "Db5zzE2LOK2l"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title Projection Head and Classification Head\n",
"class ProjectionHead(nn.Module):\n",
" def __init__(self, embedding_dim, projection_dim):\n",
" super(ProjectionHead, self).__init__()\n",
" self.fc = nn.Linear(embedding_dim, projection_dim)\n",
"\n",
" def forward(self, x):\n",
" return F.relu(self.fc(x))\n",
"\n",
"#\n",
"class ClassificationHead(nn.Module):\n",
" def __init__(self, projection_dim):\n",
" super(ClassificationHead, self).__init__()\n",
" self.fc = nn.Linear(projection_dim, 1)\n",
"\n",
" def forward(self, x):\n",
" return torch.sigmoid(self.fc(x))"
],
"metadata": {
"id": "649Mvyzy0zhX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title Model Training\n",
"# @markdown computes the loss and updates the model parameters.\n",
"\n",
"# Initialize the resnet50 backbone and the contrastive head\n",
"backbone = Backbone()\n",
"contrastive_head = ContrastiveHead()\n",
"\n",
"# Assuming `image_paths` is a list of image file paths and `labels` is a dict mapping image paths to labels\n",
"triplet_dataset = TripletDataset(image_paths=image_paths, labels=labels, transform=transform_pipeline)\n",
"triplet_dataloader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)\n",
"\n",
"# Adding the projection and classification heads to the model\n",
"projection_head = ProjectionHead(embedding_dim=128, projection_dim=128)\n",
"classification_head = ClassificationHead(projection_dim=128)\n",
"\n",
"# Binary Cross-Entropy Loss for binary classification\n",
"loss_function = nn.BCELoss()\n",
"\n",
"# Assuming 'data_loader' is a PyTorch DataLoader that provides batches of images and labels for the downstream task\n",
"optimizer = torch.optim.Adam(list(backbone.parameters()) + list(projection_head.parameters()) + list(classification_head.parameters()), lr=0.0001)\n",
"\n",
"num_epochs = 10\n",
"# Training loop for the downstream task\n",
"for epoch in range(num_epochs):\n",
" for images, labels in triplet_dataloader:\n",
" optimizer.zero_grad()\n",
" embeddings = contrastive_head(backbone(images))\n",
" projections = projection_head(embeddings)\n",
" predictions = classification_head(projections).squeeze(1)\n",
" loss = loss_function(predictions, labels.float())\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"# Function to compute metrics\n",
"def compute_metrics(y_true, y_pred):\n",
" accuracy = accuracy_score(y_true, y_pred > 0.5)\n",
" precision = precision_score(y_true, y_pred > 0.5)\n",
" recall = recall_score(y_true, y_pred > 0.5)\n",
" f1 = f1_score(y_true, y_pred > 0.5)\n",
" auc = roc_auc_score(y_true, y_pred)\n",
" return accuracy, precision, recall, f1, auc\n",
"\n",
"# Example evaluation on validation set\n",
"# with torch.no_grad():\n",
"# y_true = []\n",
"# y_pred = []\n",
"# for images, labels in validation_loader:\n",
"# embeddings = contrastive_head(backbone(images))\n",
"# projections = projection_head(embeddings)\n",
"# predictions = classification_head(projections).squeeze(1)\n",
"# y_true.extend(labels.numpy())\n",
"# y_pred.extend(predictions.numpy())\n",
"# metrics = compute_metrics(np.array(y_true), np.array(y_pred))\n",
"# print(f\"Accuracy: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, F1: {metrics[3]}, AUC: {metrics[4]}\")\n"
],
"metadata": {
"id": "9RKZ7MCux2wL"
},
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit 163e1b4

Please sign in to comment.