Skip to content

Commit

Permalink
Merge pull request #516 from taddyb/master
Browse files Browse the repository at this point in the history
Added support for saving checkpoints with CUDA device numbers
KindXiaoming authored Jan 19, 2025

Verified

This commit was signed with the committer’s verified signature.
2 parents 0a452a0 + 1625986 commit ecde4ec
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions kan/MultKAN.py
Original file line number Diff line number Diff line change
@@ -534,6 +534,9 @@ def saveckpt(self, path='model'):
round = model.round,
device = str(model.device)
)

if dic["device"].isdigit():
dic["device"] = int(model.device)

for i in range (model.depth):
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name

0 comments on commit ecde4ec

Please sign in to comment.