Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducibility of edge probing tasks #1305

Open
sagnik opened this issue Apr 1, 2021 · 3 comments
Open

Reproducibility of edge probing tasks #1305

sagnik opened this issue Apr 1, 2021 · 3 comments

Comments

@sagnik
Copy link

sagnik commented Apr 1, 2021

Describe the bug
I am having a bit of trouble reproducing the results for edge probing tasks. Various people seemed to have reported various results. For simplicity, I will use the example for Coref on ontonotes data, and the ones that use the top layer of BERT.

  1. Tenney 2019, ICLR: bert-base-uncased, 90.2. I don't quite understand the cat part.
  2. Tenney 2019, Bert rediscovers: bert-base-large, 91.9.
  3. Van aken 2019: bert-base-uncased, approx 95, figure 2.
  4. Merchant 2020: bert-base-uncased, 95.2, appendix, table 2.
  5. When I run using the existing Jiant, w/o any code change (I have made some changes in the code to read yml files): bert-base-uncased, 92.6 (average over 5 runs).
jiant_task_container_config_path: ${JIANT_HOME}/run_configs/coref_bert_run_config.yml
output_dir: ${JIANT_DATA}/runs/coref_bert_es
hf_pretrained_model_name_or_path: bert-base-uncased
model_path: ${JIANT_DATA}/models/bert-base-uncased/model/model.p
model_config_path: ${JIANT_DATA}/models/bert-base-uncased/model/config.json
learning_rate: 0.00001
eval_every_steps: 500
no_improvements_for_n_evals: 100
do_train: true
do_val: true
do_save: true
write_val_preds: true
write_test_preds: true
force_overwrite: true

Where coref_bert_run_config is given by

task_config_path_dict:
  coref: ${JIANT_HOME}/tasks/configs/coref_config.yml
task_cache_config_dict:
  coref:
    train: ${JIANT_DATA}/cache/coref/train
    val: ${JIANT_DATA}/cache/coref/val
    val_labels: ${JIANT_DATA}/cache/coref/val_labels
    test: ${JIANT_DATA}/cache/coref/test
    test_labels: ${JIANT_DATA}/cache/coref/test_labels
sampler_config:
  sampler_type: ProportionalMultiTaskSampler
global_train_config:
  max_steps: 129900
  warmup_steps: 12990
task_specific_configs_dict:
  coref:
    train_batch_size: 16
    eval_batch_size: 8
    gradient_accumulation_steps: 1
    eval_subset_num: 500
taskmodels_config:
  task_to_taskmodel_map:
    coref: coref
  taskmodel_config_map:
    coref: null
task_run_config:
  train_task_list: &id001
  - coref
  train_val_task_list: *id001
  val_task_list:
  - coref
  test_task_list:
  - coref
metric_aggregator_config:
  metric_aggregator_type: EqualMetricAggregator

However, I am also wondering where is the encoder frozen in the existing implementation, which I think is necessary, right? As Tenney 2019, Bert rediscovers writes,
image
If I look back in the original implementation branch for Tenney papers, there is an option to do this, which I can not find in the existing code. If I freeze the encoder weights myself, the result for Coref reduces down to 77.7 for bert-base-uncased, which is really low but corresponds well to Liu 2019 paper, appendix D6 (the data is the same AFAICT, but there is no self-attention pooling layer over the spans). Given this, my question is: where is the encoder frozen for EP tasks? Or am I understanding the task design completely wrong?

Here's the original code for multi label span classification in the existing implementation:

    elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION:
        assert taskmodel_kwargs is None
        span_comparison_head = heads.SpanComparisonHead(
            hidden_size=hidden_size,
            hidden_dropout_prob=hidden_dropout_prob,
            num_spans=task.num_spans,
            num_labels=len(task.LABELS),
        )
        taskmodel = taskmodels.MultiLabelSpanComparisonModel(
            encoder=encoder, span_comparison_head=span_comparison_head,
        )
class MultiLabelSpanComparisonModel(Taskmodel):
    def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead):
        super().__init__(encoder=encoder)
        self.span_comparison_head = span_comparison_head

    def forward(self, batch, task, tokenizer, compute_loss: bool = False):
        encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
        logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans)
        if compute_loss:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(
                logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(),
            )
            return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other)
        else:
            return LogitsOutput(logits=logits, other=encoder_output.other)

