From 876c8fc15b3e9d784da15f916642318c0f81c6db Mon Sep 17 00:00:00 2001
From: pytorchbot <pytorchbot@pytorch.com>
Date: Tue, 10 Dec 2024 11:35:06 +0000
Subject: [PATCH] 2024-12-10 nightly release
 (fad795e27dc2df84efe95e903c7a6585066889d3)

---
 .github/workflows/pyre.yml             |  2 +-
 .github/workflows/unittest_ci.yml      | 10 ++++++----
 .github/workflows/unittest_ci_cpu.yml  |  8 +++++---
 torchrec/distributed/model_parallel.py |  4 ++++
 4 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml
index be99ff62c..ed25404e7 100644
--- a/.github/workflows/pyre.yml
+++ b/.github/workflows/pyre.yml
@@ -19,7 +19,7 @@ jobs:
         uses: actions/checkout@v2
       - name: Install dependencies
         run: >
-          conda install --yes pytorch cpuonly -c pytorch-nightly &&
+          pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu &&
           pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cpu &&
           pip install -r requirements.txt &&
           pip install pyre-check-nightly==$(cat .pyre_configuration | grep version | awk '{print $2}' | sed 's/\"//g')
diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml
index 8865acee4..0b3fe0638 100644
--- a/.github/workflows/unittest_ci.yml
+++ b/.github/workflows/unittest_ci.yml
@@ -73,13 +73,15 @@ jobs:
         conda info
         python --version
         conda run -n build_binary python --version
-        conda install -n build_binary \
-          --yes \
-          pytorch pytorch-cuda=11.8 -c pytorch-nightly -c nvidia
+        conda run -n build_binary \
+          pip install torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
+        conda run -n build_binary \
+          python -c "import torch"
+        echo "torch succeeded"
         conda run -n build_binary \
           python -c "import torch.distributed"
         conda run -n build_binary \
-          pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu118
+          pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/${{ matrix.cuda-tag }}
         conda run -n build_binary \
           python -c "import fbgemm_gpu"
         echo "fbgemm_gpu succeeded"
diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml
index 2861029e1..1efe64178 100644
--- a/.github/workflows/unittest_ci_cpu.yml
+++ b/.github/workflows/unittest_ci_cpu.yml
@@ -45,9 +45,11 @@ jobs:
         conda info
         python --version
         conda run -n build_binary python --version
-        conda install -n build_binary \
-          --yes \
-          pytorch cpuonly -c pytorch-nightly
+        conda run -n build_binary \
+          pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+        conda run -n build_binary \
+          python -c "import torch"
+        echo "torch succeeded"
         conda run -n build_binary \
           python -c "import torch.distributed"
         conda run -n build_binary \
diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py
index 0f60362b6..5cbd2429b 100644
--- a/torchrec/distributed/model_parallel.py
+++ b/torchrec/distributed/model_parallel.py
@@ -746,6 +746,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
         all_weights = [
             w
             for emb_kernel in self._modules_to_sync
+            # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
             for w in emb_kernel.split_embedding_weights()
         ]
         handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts)
@@ -755,6 +756,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
             # Sync accumulated square of grad of local optimizer shards
             optim_list = []
             for emb_kernel in self._modules_to_sync:
+                # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
                 all_optimizer_states = emb_kernel.get_optimizer_state()
                 momentum1 = [optim["sum"] for optim in all_optimizer_states]
                 optim_list.extend(momentum1)
@@ -864,6 +866,8 @@ def _find_sharded_modules(
             if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen):
                 sharded_modules.append(module)
             if hasattr(module, "_lookups"):
+                # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is
+                #  not a function.
                 for lookup in module._lookups:
                     _find_sharded_modules(lookup)
                 return