Skip to content

Commit

Permalink
added
Browse files Browse the repository at this point in the history
  • Loading branch information
nipunbatra committed May 30, 2024
1 parent d37fecd commit f26c84c
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions posts/2024-forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,75 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 114,
"id": "a06840b9",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from transformers import InformerForPrediction, InformerConfig\n",
"\n",
"config = InformerConfig(\n",
" input_dim=context_length,\n",
" prediction_length=prediction_length,\n",
" num_heads=4,\n",
" encoder_layers=2,\n",
" decoder_layers=2,\n",
" use_mask=True,\n",
" forecast=True\n",
")\n",
"\n",
"informer = InformerForPrediction(config)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "5465e218",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"134531"
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(p.numel() for p in informer.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "4f46f6a0",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[117], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_epochs):\n\u001b[1;32m 6\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m----> 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43minformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(y_pred, ys)\n\u001b[1;32m 9\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
"File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'"
]
}
],
"source": [
"# training loop\n",
"n_epochs = 100\n",
"\n",
"start_time = time.time()\n",
"for epoch in range(n_epochs):\n",
" optimizer.zero_grad()\n",
" y_pred = informer(Xs, past\n"
]
}
],
"metadata": {
Expand Down

0 comments on commit f26c84c

Please sign in to comment.