I added this before creating taskmodel:

for param in encoder.parameters():
            param.requires_grad = False

To Reproduce

  1. Tell use which version of jiant you're using: the latest one, with some modifications for reading yml files.
  2. Describe the environment where you're using jiant, e.g, "2 P40 GPUs": single nvidia GPU
  3. Provide the experiment config artifact (e.g., defaults.conf): see above.

Expected behavior

Screenshots

Additional context

@zphang
Copy link
Collaborator

zphang commented Apr 5, 2021

Hi @sagnik,

I think you're more or less correct, and the issue may be coming from the mix and cat configurations. Indeed, the encoder should be frozen when using these tasks in a "probing" format, and that's not an option in the standard jiant runscript (which supports general, full-model fine-tuning), so code for looping over and freezing the encoder parameters looks right.

For the probing tasks though, they generally do not use just the top layer representations, but information for the middle of the layers.

  • cat concatenates the final layer activations and embeddings
  • mix learns a weighted mix of activations across all layers(12 layers + 1 embedding), where the parameters are a set of tunable scalars, similar to ELMo.

The implementation you described takes only the final layer activations, which does not match either of the above setups.

I will be busy over the next few days, but I will try to get some results that compare all three.

@sagnik
Copy link
Author

sagnik commented Apr 6, 2021

Thanks for the reply @zphang .

Yes, I agree the original paper uses cat and mix which is different from my setup. I was talking about the cat result anyway, but as I mentioned that I didn't quite understand what it was doing, thanks for clearing that up. I also don't see cat or the scalar mixing option anymore in the current master, am I missing it?

If you can run the experiments on your end, that would be great. If the scalar mix/cat is difficult to put in, just the results from the top layer would be fine: I just want to know if there's something I am doing massively wrong. At least two papers have reported results with the top layer and my results are almost 20% off. If by any chance you can get the layerwise results, that would be awesome!

Also, in the tenney papers branch, In the configs, transfer_paradigm=frozen is not mentioned, so I am pretty sure it comes from cl_arguments in main? I didn't go through that code in detail, so I might be missing something. But if indeed that is the case, I think for the purpose of reproducibility that should be put in the conf files themselves.

Thanks for getting back!

@sagnik
Copy link
Author

sagnik commented Apr 9, 2021

I have a small update on this. I have updated the code to include scalar mixing and cat. Here's the changeset for the files:

(jiant) sagnik:modeling$ git diff upstream/master..origin/master  -- model_setup.py
diff --git a/jiant/proj/main/modeling/model_setup.py b/jiant/proj/main/modeling/model_setup.py
index 6ff546e..835b287 100644
--- a/jiant/proj/main/modeling/model_setup.py
+++ b/jiant/proj/main/modeling/model_setup.py
@@ -21,10 +21,13 @@ from jiant.tasks import TaskTypes
 
 
 def setup_jiant_model(
-    hf_pretrained_model_name_or_path: str,
-    model_config_path: str,
-    task_dict: Dict[str, Task],
-    taskmodels_config: container_setup.TaskmodelsConfig,
+        hf_pretrained_model_name_or_path: str,
+        model_config_path: str,
+        task_dict: Dict[str, Task],
+        taskmodels_config: container_setup.TaskmodelsConfig,
+        freeeze_encoder: bool,
+        num_hidden_layers: int = -1,
+        scalar_mixing: bool = False
 ):
     """Sets up tokenizer, encoder, and task models, and instantiates and returns a JiantModel.
 
@@ -47,7 +50,9 @@ def setup_jiant_model(
         model_config_path (str): Path to the JSON file containing the configuration parameters.
         task_dict (Dict[str, tasks.Task]): map from task name to task instance.
         taskmodels_config: maps mapping from tasks to models, and specifying task-model configs.
-
+        freeeze_encoder: should we freeze the underlying encoder model?
+        num_hidden_layers: how many layers of encoder would we use (from the bottom)?
+        scalar_mixing: should we use a scalar mixing for ep tasks?
     Returns:
         JiantModel nn.Module.
 
@@ -66,6 +71,9 @@ def setup_jiant_model(
             model_arch=model_arch,
             encoder=encoder,
             taskmodel_kwargs=taskmodels_config.get_taskmodel_kwargs(taskmodel_name),
+            freeze_encoder=freeeze_encoder,
+            scalar_mixing=scalar_mixing,
+            num_hidden_layers=num_hidden_layers
         )
         for taskmodel_name, task_name_list in get_taskmodel_and_task_names(
             taskmodels_config.task_to_taskmodel_map
@@ -144,7 +152,7 @@ def delegate_load(jiant_model, weights_dict: dict, load_mode: str):
 
 
 def load_encoder_from_transformers_weights(
-    encoder: nn.Module, weights_dict: dict, return_remainder=False
+        encoder: nn.Module, weights_dict: dict, return_remainder=False
 ):
     """Find encoder weights in weights dict, load them into encoder, return any remaining weights.
 
