diff --git a/README.md b/README.md index 572e8b9..1771588 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ BertViz
- -
+![model view](images/model-view-noscroll.gif) ### Neuron View The *neuron view* visualizes individual neurons in the query and key vectors and shows how they are used to compute attention. @@ -49,25 +47,24 @@ The *neuron view* visualizes individual neurons in the query and key vectors and 🕹 Try out the neuron view in the [Interactive Colab Tutorial](https://colab.research.google.com/drive/1MV7u8hdMgpwUd9nIlONQp-EBo8Fsj7CJ?usp=sharing) (all visualizations pre-loaded). -![neuron view](https://github.com/jessevig/bertviz/raw/master/images/neuron-view-dark.gif) +![neuron view](images/neuron-view-dark.gif) ## ⚡️ Getting Started -### Installation -#### Jupyter Notebook +### Running BertViz in a Jupyter Notebook From the command line: ```bash pip install bertviz ``` -You must also have Jupyter Notebook and ipywidgets installed in order to run BertViz in a notebook: +You must also have Jupyter Notebook and ipywidgets installed: ```bash pip install jupyterlab pip install ipywidgets ``` -If you run into any issues installing Jupyter or ipywidgets, consult the documentation [here](https://jupyter.org/install) and [here](https://ipywidgets.readthedocs.io/en/stable/user_install.html). +(If you run into any issues installing Jupyter or ipywidgets, consult the documentation [here](https://jupyter.org/install) and [here](https://ipywidgets.readthedocs.io/en/stable/user_install.html).) To create a new Jupyter notebook, simply run: @@ -78,7 +75,7 @@ jupyter notebook Then click `New` and select `Python 3 (ipykernel)` if prompted. -#### Colab +### Running BertViz in Colab To run in [Colab](https://colab.research.google.com/), simply add the following cell at the beginning of your Colab notebook: @@ -121,7 +118,7 @@ jupyter notebook ## 🕹 Interactive Tutorial Check out the [Interactive Colab Tutorial](https://colab.research.google.com/drive/1MV7u8hdMgpwUd9nIlONQp-EBo8Fsj7CJ?usp=sharing) -to interact with BertViz and learn more about the tool. Note: all visualizations are pre-loaded, so there is no need to execute any cells. +to learn more about BertViz and try out the tool. Note: all visualizations are pre-loaded, so there is no need to execute any cells. [![Tutorial](images/tutorial-screenshots.jpg)](https://colab.research.google.com/drive/1MV7u8hdMgpwUd9nIlONQp-EBo8Fsj7CJ?usp=sharing) @@ -133,7 +130,7 @@ to interact with BertViz and learn more about the tool. Note: all visuali - [Self-attention models (BERT, GPT-2, etc.)](#self-attention-models-bert-gpt-2-etc) * [Head and Model Views](#head-and-model-views) * [Neuron View](#neuron-view-1) -- [Encoder-decoder models (BART, MarianMT, etc.)](#encoder-decoder-models-bart-marianmt-etc) +- [Encoder-decoder models (BART, T5, etc.)](#encoder-decoder-models-bart-t5-etc) - [Installing from source](#installing-from-source) - [Additional options](#additional-options) * [Dark / light mode](#dark--light-mode) @@ -205,9 +202,9 @@ GPT-2 ([Notebook](notebooks/neuron_view_gpt2.ipynb), RoBERTa ([Notebook](notebooks/neuron_view_roberta.ipynb)) -For full API, please refer to the [source](bertviz/neuron_view.py). +Note that only one instance of the Neuron View may be displayed within a notebook. For full API, please refer to the [source](bertviz/neuron_view.py). -### Encoder-decoder models (BART, MarianMT, etc.) +### Encoder-decoder models (BART, T5, etc.) The head view and model view both support encoder-decoder models. @@ -357,10 +354,10 @@ which required modifying the model code (see `transformers_neuron_view` director Also, only one neuron view may be included per notebook. ### Attention as "explanation"? -* Visualizing attention weights illuminates a particular mechanism within the model architecture but does not -necessarily provide a direct *explanation* for model predictions. See [[1](https://arxiv.org/pdf/1909.11218.pdf), [2](https://arxiv.org/abs/1902.10186), [3](https://arxiv.org/pdf/1908.04626.pdf)]. +* Visualizing attention weights illuminates one type of architecture within the model but does not +necessarily provide a direct *explanation* for predictions [[1](https://arxiv.org/pdf/1909.11218.pdf), [2](https://arxiv.org/abs/1902.10186), [3](https://arxiv.org/pdf/1908.04626.pdf)]. * If you wish to understand how the input text influences output predictions more directly, consider [saliency methods](https://arxiv.org/pdf/2010.05607.pdf) provided -by excellent tools such such as the [Language Interpretability Toolkit](https://github.com/PAIR-code/lit) or [Ecco](https://github.com/jalammar/ecco). +by tools such as the [Language Interpretability Toolkit](https://github.com/PAIR-code/lit) or [Ecco](https://github.com/jalammar/ecco). ## 🔬 Paper diff --git a/bertviz/model_view.js b/bertviz/model_view.js index f7a8b0f..ca024a0 100644 --- a/bertviz/model_view.js +++ b/bertviz/model_view.js @@ -32,7 +32,7 @@ requirejs(['jquery', 'd3'], function($, d3) { const DETAIL_ATTENTION_WIDTH = 140; const DETAIL_BOX_WIDTH = 80; const DETAIL_BOX_HEIGHT = 18; - const DETAIL_PADDING = 10; + const DETAIL_PADDING = 15; const ATTN_PADDING = 0; const DETAIL_HEADING_HEIGHT = 25; const HEADING_TEXT_SIZE = 15; @@ -96,6 +96,7 @@ requirejs(['jquery', 'd3'], function($, d3) { config.svg.append("text") .text("Heads") .attr("fill", "black") + .attr("font-weight", "bold") .attr("font-size", HEADING_TEXT_SIZE + "px") .attr("x", axisSize + tableWidth / 2) .attr("text-anchor", "middle") @@ -117,6 +118,7 @@ requirejs(['jquery', 'd3'], function($, d3) { config.svg.append("text") .text("Layers") .attr("fill", "black") + .attr("font-weight", "bold") .attr("transform", "rotate(270, " + x + ", " + y + ")") .attr("font-size", HEADING_TEXT_SIZE + "px") .attr("x", x) @@ -149,7 +151,7 @@ requirejs(['jquery', 'd3'], function($, d3) { const axisSize = TEXT_SIZE + HEADING_PADDING + TEXT_SIZE + TEXT_PADDING; var xOffset = .8 * config.thumbnailWidth; var maxX = DIV_WIDTH; - var maxY = config.divHeight; + var maxY = config.divHeight - 3; var leftPos = axisSize + headIndex * config.thumbnailWidth; var x = leftPos + THUMBNAIL_PADDING + xOffset; if (x < MIN_X) { diff --git a/bertviz/neuron_view.js b/bertviz/neuron_view.js index 08f27e5..b612258 100644 --- a/bertviz/neuron_view.js +++ b/bertviz/neuron_view.js @@ -42,13 +42,13 @@ requirejs(['jquery', 'd3'], 'attn': '#2994de', 'neg': '#ff6318', 'pos': '#2090dd', - 'text': '#bbb', + 'text': '#ccc', 'selected_text': 'white', 'heading_text': 'white', 'text_highlight_left': "#1b86cd", 'text_highlight_right': "#1b86cd", 'vector_border': "#444", - 'connector': "#8aa4d2", + 'connector': "#2994de", 'background': 'black', 'dropdown': 'white', 'icon': 'white' @@ -266,13 +266,13 @@ requirejs(['jquery', 'd3'], .attr("y2", function (d, targetIndex) { return targetIndex * BOXHEIGHT + HEADING_HEIGHT + BOXHEIGHT / 2; }) - .attr("stroke-width", 1.9) + .attr("stroke-width", 2) .attr("stroke", getColor('connector')) .attr("stroke-opacity", function (d) { if (d==0) { return 0; } else { - return Math.max(MIN_CONNECTOR_OPACITY, Math.tanh(Math.abs(2 * d))); + return Math.max(MIN_CONNECTOR_OPACITY, Math.tanh(Math.abs(1.8 * d))); } }); } @@ -303,10 +303,10 @@ requirejs(['jquery', 'd3'], .attr("height", BOXHEIGHT - 5) .attr("width", MATRIX_WIDTH + 3) .style("fill-opacity", 0) - .attr("stroke-width", 1.9) + .attr("stroke-width", 2) .attr("stroke", getColor('connector')) .attr("stroke-opacity", function (d) { - return Math.tanh(Math.abs(2*d) ); + return Math.tanh(Math.abs(1.8*d) ); }); } @@ -353,13 +353,13 @@ requirejs(['jquery', 'd3'], ]) }) .attr("fill", "none") - .attr("stroke-width", 1.9) + .attr("stroke-width", 2) .attr("stroke", getColor('connector')) .attr("stroke-opacity", function (d) { if (d==0) { return 0; } else { - return Math.max(MIN_CONNECTOR_OPACITY, Math.tanh(Math.abs(2 * d))); + return Math.max(MIN_CONNECTOR_OPACITY, Math.tanh(Math.abs(1.8 * d))); } }); } @@ -382,7 +382,7 @@ requirejs(['jquery', 'd3'], .attr("y2", function (d, i) { return i * BOXHEIGHT + HEADING_HEIGHT + BOXHEIGHT / 2; }) - .attr("stroke-width", 1.9) + .attr("stroke-width", 2) .attr("stroke", getColor('connector')) } @@ -412,7 +412,7 @@ requirejs(['jquery', 'd3'], .attr("y2", function (d, targetIndex) { return targetIndex * BOXHEIGHT + HEADING_HEIGHT + BOXHEIGHT / 2; }) - .attr("stroke-width", 3) + .attr("stroke-width", 2) .attr("stroke", getColor('attn')) .attr("stroke-opacity", function (d) { return d; @@ -742,7 +742,7 @@ requirejs(['jquery', 'd3'], return i == index ? getColor('connector') : getColor('vector_border'); }) .style("stroke-width", function (d, i) { - return i == index ? 1.9 : 1; + return i == index ? 2 : 1; }) ; svg.select("#queries") diff --git a/images/model-view-noscroll.gif b/images/model-view-noscroll.gif index 2d283e3..de73931 100644 Binary files a/images/model-view-noscroll.gif and b/images/model-view-noscroll.gif differ diff --git a/images/neuron-view-dark.gif b/images/neuron-view-dark.gif index 1a8e780..2842d8b 100644 Binary files a/images/neuron-view-dark.gif and b/images/neuron-view-dark.gif differ