Skip to content

Commit 08508b6

Browse files
authored
Merge branch 'master' into fix/no-grad-amp-bug
2 parents d22dbd1 + cb1afbe commit 08508b6

File tree

5 files changed

+95
-3
lines changed

5 files changed

+95
-3
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.10.0
66
omegaconf >=2.2.3, <2.4.0
77
hydra-core >=1.2.0, <1.4.0
8-
jsonargparse[signatures] >=4.39.0, <4.41.0
8+
jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0
99
rich >=12.3.0, <14.1.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,10 @@ def sanitize_parameters_to_prune(
458458

459459
if not parameters_to_prune:
460460
parameters_to_prune = [
461-
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
461+
(m, p)
462+
for p in parameters
463+
for m in current_modules
464+
if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter)
462465
]
463466
elif (
464467
isinstance(parameters_to_prune, (list, tuple))

src/lightning/pytorch/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,9 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
554554
def _dump_config(self) -> None:
555555
if hasattr(self, "config_dump"):
556556
return
557-
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
557+
self.config_dump = yaml.safe_load(
558+
self.parser.dump(self.config, skip_link_targets=False, skip_none=False, format="yaml")
559+
)
558560
if "subcommand" in self.config:
559561
self.config_dump = self.config_dump[self.config.subcommand]
560562

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,70 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
338338
assert not hasattr(model.layer.mlp_3, "weight_orig")
339339
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
340340
assert not hasattr(model.layer.mlp_3, "weight_orig")
341+
342+
343+
def test_sanitize_parameters_explicit_check():
344+
"""Test the sanitize_parameters_to_prune method with various attribute types."""
345+
346+
class TestModule(nn.Module):
347+
def __init__(self):
348+
super().__init__()
349+
self.weight = nn.Parameter(torch.randn(5, 5))
350+
self.bias = nn.Parameter(torch.randn(5))
351+
self.some_bool = True
352+
self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter
353+
self.some_string = "test"
354+
self.some_none = None
355+
356+
class TestModel(BoringModel):
357+
def __init__(self):
358+
super().__init__()
359+
self.test_module = TestModule()
360+
361+
model = TestModel()
362+
363+
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
364+
model,
365+
parameters_to_prune=(),
366+
parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"],
367+
)
368+
369+
param_names_found = set()
370+
for module, param_name in parameters_to_prune:
371+
param = getattr(module, param_name)
372+
assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}"
373+
param_names_found.add(param_name)
374+
375+
assert "weight" in param_names_found
376+
assert "bias" in param_names_found
377+
assert "some_bool" not in param_names_found
378+
assert "some_tensor" not in param_names_found
379+
assert "some_string" not in param_names_found
380+
assert "some_none" not in param_names_found
381+
382+
383+
def test_original_issue_reproduction():
384+
"""Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""
385+
386+
class ProblematicModel(BoringModel):
387+
def __init__(self):
388+
super().__init__()
389+
self.layer = Sequential(
390+
OrderedDict([
391+
("mlp_1", nn.Linear(32, 32)),
392+
("mlp_2", nn.Linear(32, 2)),
393+
])
394+
)
395+
# Add boolean attributes that would cause the original error
396+
self.layer.mlp_1.training = True
397+
self.layer.mlp_2.requires_grad = True
398+
399+
model = ProblematicModel()
400+
401+
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
402+
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
403+
)
404+
405+
for module, param_name in parameters_to_prune:
406+
param = getattr(module, param_name)
407+
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"

tests/tests_pytorch/test_cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,3 +1858,23 @@ def test_lightning_cli_with_args_given(args):
18581858
def test_lightning_cli_args_and_sys_argv_warning():
18591859
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "):
18601860
LightningCLI(TestModel, run=False, args=["--model.foo=789"])
1861+
1862+
1863+
def test_lightning_cli_jsonnet(cleandir):
1864+
class MainModule(BoringModel):
1865+
def __init__(self, main_param: int = 1):
1866+
super().__init__()
1867+
1868+
config = """{
1869+
"model":{
1870+
"main_param": 2
1871+
}
1872+
}"""
1873+
config_path = Path("config.jsonnet")
1874+
config_path.write_text(config)
1875+
1876+
cli_args = [f"--config={config_path}"]
1877+
with mock.patch("sys.argv", ["any.py"] + cli_args):
1878+
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})
1879+
1880+
assert cli.config["model"]["main_param"] == 2

0 commit comments

Comments
 (0)