Skip to content

Commit

Permalink
Revert "Pin PT version: Fix FPX Inductor error" (#843)
Browse files Browse the repository at this point in the history
* Revert "Pin PT version: Fix FPX Inductor error (#790)"

This reverts commit 287458c.

* udpates

* yolo

* yolo

* yolo

* yolo
  • Loading branch information
msaroufim authored and jainapurva committed Sep 10, 2024
1 parent 0802f16 commit f65ad6d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ jobs:
torch-spec: 'torch==2.4.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly (Aug 29, 2024)
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.5.0.dev20240829+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

Expand All @@ -57,9 +57,9 @@ jobs:
torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: CPU Nightly (Aug 29, 2024)
- name: CPU Nightly
runs-on: linux.4xlarge
torch-spec: '--pre torch==2.5.0.dev20240829+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""

Expand Down
3 changes: 2 additions & 1 deletion test/dtypes/test_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchao.prototype.dtypes import BitnetTensor
from torchao.prototype.dtypes.uint2 import unpack_uint2
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -58,6 +58,7 @@ def fn(mod):
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies")
@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]])
def test_uint2_quant(input_shape):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,9 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
self.skipTest("Regression introduced in PT nightlies")

undo_recommended_configs()
self._test_lin_weight_subclass_api_impl(
_int8wo_api, device, 40, test_dtype=dtype
Expand All @@ -826,6 +829,9 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype):
@torch._inductor.config.patch({"freezing": True})
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after.")
def test_int8_weight_only_quant_with_freeze(self, device, dtype):
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
self.skipTest("Regression introduced in PT nightlies")

self._test_lin_weight_subclass_api_impl(
_int8wo_api, device, 40, test_dtype=dtype
)
Expand Down Expand Up @@ -1039,7 +1045,10 @@ def test_save_load_dqtensors(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_save_load_int8woqtensors(self, device, dtype):
def test_save_load_int8woqtensors(self, device, dtype):
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
self.skipTest(f"Regression introduced in PT nightlies")

undo_recommended_configs()
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)

Expand Down

0 comments on commit f65ad6d

Please sign in to comment.