@@ -242,7 +250,7 @@ def load_encoder_only(jiant_model, weights_dict):
 
 
 def load_partial_heads(
-    jiant_model, weights_dict, allow_missing_head_weights=False, allow_missing_head_model=False
+        jiant_model, weights_dict, allow_missing_head_weights=False, allow_missing_head_model=False
 ):
     """Loads model weights and returns lists of missing head weights or missing heads (if any).
 
@@ -274,7 +282,8 @@ def load_partial_heads(
 
 
 def create_taskmodel(
-    task, model_arch, encoder, taskmodel_kwargs: Optional[Dict] = None
+        task, model_arch, encoder, taskmodel_kwargs: Optional[Dict] = None, freeze_encoder: bool = False,
+        scalar_mixing: bool = False, num_hidden_layers: int = -1
 ) -> taskmodels.Taskmodel:
     """Creates, initializes and returns the task model for a given task type and encoder.
 
@@ -283,7 +292,9 @@ def create_taskmodel(
         model_arch (ModelArchitectures.Any): Model architecture (e.g., ModelArchitectures.BERT).
         encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer).
         taskmodel_kwargs (Optional[Dict]): map containing any kwargs needed for taskmodel setup.
-
+        freeze_encoder: freeze the encoder (for ep tasks)
+        scalar_mixing: for ep tasks, should we use a scalar mixing of layers?
+        num_hidden_layers: how many layers should we use?
     Raises:
         KeyError if task does not have valid TASK_TYPE.
 
