Skip to content

Adopt pytorch-directml as accelerator #16870

Open
@wuzhican

Description

@wuzhican

Description & Motivation

i have tried to custome Accelerator as the code

class DML_Accelerator(Accelerator):
    def setup_device(self, device: torch.device) -> None:
        pass
    def setup(self, trainer: "pl.Trainer") -> None:
        pass
    def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
        return {}
    def teardown(self) -> None:
        pass
    @staticmethod
    def parse_devices(devices: Any) -> Any:
        return devices
    @staticmethod
    def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
        t = torch_directml.device
        return [torch_directml.device()]
    @staticmethod
    def auto_device_count() -> int:
        return torch_directml.device_count()
    @staticmethod
    def is_available() -> bool:
        return torch_directml.is_available()
    @classmethod
    def register_accelerators(cls, accelerator_registry):
        accelerator_registry.register(
            "dml",
            cls,
            description=f"GPU Accelerator - optimized for large-scale machine learning.",
        )

class test_model(pl.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self._model = densenet121() 
    def training_step(self, batch,batch_index, **kwargs: Any) -> Any:
        y = self._model(batch)
        return None
    def configure_optimizers(self) -> Any:
        return None
    
class testSet(Dataset):
    def __init__(self) -> None:
        super().__init__()
        
    def __getitem__(self, index) -> Any:
        return torch.rand(3,224,224)
    
    def __len__(self):
        return 100

if __name__ == '__main__':
    acc = DML_Accelerator()
    trainer = pl.Trainer(accelerator=DML_Accelerator(),max_epochs=20)
    loader = (
        DataLoader(testSet(),batch_size=12,num_workers=8,drop_last=True),
    )
    trainer.fit(test_model(),*loader)

but the model style use the device 'cpu',not the directml after i pass the DML_Accelerator to the trainer.the pytorch_drectml module can work successfully.

Pitch

No response

Alternatives

No response

Additional context

No response

cc @Borda @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    acceleratorfabriclightning.fabric.FabricfeatureIs an improvement or enhancementplGeneric label for PyTorch Lightning package

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions