Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jul 26, 2023
1 parent 148b289 commit 62cd7bb
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 11 deletions.
93 changes: 93 additions & 0 deletions maze_dataset/plotting/print_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Sequence
import html

import numpy as np
from jaxtyping import UInt8
from IPython.display import display, HTML
import matplotlib
from maze_dataset.constants import SPECIAL_TOKENS

from maze_dataset.tokenization import tokens_between
from maze_dataset.tokenization.token_utils import get_adj_list_tokens, get_origin_tokens, get_path_tokens, get_target_tokens

RGBArray = UInt8[np.ndarray, "n 3"]

_DEFAULT_TEMPLATE: str = '<span style="color: black; background-color: rgb{clr}">&nbsp{tok}&nbsp</span>'

def color_tokens_rgb(
tokens: list[str],
colors: RGBArray,
template: str = _DEFAULT_TEMPLATE,
):
output: list[str] = [
template.format(
tok=html.escape(tok),
clr=tuple(np.array(clr, dtype=np.uint8)),
)
for tok, clr in zip(tokens, colors)
]
return ' '.join(output)

def color_tokens_cmap(
tokens: list[str],
weights: Sequence[float],
cmap: str|matplotlib.colors.Colormap = "Blues",
):
assert len(tokens) == len(weights)
weights = np.array(weights)

if isinstance(cmap, str):
cmap = matplotlib.cm.get_cmap(cmap)

colors: RGBArray = cmap(weights)[:, :3] * 255

return color_tokens_rgb(tokens, colors)

# these colors are to match those from the original understanding-search talk at the conclusion of AISC 2023
_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = {
(SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): (234, 209, 220), # pink
(SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (217, 210, 233), # purple
(SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (207, 226, 243), # blue
(SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (217, 234, 211), # green
}

def color_maze_tokens_AOTP(
tokens: list[str],
) -> str:

output: list[str] = [
" ".join(tokens_between(
tokens, start_tok, end_tok, include_start=True, include_end=True
))
for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS.keys()
]

colors: RGBArray = np.array(list(_MAZE_TOKENS_DEFAULT_COLORS.values()), dtype=np.uint8)

return color_tokens_rgb(output, colors)

def display_html(html: str):
display(HTML(html))


def display_color_tokens_rgb(
tokens: list[str],
colors: RGBArray,
template: str = _DEFAULT_TEMPLATE,
) -> None:
html: str = color_tokens_rgb(tokens, colors, template)
display_html(html)

def display_color_tokens_cmap(
tokens: list[str],
weights: Sequence[float],
cmap: str|matplotlib.colors.Colormap = "Blues",
) -> None:
html: str = color_tokens_cmap(tokens, weights, cmap)
display_html(html)

def display_color_maze_tokens_AOTP(
tokens: list[str],
) -> None:
html: str = color_maze_tokens_AOTP(tokens)
display_html(html)
7 changes: 4 additions & 3 deletions maze_dataset/tokenization/maze_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class MazeTokenizer(SerializableDataclass):

@property
def name(self) -> str:
return f"maze_tokenizer-{self.tokenization_mode.value}-n{self.max_grid_size}"
max_grid_size_str: str = f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"

@cached_property
def node_token_map(self) -> dict[CoordTup, str]:
Expand Down Expand Up @@ -177,13 +178,13 @@ def coords_to_strings(
):
return coords_to_strings(
coords=coords,
coords_to_strings_func=_coord_to_strings_UT,
coord_to_strings_func=_coord_to_strings_UT,
when_noncoord=when_noncoord,
)
elif self.tokenization_mode == TokenizationMode.AOTP_indexed:
return coords_to_strings(
coords=coords,
coords_to_strings_func=_coord_to_strings_indexed,
coord_to_strings_func=_coord_to_strings_indexed,
when_noncoord=when_noncoord,
)
else:
Expand Down
Binary file not shown.
Loading

0 comments on commit 62cd7bb

Please sign in to comment.