@@ -308,7 +319,9 @@ def create_taskmodel(
         hidden_dropout_prob = encoder.config.dropout
     else:
         raise KeyError()
-
+    if freeze_encoder:
+        for param in encoder.parameters():
+            param.requires_grad = False
     if task.TASK_TYPE == TaskTypes.CLASSIFICATION:
         assert taskmodel_kwargs is None
         classification_head = heads.ClassificationHead(
@@ -356,6 +369,9 @@ def create_taskmodel(
         )
     elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION:
         assert taskmodel_kwargs is None
+        if not scalar_mixing:  # if we do the scalar mixing, we don't have to concatenate embeddings and the final layer
+            # activation
+            hidden_size = hidden_size*2
         span_comparison_head = heads.SpanComparisonHead(
             hidden_size=hidden_size,
             hidden_dropout_prob=hidden_dropout_prob,
@@ -363,7 +379,8 @@ def create_taskmodel(
             num_labels=len(task.LABELS),
         )
         taskmodel = taskmodels.MultiLabelSpanComparisonModel(
-            encoder=encoder, span_comparison_head=span_comparison_head,
+            encoder=encoder, span_comparison_head=span_comparison_head, scalar_mixing=scalar_mixing,
+            num_hidden_layers=num_hidden_layers
         )
     elif task.TASK_TYPE == TaskTypes.TAGGING:
         assert taskmodel_kwargs is None
(jiant) sagnik:modeling$ git diff upstream/master..origin/master  -- taskmodels.py
diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py
index 53d5e70..509d770 100644
--- a/jiant/proj/main/modeling/taskmodels.py
+++ b/jiant/proj/main/modeling/taskmodels.py
@@ -10,6 +10,7 @@ import jiant.utils.transformer_utils as transformer_utils
 from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput
 from jiant.utils.python.datastructures import take_one
 from jiant.shared.model_resolution import ModelArchitectures
+from jiant.proj.main.modeling.scalarmix import ScalarMix
 
 
 class Taskmodel(nn.Module, metaclass=abc.ABCMeta):
@@ -147,13 +148,28 @@ class SpanPredictionModel(Taskmodel):
 
 
 class MultiLabelSpanComparisonModel(Taskmodel):
-    def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead):
+    def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead, scalar_mixing: bool,
+                 num_hidden_layers: int):
         super().__init__(encoder=encoder)
+        self.scalar_mixing = scalar_mixing
         self.span_comparison_head = span_comparison_head
+        self.num_hidden_layers = num_hidden_layers
 
     def forward(self, batch, task, tokenizer, compute_loss: bool = False):
-        encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
-        logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans)
+        with transformer_utils.output_hidden_states_context(self.encoder):
+            encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch)
+        embedding_weights = encoder_output.other[0][0]
+        layer_wise_unpooled = encoder_output.other[0][1:]
+        if self.num_hidden_layers == -1:
+            self.num_hidden_layers = len(layer_wise_unpooled) - 1
+        if not self.scalar_mixing:
+            top_layer_acts = layer_wise_unpooled[self.num_hidden_layers]
+            concat_vec = torch.cat((embedding_weights, top_layer_acts), -1)
+            logits = self.span_comparison_head(unpooled=concat_vec, spans=batch.spans)
+        else:  # introduce a mixing weight vector
+            layer_activations = list(layer_wise_unpooled[:self.num_hidden_layers])
+            scalar_mixing_layer = ScalarMix(mixture_size=len(layer_activations))
+            logits = self.span_comparison_head(unpooled=scalar_mixing_layer(layer_activations), spans=batch.spans)
         if compute_loss:
             loss_fct = nn.BCEWithLogitsLoss()
             loss = loss_fct(

scalarmix is pretty much the same as AllenNLP, with small changes:

(jiant) sagnik:modeling$ git diff upstream/master..origin/master  -- scalarmix.py
diff --git a/jiant/proj/main/modeling/scalarmix.py b/jiant/proj/main/modeling/scalarmix.py
new file mode 100644
index 0000000..5951ec5
--- /dev/null
+++ b/jiant/proj/main/modeling/scalarmix.py
@@ -0,0 +1,62 @@
+from typing import List
+
+import torch
+from torch.nn import ParameterList, Parameter
+
+
+class ScalarMix(torch.nn.Module):
+    """
+    Computes a parameterised scalar mixture of N tensors, `mixture = gamma * sum(s_k * tensor_k)`
+    where `s = softmax(w)`, with `w` and `gamma` scalar parameters.
+
+    In addition, if `do_layer_norm=True` then apply layer normalization to each tensor
+    before weighting.
+    """
+
+    def __init__(
+            self,
+            mixture_size: int,
+            initial_scalar_parameters: List[float] = None,
+            trainable: bool = True,
+    ) -> None:
+        super().__init__()
+        self.mixture_size = mixture_size
+        if initial_scalar_parameters is None:
+            initial_scalar_parameters = [0.0] * mixture_size
+        elif len(initial_scalar_parameters) != mixture_size:
+            raise RuntimeError(
+                "Length of initial_scalar_parameters {} differs "
+                "from mixture_size {}".format(initial_scalar_parameters, mixture_size)
+            )
+
+        self.scalar_parameters = ParameterList(
+            [
+                Parameter(
+                    torch.FloatTensor([initial_scalar_parameters[i]]), requires_grad=trainable
+                )
+                for i in range(mixture_size)
+            ]
+        )
+        self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable)
+
+    def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
+        """
+        Compute a weighted average of the `tensors`.  The input tensors an be any shape
+        with at least two dimensions, but must all be the same shape.
+
+        When `do_layer_norm=False` the `mask` is ignored.
+        """
+        if len(tensors) != self.mixture_size:
+            raise RuntimeError(
+                "{} tensors were passed, but the module was initialized to "
+                "mix {} tensors.".format(len(tensors), self.mixture_size)
+            )
+
+        normed_weights = torch.nn.functional.softmax(
+            torch.cat([parameter for parameter in self.scalar_parameters]), dim=0
+        )
+        normed_weights = torch.split(normed_weights, split_size_or_sections=1)
+        pieces = []
+        for weight, tensor in zip(normed_weights, tensors):
+            pieces.append(weight.to(device='cuda') * tensor)
+        return self.gamma.to(device='cuda') * sum(pieces)
(END)

Do the codes look right?

If yes, then I still can't reproduce the results. The cat setting gives me an f1_score of 78.71 with the config given above and the mix setting gives slightly better: 81.87.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants