diff --git a/notebooks/dna_example_torch.ipynb b/notebooks/dna_example_torch.ipynb new file mode 100644 index 0000000..e22afdc --- /dev/null +++ b/notebooks/dna_example_torch.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "5573a4bc-5463-40d8-a75a-6aaef8fcd06c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from transformers import BertModel, BertTokenizer\n", + "\n", + "from mlguess.torch.class_losses import relu_evidence" + ] + }, + { + "cell_type": "markdown", + "id": "516cbd75-5da1-435b-b995-320004fa63ca", + "metadata": {}, + "source": [ + "### Example usage for K-class problem" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "b2c4e448-b5ec-4355-bf8f-5cc300cc8bd4", + "metadata": {}, + "outputs": [], + "source": [ + "class DNABert(nn.Module):\n", + " def __init__(self, n_classes):\n", + " super(DNABert, self).__init__()\n", + " self.n_classes = n_classes\n", + " self.bert = BertModel.from_pretrained('bert-base-uncased')\n", + " self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)\n", + "\n", + " def forward(self, input_ids, attention_mask, token_type_ids=None):\n", + " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", + " \n", + " # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token\n", + " cls_hidden_state = outputs.last_hidden_state[:, 0, :]\n", + " \n", + " out = self.fc(cls_hidden_state)\n", + " return out\n", + " \n", + " def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None):\n", + " y_pred = self(input_ids, attention_mask, token_type_ids)\n", + " \n", + " # dempster-shafer theory\n", + " evidence = relu_evidence(outputs) # can also try softplus and exp evidence schemes\n", + " alpha = evidence + 1\n", + " S = torch.sum(alpha, dim=1, keepdim=True)\n", + " u = self.n_classes / S\n", + " prob = alpha / S\n", + " \n", + " # law of total uncertainty \n", + " epistemic = prob * (1 - prob) / (S + 1)\n", + " aleatoric = prob - prob**2 - epistemic\n", + " return prob, u, aleatoric, epistemic" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "f642e42e-3083-454b-9720-71a78eb04061", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the model\n", + "num_classes = 10\n", + "\n", + "model = DNABert(n_classes=num_classes)\n", + "\n", + "dna_sequence = \"AGCTAGCTAGCT\"\n", + "\n", + "# We need to convert the DNA sequence to the format expected by BERT\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", + "inputs = tokenizer(dna_sequence, return_tensors='pt')\n", + "\n", + "# Forward pass through the model\n", + "outputs = model(**inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "dc73cdee-d1f0-409d-bd01-57bfdc80cf46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[ 101, 12943, 25572, 18195, 15900, 6593, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "5a6e6872-28a6-4ac3-b9d9-504792251f34", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.7508, -0.6081, -0.0026, -0.0115, 0.1004, 0.1924, -0.4315, -0.0052,\n", + " 0.0900, 0.8016]], grad_fn=)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "439bc25b-328f-4cd1-a5aa-bc5f9909da7d", + "metadata": {}, + "outputs": [], + "source": [ + "prob, u, aleatoric, epistemic = model.predict_uncertainty(**inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "5d8e5517-a7bb-4c9c-824b-a69e2ecdbca8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0894, 0.0894, 0.0894, 0.0894, 0.0984, 0.1066, 0.0894, 0.0894, 0.0975,\n", + " 0.1611]], grad_fn=)" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prob" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "a886205b-5f16-4255-a6d0-f947e9c17395", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.8941]], grad_fn=)" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "u" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "7bf79791-c511-4c6d-960c-b100794181f4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0747, 0.0747, 0.0747, 0.0747, 0.0814, 0.0874, 0.0747, 0.0747, 0.0807,\n", + " 0.1240]], grad_fn=)" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aleatoric" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "27842698-60a9-419e-82d2-36b5ee341983", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0067, 0.0067, 0.0067, 0.0067, 0.0073, 0.0078, 0.0067, 0.0067, 0.0072,\n", + " 0.0111]], grad_fn=)" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "epistemic" + ] + }, + { + "cell_type": "markdown", + "id": "e8ba8eec-26df-4c24-84a0-141d1caaa28b", + "metadata": {}, + "source": [ + "### Evidential loss" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "cffcfe27-be85-4948-8ba0-f10502e2f108", + "metadata": {}, + "outputs": [], + "source": [ + "from mlguess.torch.class_losses import edl_digamma_loss, edl_log_loss, edl_mse_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "ed550265-225f-4720-af09-1be5a86a8aed", + "metadata": {}, + "outputs": [], + "source": [ + "loss = \"digamma\"\n", + "annealing_coefficient = 10.\n", + "epoch = 0\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "fee00e6d-be82-40d5-a555-ed2417bca50f", + "metadata": {}, + "outputs": [], + "source": [ + "if loss == \"digamma\":\n", + " criterion = edl_digamma_loss\n", + "elif loss == \"log\":\n", + " criterion = edl_log_loss\n", + "elif loss == \"mse\":\n", + " criterion = edl_mse_loss\n", + "else:\n", + " logging.error(\"--uncertainty requires --mse, --log or --digamma.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "69ccc5a0-a560-40d5-b9c2-6c303c8ba238", + "metadata": {}, + "outputs": [], + "source": [ + "y_true_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n", + "\n", + "loss = criterion(\n", + " outputs,\n", + " y_true_hot.float(), \n", + " epoch, \n", + " num_classes, \n", + " annealing_coefficient, \n", + " device\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "2b902b09-1047-45d7-a233-2b48ed026481", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(2.8403, grad_fn=)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bb75bb5-a005-4c44-8c4e-65026a13cbfc", + "metadata": {}, + "outputs": [], + "source": [ + "# loss.backward" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:miniconda3-evidential]", + "language": "python", + "name": "conda-env-miniconda3-evidential-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}