Skip to content

Commit

Permalink
Evidential DNA BERT encoder example
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Jan 3, 2024
1 parent 2832f6e commit 6d80373
Showing 1 changed file with 346 additions and 0 deletions.
346 changes: 346 additions & 0 deletions notebooks/dna_example_torch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"id": "5573a4bc-5463-40d8-a75a-6aaef8fcd06c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from transformers import BertModel, BertTokenizer\n",
"\n",
"from mlguess.torch.class_losses import relu_evidence"
]
},
{
"cell_type": "markdown",
"id": "516cbd75-5da1-435b-b995-320004fa63ca",
"metadata": {},
"source": [
"### Example usage for K-class problem"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "b2c4e448-b5ec-4355-bf8f-5cc300cc8bd4",
"metadata": {},
"outputs": [],
"source": [
"class DNABert(nn.Module):\n",
" def __init__(self, n_classes):\n",
" super(DNABert, self).__init__()\n",
" self.n_classes = n_classes\n",
" self.bert = BertModel.from_pretrained('bert-base-uncased')\n",
" self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)\n",
"\n",
" def forward(self, input_ids, attention_mask, token_type_ids=None):\n",
" outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
" \n",
" # note how we only take one hidden state from the sequeunce, which corresponds with the CLS token\n",
" cls_hidden_state = outputs.last_hidden_state[:, 0, :]\n",
" \n",
" out = self.fc(cls_hidden_state)\n",
" return out\n",
" \n",
" def predict_uncertainty(self, input_ids, attention_mask, token_type_ids=None):\n",
" y_pred = self(input_ids, attention_mask, token_type_ids)\n",
" \n",
" # dempster-shafer theory\n",
" evidence = relu_evidence(outputs) # can also try softplus and exp evidence schemes\n",
" alpha = evidence + 1\n",
" S = torch.sum(alpha, dim=1, keepdim=True)\n",
" u = self.n_classes / S\n",
" prob = alpha / S\n",
" \n",
" # law of total uncertainty \n",
" epistemic = prob * (1 - prob) / (S + 1)\n",
" aleatoric = prob - prob**2 - epistemic\n",
" return prob, u, aleatoric, epistemic"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "f642e42e-3083-454b-9720-71a78eb04061",
"metadata": {},
"outputs": [],
"source": [
"# Initialize the model\n",
"num_classes = 10\n",
"\n",
"model = DNABert(n_classes=num_classes)\n",
"\n",
"dna_sequence = \"AGCTAGCTAGCT\"\n",
"\n",
"# We need to convert the DNA sequence to the format expected by BERT\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
"inputs = tokenizer(dna_sequence, return_tensors='pt')\n",
"\n",
"# Forward pass through the model\n",
"outputs = model(**inputs)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "dc73cdee-d1f0-409d-bd01-57bfdc80cf46",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 101, 12943, 25572, 18195, 15900, 6593, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "5a6e6872-28a6-4ac3-b9d9-504792251f34",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.7508, -0.6081, -0.0026, -0.0115, 0.1004, 0.1924, -0.4315, -0.0052,\n",
" 0.0900, 0.8016]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "439bc25b-328f-4cd1-a5aa-bc5f9909da7d",
"metadata": {},
"outputs": [],
"source": [
"prob, u, aleatoric, epistemic = model.predict_uncertainty(**inputs)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "5d8e5517-a7bb-4c9c-824b-a69e2ecdbca8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0894, 0.0894, 0.0894, 0.0894, 0.0984, 0.1066, 0.0894, 0.0894, 0.0975,\n",
" 0.1611]], grad_fn=<DivBackward0>)"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "a886205b-5f16-4255-a6d0-f947e9c17395",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.8941]], grad_fn=<MulBackward0>)"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"u"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "7bf79791-c511-4c6d-960c-b100794181f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0747, 0.0747, 0.0747, 0.0747, 0.0814, 0.0874, 0.0747, 0.0747, 0.0807,\n",
" 0.1240]], grad_fn=<SubBackward0>)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"aleatoric"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "27842698-60a9-419e-82d2-36b5ee341983",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0067, 0.0067, 0.0067, 0.0067, 0.0073, 0.0078, 0.0067, 0.0067, 0.0072,\n",
" 0.0111]], grad_fn=<DivBackward0>)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"epistemic"
]
},
{
"cell_type": "markdown",
"id": "e8ba8eec-26df-4c24-84a0-141d1caaa28b",
"metadata": {},
"source": [
"### Evidential loss"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "cffcfe27-be85-4948-8ba0-f10502e2f108",
"metadata": {},
"outputs": [],
"source": [
"from mlguess.torch.class_losses import edl_digamma_loss, edl_log_loss, edl_mse_loss"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "ed550265-225f-4720-af09-1be5a86a8aed",
"metadata": {},
"outputs": [],
"source": [
"loss = \"digamma\"\n",
"annealing_coefficient = 10.\n",
"epoch = 0\n",
"device = \"cpu\""
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "fee00e6d-be82-40d5-a555-ed2417bca50f",
"metadata": {},
"outputs": [],
"source": [
"if loss == \"digamma\":\n",
" criterion = edl_digamma_loss\n",
"elif loss == \"log\":\n",
" criterion = edl_log_loss\n",
"elif loss == \"mse\":\n",
" criterion = edl_mse_loss\n",
"else:\n",
" logging.error(\"--uncertainty requires --mse, --log or --digamma.\")"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "69ccc5a0-a560-40d5-b9c2-6c303c8ba238",
"metadata": {},
"outputs": [],
"source": [
"y_true_hot = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n",
"\n",
"loss = criterion(\n",
" outputs,\n",
" y_true_hot.float(), \n",
" epoch, \n",
" num_classes, \n",
" annealing_coefficient, \n",
" device\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "2b902b09-1047-45d7-a233-2b48ed026481",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.8403, grad_fn=<MeanBackward0>)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bb75bb5-a005-4c44-8c4e-65026a13cbfc",
"metadata": {},
"outputs": [],
"source": [
"# loss.backward"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:miniconda3-evidential]",
"language": "python",
"name": "conda-env-miniconda3-evidential-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 6d80373

Please sign in to comment.