You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on May 14, 2024. It is now read-only.
when resuming from a checkpoint, it returns:
load network weights from /content/drive/MyDrive/XXXXX.safetensors: None
I have slightly looked into the code, and find that it always provides a 'FALSE' value for the 'dtype' parameter in the 'load_weights' function, while you always feed the correct 'dtype' when saving the checkpoint. And there is a simple fix for it:
in flie 'kohya-trainer/train_network.py', starting from line 206, you can replace
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
print(f"load network weights from {args.network_weights}: {info}")
with:
def load_weights_2(network, file, dtype):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = network.load_state_dict(weights_sd, dtype)
return info
if args.network_weights is not None:
info = load_weights_2(network, args.network_weights, save_dtype)
print(f"load network weights from {args.network_weights}: {info}")
This should work for lora, locon and loha, not sure if it works for XL, and it can be easily done in Google Colab editing mode.
Hope it can be fixed officially soon.
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
when resuming from a checkpoint, it returns:
load network weights from /content/drive/MyDrive/XXXXX.safetensors: None
I have slightly looked into the code, and find that it always provides a 'FALSE' value for the 'dtype' parameter in the 'load_weights' function, while you always feed the correct 'dtype' when saving the checkpoint. And there is a simple fix for it:
in flie 'kohya-trainer/train_network.py', starting from line 206, you can replace
with:
This should work for lora, locon and loha, not sure if it works for XL, and it can be easily done in Google Colab editing mode.
Hope it can be fixed officially soon.
The text was updated successfully, but these errors were encountered: