Skip to content

Commit

Permalink
default to guessing layer matcher if not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 19, 2023
1 parent c1c01d6 commit 9c52625
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
9 changes: 5 additions & 4 deletions linear_relational/CausalEditor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Sequence, TypeVar, Union, cast
from typing import Callable, Optional, Sequence, TypeVar, Union, cast

import torch
from tokenizers import Tokenizer
Expand All @@ -12,6 +12,7 @@
LayerMatcher,
collect_matching_layers,
get_layer_name,
guess_hidden_layer_matcher,
)
from linear_relational.lib.token_utils import (
ensure_tokenizer_has_pad_token,
Expand Down Expand Up @@ -55,18 +56,18 @@ def __init__(
model: nn.Module,
tokenizer: Tokenizer,
concepts: list[Concept],
layer_matcher: LayerMatcher,
layer_matcher: Optional[LayerMatcher] = None,
) -> None:
self.concepts = concepts
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
ensure_tokenizer_has_pad_token(tokenizer)
num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
self.layer_name_to_num = {}
for layer_num in range(num_layers):
self.layer_name_to_num[
get_layer_name(model, layer_matcher, layer_num)
get_layer_name(model, self.layer_matcher, layer_num)
] = layer_num

@property
Expand Down
16 changes: 7 additions & 9 deletions linear_relational/ConceptMatcher.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from dataclasses import dataclass
from typing import Callable, Sequence, Union
from typing import Callable, Optional, Sequence, Union

import torch
from tokenizers import Tokenizer
from torch import nn

from linear_relational.Concept import Concept
from linear_relational.lib.constants import DEFAULT_DEVICE
from linear_relational.lib.extract_token_activations import extract_token_activations
from linear_relational.lib.layer_matching import (
LayerMatcher,
collect_matching_layers,
get_layer_name,
guess_hidden_layer_matcher,
)
from linear_relational.lib.token_utils import (
ensure_tokenizer_has_pad_token,
find_final_word_token_index,
)
from linear_relational.lib.torch_utils import get_device
from linear_relational.lib.util import batchify

QuerySubject = Union[str, int, Callable[[str, list[int]], int]]
Expand All @@ -40,28 +41,25 @@ class ConceptMatcher:
tokenizer: Tokenizer
layer_matcher: LayerMatcher
layer_name_to_num: dict[str, int]
device: torch.device

def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
concepts: list[Concept],
layer_matcher: LayerMatcher,
device: torch.device = DEFAULT_DEVICE,
layer_matcher: Optional[LayerMatcher] = None,
) -> None:
self.concepts = concepts
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
ensure_tokenizer_has_pad_token(tokenizer)
num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
self.layer_name_to_num = {}
for layer_num in range(num_layers):
self.layer_name_to_num[
get_layer_name(model, layer_matcher, layer_num)
get_layer_name(model, self.layer_matcher, layer_num)
] = layer_num
self.device = device

def query(self, query: str, subject: QuerySubject) -> dict[str, ConceptMatchResult]:
return self.query_bulk([ConceptMatchQuery(query, subject)])[0]
Expand All @@ -88,7 +86,7 @@ def _query_batch(
layers=self.layer_name_to_num.keys(),
texts=[q.text for q in queries],
token_indices=subj_tokens,
device=self.device,
device=get_device(self.model),
# batching is handled already, so no need to batch here too
batch_size=len(queries),
show_progress=False,
Expand Down
27 changes: 19 additions & 8 deletions linear_relational/training/Trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass
from time import time
from typing import Literal, Optional

Expand All @@ -10,7 +9,11 @@
from linear_relational.Concept import Concept
from linear_relational.lib.balance_grouped_items import balance_grouped_items
from linear_relational.lib.extract_token_activations import extract_token_activations
from linear_relational.lib.layer_matching import LayerMatcher, get_layer_name
from linear_relational.lib.layer_matching import (
LayerMatcher,
get_layer_name,
guess_hidden_layer_matcher,
)
from linear_relational.lib.logger import log_or_print, logger
from linear_relational.lib.token_utils import PromptAnswerData, find_prompt_answer_data
from linear_relational.lib.torch_utils import get_device
Expand All @@ -23,12 +26,23 @@
VectorAggregation = Literal["pre_mean", "post_mean"]


@dataclass
class Trainer:
model: nn.Module
tokenizer: Tokenizer
layer_matcher: LayerMatcher
prompt_validator: Optional[PromptValidator] = None
prompt_validator: PromptValidator

def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
layer_matcher: Optional[LayerMatcher] = None,
prompt_validator: Optional[PromptValidator] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
self.prompt_validator = prompt_validator or PromptValidator(model, tokenizer)

def train_lre(
self,
Expand Down Expand Up @@ -171,10 +185,7 @@ def _process_relation_prompts(
valid_prompts = prompts
if validate_prompts:
log_or_print(f"validating {len(prompts)} prompts", verbose=verbose)
prompt_validator = self.prompt_validator or PromptValidator(
self.model, self.tokenizer
)
valid_prompts = prompt_validator.filter_prompts(
valid_prompts = self.prompt_validator.filter_prompts(
prompts, batch_size, verbose
)
if len(valid_prompts) == 0:
Expand Down

0 comments on commit 9c52625

Please sign in to comment.