Skip to content

Commit

Permalink
Merge pull request #136 from stanfordnlp/zen/add_labels
Browse files Browse the repository at this point in the history
[Minor] Accepting `labels` field for loss calculation
  • Loading branch information
frankaging authored Mar 26, 2024
2 parents db7c676 + 9c5a2ff commit 37782e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ def forward(
unit_locations: Optional[Dict] = None,
source_representations: Optional[Dict] = None,
subspaces: Optional[List] = None,
labels: Optional[torch.LongTensor] = None,
output_original_output: Optional[bool] = False,
return_dict: Optional[bool] = None,
):
Expand Down Expand Up @@ -1438,7 +1439,10 @@ def forward(
)

# run intervened forward
counterfactual_outputs = self.model(**base)
if labels is not None:
counterfactual_outputs = self.model(**base, labels=labels)
else:
counterfactual_outputs = self.model(**base)
set_handlers_to_remove.remove()

self._output_validation()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="pyvene",
version="0.0.8dev",
version="0.0.8",
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 37782e6

Please sign in to comment.