- numpy
- pynvml
- pytorch_lightning
- torch
- torchinfo
- rich
假设已经完成了PyTorch
和Pytorch_lightning
安装!
pip install pynvml torchinfo
将Wandb
传入的str
根据指定分隔符转译为指定数据类型的list
。
str_params
:str
,Wandb
输入的字符串。params_type
:str
,转译后的数据类型。split_marker
:str
,default:#
,分隔符。
list<eval('params_type')>
,使用指定数据类型的数组。
为PyTorch
和Pytorch_lightning
设置固定随机种子。
random_seed
:float / int
,指定的随机种子。
null
使用torchinfo
的summary
函数预估在Float32
格式下训练模型所需的显存占用(GB)。
model
:torch.nn.Module
,指定的模型。input_shape
:list[...]
,与summary
输入格式一致的模型输入数据形状。
float
,显存占用(GB)。
仅适用于单卡训练,根据模型的显存占用情况自动选择剩余显存最大的显卡,如果显存均不足,则根据设定的等待时间(秒)自动等待直到超时退出或有满足要求的显卡激活训练过程。
card_list
:list[int, ...]
,指定使用的显卡编号列表。model_memory_usage
:float
,Float32
下模型的显存占用(GB)。idle
:bool
,default:false
,是否启动等待模式;若不启用,则未找到满足显存要求的显卡时直接报错退出,若启用,则按照idle_max_seconds
的设置等待空余显卡直到超时。idle_max_seconds
:int
,超时时长,仅在启用idle
时生效。
int
,可供使用的显卡编号。
Made By Egg Targaryen
MIT License