Skip to content

Commit f26c84c

Browse files
committed
added
1 parent d37fecd commit f26c84c

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

posts/2024-forecast.ipynb

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,11 +539,75 @@
539539
},
540540
{
541541
"cell_type": "code",
542-
"execution_count": null,
542+
"execution_count": 114,
543543
"id": "a06840b9",
544544
"metadata": {},
545545
"outputs": [],
546-
"source": []
546+
"source": [
547+
"from transformers import InformerForPrediction, InformerConfig\n",
548+
"\n",
549+
"config = InformerConfig(\n",
550+
" input_dim=context_length,\n",
551+
" prediction_length=prediction_length,\n",
552+
" num_heads=4,\n",
553+
" encoder_layers=2,\n",
554+
" decoder_layers=2,\n",
555+
" use_mask=True,\n",
556+
" forecast=True\n",
557+
")\n",
558+
"\n",
559+
"informer = InformerForPrediction(config)"
560+
]
561+
},
562+
{
563+
"cell_type": "code",
564+
"execution_count": 116,
565+
"id": "5465e218",
566+
"metadata": {},
567+
"outputs": [
568+
{
569+
"data": {
570+
"text/plain": [
571+
"134531"
572+
]
573+
},
574+
"execution_count": 116,
575+
"metadata": {},
576+
"output_type": "execute_result"
577+
}
578+
],
579+
"source": [
580+
"sum(p.numel() for p in informer.parameters())"
581+
]
582+
},
583+
{
584+
"cell_type": "code",
585+
"execution_count": 117,
586+
"id": "4f46f6a0",
587+
"metadata": {},
588+
"outputs": [
589+
{
590+
"ename": "TypeError",
591+
"evalue": "forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'",
592+
"output_type": "error",
593+
"traceback": [
594+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
595+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
596+
"Cell \u001b[0;32mIn[117], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_epochs):\n\u001b[1;32m 6\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m----> 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43minformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(y_pred, ys)\n\u001b[1;32m 9\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
597+
"File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
598+
"\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'"
599+
]
600+
}
601+
],
602+
"source": [
603+
"# training loop\n",
604+
"n_epochs = 100\n",
605+
"\n",
606+
"start_time = time.time()\n",
607+
"for epoch in range(n_epochs):\n",
608+
" optimizer.zero_grad()\n",
609+
" y_pred = informer(Xs, past\n"
610+
]
547611
}
548612
],
549613
"metadata": {

0 commit comments

Comments
 (0)