|
539 | 539 | },
|
540 | 540 | {
|
541 | 541 | "cell_type": "code",
|
542 |
| - "execution_count": null, |
| 542 | + "execution_count": 114, |
543 | 543 | "id": "a06840b9",
|
544 | 544 | "metadata": {},
|
545 | 545 | "outputs": [],
|
546 |
| - "source": [] |
| 546 | + "source": [ |
| 547 | + "from transformers import InformerForPrediction, InformerConfig\n", |
| 548 | + "\n", |
| 549 | + "config = InformerConfig(\n", |
| 550 | + " input_dim=context_length,\n", |
| 551 | + " prediction_length=prediction_length,\n", |
| 552 | + " num_heads=4,\n", |
| 553 | + " encoder_layers=2,\n", |
| 554 | + " decoder_layers=2,\n", |
| 555 | + " use_mask=True,\n", |
| 556 | + " forecast=True\n", |
| 557 | + ")\n", |
| 558 | + "\n", |
| 559 | + "informer = InformerForPrediction(config)" |
| 560 | + ] |
| 561 | + }, |
| 562 | + { |
| 563 | + "cell_type": "code", |
| 564 | + "execution_count": 116, |
| 565 | + "id": "5465e218", |
| 566 | + "metadata": {}, |
| 567 | + "outputs": [ |
| 568 | + { |
| 569 | + "data": { |
| 570 | + "text/plain": [ |
| 571 | + "134531" |
| 572 | + ] |
| 573 | + }, |
| 574 | + "execution_count": 116, |
| 575 | + "metadata": {}, |
| 576 | + "output_type": "execute_result" |
| 577 | + } |
| 578 | + ], |
| 579 | + "source": [ |
| 580 | + "sum(p.numel() for p in informer.parameters())" |
| 581 | + ] |
| 582 | + }, |
| 583 | + { |
| 584 | + "cell_type": "code", |
| 585 | + "execution_count": 117, |
| 586 | + "id": "4f46f6a0", |
| 587 | + "metadata": {}, |
| 588 | + "outputs": [ |
| 589 | + { |
| 590 | + "ename": "TypeError", |
| 591 | + "evalue": "forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'", |
| 592 | + "output_type": "error", |
| 593 | + "traceback": [ |
| 594 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 595 | + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", |
| 596 | + "Cell \u001b[0;32mIn[117], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_epochs):\n\u001b[1;32m 6\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m----> 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43minformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(y_pred, ys)\n\u001b[1;32m 9\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", |
| 597 | + "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", |
| 598 | + "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'" |
| 599 | + ] |
| 600 | + } |
| 601 | + ], |
| 602 | + "source": [ |
| 603 | + "# training loop\n", |
| 604 | + "n_epochs = 100\n", |
| 605 | + "\n", |
| 606 | + "start_time = time.time()\n", |
| 607 | + "for epoch in range(n_epochs):\n", |
| 608 | + " optimizer.zero_grad()\n", |
| 609 | + " y_pred = informer(Xs, past\n" |
| 610 | + ] |
547 | 611 | }
|
548 | 612 | ],
|
549 | 613 | "metadata": {
|
|
0 commit comments