Skip to content

Commit

Permalink
remove unused load function
Browse files Browse the repository at this point in the history
  • Loading branch information
johnathanchiu committed May 3, 2024
1 parent bf81f3b commit caf0217
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions examples/bayes_llama3/llama3/modules/bayesllama_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ def __init__(self, config):
[copy.deepcopy(self.layers[-1]) for _ in range(2)]
)

def load_ensemble_weights(
self, layer_idx: int, param_names: list, ensemble_params: list[torch.Tensor]
):
module = self.layers[layer_idx]

for param_name, params in zip(param_names, ensemble_params):
attributes = re.split(r"\d+", param_name)[-1].split(".")[1:]

sub_module = module
attr = None
for attr in attributes:
sub_module = getattr(sub_module, attr)
setattr(sub_module, attr, torch.nn.Parameter(params.to(self.device)))

return module

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down

0 comments on commit caf0217

Please sign in to comment.