diff --git a/clrs/_src/clrs_text/clrs_utils.py b/clrs/_src/clrs_text/clrs_utils.py new file mode 100644 index 00000000..7ed38ab5 --- /dev/null +++ b/clrs/_src/clrs_text/clrs_utils.py @@ -0,0 +1,601 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions to create text versions of CLRS data.""" +from typing import Any, Optional + +import clrs +import numpy as np + + +CLRS_TASKS_WITH_HINTS = tuple( + [ + 'activity_selector', + 'articulation_points', + 'bellman_ford', + 'bfs', + 'binary_search', + 'bridges', + 'bubble_sort', + 'dag_shortest_paths', + 'dfs', + 'dijkstra', + 'find_maximum_subarray_kadane', + 'floyd_warshall', + 'graham_scan', + 'heapsort', + 'insertion_sort', + 'jarvis_march', + 'kmp_matcher', + 'lcs_length', + 'matrix_chain_order', + 'minimum', + 'mst_kruskal', + 'mst_prim', + 'naive_string_matcher', + 'optimal_bst', + 'quickselect', + 'quicksort', + 'strongly_connected_components', + 'task_scheduling', + 'topological_sort', + ], +) +CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER = { + 'naive_string_matcher': 's', + 'kmp_matcher': 's', +} +CLRS_SEARCH_TAKS_OUTPUT_REPLACER = { + 'binary_search': ['low', 'high'], + 'find_maximum_subarray_kadane': ['best_low', 'best_high'], + 'quickselect': ['pivot'], +} +CLRS_PARENTHESES_TRACES = frozenset( + {'binary_search', 'find_maximum_subarray_kadane'} +) +CLRS_SORTING_TASKS = ['bubble_sort', 'heapsort', 'insertion_sort', 'quicksort'] + +DEFAULT_SEPARATOR = ', ' +INPUT_TRACE_MARKER = 'initial_trace:' +TRACE_ANSWER_SEPARATOR = ' | ' +OUTPUT_TRACE_MARKER = 'trace' +PERMUTATION_SEPARATOR = '->' +SEQUENCE_SEPARATOR = ' ' + +_HINT_PREFIX = '_h' + + +def format_clrs_example( + algo: str, + sample: clrs.Feedback, + use_hints: bool = False, +) -> tuple[str, str]: + """Formats CLRS example into prompt for the LLM. + + Args: + algo: Name of the algorithm the sample comes from. + sample: A sample generated by a CLRS sampler. + use_hints: if True the initial CLRS hint is added to the input, the rest of + to the output. + + Returns: + The question and answer prompts. + """ + input_, output_names, output, hints_added = sample_to_str( + algo=algo, + sample=sample, + use_hints=use_hints, + ) + if hints_added: + output_name_str = TRACE_ANSWER_SEPARATOR.join( + [OUTPUT_TRACE_MARKER, output_names] + ) + else: + output_name_str = output_names + + question = f'{algo}:\n{input_}\n{output_name_str}:\n' + answer = f'{output}\n\n' + + return question, answer + + +def _get_output_names( + algo_name: str, + spec: clrs.Spec, + use_hints: bool, +) -> list[str]: + """Gets the output names for a CLRS algorithm.""" + if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER and use_hints: + return [CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name]] + elif algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints: + return CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name] + else: + return [ + spec_name + for spec_name in spec + if spec[spec_name][0] == clrs.Stage.OUTPUT + ] + + +def _get_output_str( + sample: clrs.Feedback, spec, algo_name: str, use_hints: bool +) -> list[str]: + """Gets the output string for a CLRS algorithm.""" + if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints: + output_results = [] + spec_names = CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name] + for spec_name in spec_names: + x = _get_feature_by_name(sample.features.hints, spec_name).data[-1] + output_results.append( + _feature_to_str( + name=spec_name, + spec=spec, + x=x, + with_name=False, + inputs=sample.features.inputs, + ) + ) + return [DEFAULT_SEPARATOR.join(output_results)] + else: + return _create_output_feature_strs( + spec=spec, + inputs=sample.features.inputs, + outputs=sample.outputs, + ) + + +def sample_to_str( + algo: str, + sample: clrs.Feedback, + use_hints: bool = False, +) -> tuple[str, str, str, bool]: + """Converts a CLRS sample into input and output strings. + + Output examples without hints: + 1. insertion_sort + input_str = 'key: [0.549 0.715 0.603 0.545 0.424]' + output_names_strs = 'pred' + output_str = '[0.424 0.545 0.549 0.603 0.715]' + 2. find_maximum_subarray + input_str = 'key: [0.098 0.43 0.206 0.09 -0.153]' + output_names_strs = 'start, end' + output_str = '0, 3' + 3. binary_search + input_str = 'key: [0.424 0.545 0.549 0.603 0.715], target: 0.646' + output_names_strs = 'return' + output_str = '4' + + Output examples with hints: + 1. insertion_sort + input_str = 'key: [0.549 0.715 0.603 0.545 0.424], trace: + 0->1->2->3->4' + output_names_strs = 'pred' + output_str = '[0.549 0.715 0.603 0.545 0.424], + [0.549 0.603 0.715 0.545 0.424], + [0.545 0.549 0.603 0.715 0.424], + [0.424 0.545 0.549 0.603 0.715] + | [0.424 0.545 0.549 0.603 0.715]' + 2. find_maximum_subarray + input_str = 'key: [0.098 0.43 0.206 0.09 -0.153]' + output_names_strs = 'start, end' + output_str = '0, 3' + 3. binary_search + input_str = 'key: [0.424 0.545 0.549 0.603 0.715], target: 0.646' + output_names_strs = 'return' + output_str = '4' + + For more details about task specs refer to + clrs._src.specs + + + Args: + algo: Name of the algorithm the sample comes from. + sample: A sample generated by a CLRS sampler. + use_hints: if True the initial CLRS hint is added to the input, the rest of + to the output. + + Returns: + A 3-tuple of (input, output_names, output) strings. + """ + spec = clrs.SPECS[algo] + + # Create input prompt. + input_strs = _create_input_feature_strs(spec, sample.features.inputs) + input_str = DEFAULT_SEPARATOR.join(input_strs) + # Create output prompt. + output_names = _get_output_names( + algo_name=algo, + spec=spec, + use_hints=use_hints, + ) + output_strs = _get_output_str( + sample, + spec, + algo_name=algo, + use_hints=use_hints, + ) + output_str = DEFAULT_SEPARATOR.join(output_strs) + output_names_strs = DEFAULT_SEPARATOR.join(output_names) + + hints_added = False + if use_hints: + input_hint_str, output_hint_str, hints_added = _create_hint_feature_strs( + algo_name=algo, + spec=spec, + inputs=sample.features.inputs, + hints=sample.features.hints, + output_names=output_names, + ) + output_str = _format_hint([output_str], algo_name=algo) + output_names_strs = _format_hint([output_names_strs], algo_name=algo) + + if input_hint_str: + input_hint_str = f'{INPUT_TRACE_MARKER} {input_hint_str}' + input_str = DEFAULT_SEPARATOR.join([input_str, input_hint_str]) + output_str = TRACE_ANSWER_SEPARATOR.join( + [ + output_hint_str if output_hint_str else '', + output_str, + ], + ) + + return input_str, output_names_strs, output_str, hints_added + + +def _create_input_feature_strs( + spec: clrs.Spec, + inputs: clrs.Features, +) -> list[str]: + """Extracts input features and convert them into strings.""" + input_strs = [] + for spec_name in spec: + + stage, _, _ = spec[spec_name] # (stage, location, type) + + if stage != clrs.Stage.INPUT: + continue + + if _do_not_include_input_in_text(spec_name, spec): + continue + + input_strs.append( + _feature_to_str( + name=spec_name, + spec=spec, + x=_get_feature_by_name(inputs, spec_name).data, + with_name=True, + ), + ) + return input_strs + + +def _create_output_feature_strs( + spec: clrs.Spec, + inputs: clrs.Features, + outputs: clrs.Features, +) -> list[str]: + """Extracts output features and convert them into strings.""" + output_strs = [] + for spec_name in spec: + stage, _, _ = spec[spec_name] + + if stage != clrs.Stage.OUTPUT: + continue + + x = _get_feature_by_name(outputs, spec_name).data + output_strs.append( + _feature_to_str( + name=spec_name, + spec=spec, + x=x, + with_name=False, + inputs=inputs, + ) + ) + + return output_strs + + +def _is_hint_field( + field_name: str, + algo_name: str, + output_names: list[str], +) -> bool: + """Checks if a field is a hint field.""" + if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER: + return field_name == CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name] + if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER: + return field_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER[algo_name] + else: + return field_name[: -len(_HINT_PREFIX)] in output_names + + +def _get_output_name(hint_name: str, algo_name: str) -> str: + """Gets the output name for a hint field.""" + if algo_name in CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER: + return CLRS_STRING_MATCHING_TASKS_OUTPUT_REPLACER[algo_name] + if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER: + return hint_name + else: + return hint_name[: -len(_HINT_PREFIX)] + + +def _format_hint(hints: list[str], algo_name: str) -> str: + """Formats a hint field.""" + result = DEFAULT_SEPARATOR.join(hints) + if algo_name in CLRS_PARENTHESES_TRACES: + result = f'({result})' + return result + + +def _create_hint_feature_strs( + algo_name: str, + spec: clrs.Spec, + inputs: clrs.Features, + hints: clrs.Features, + output_names: list[str], +) -> tuple[str, str, bool]: + """Extracts hint features and convert them into strings.""" + input_hint_strs = [] + unrolled_hints_strs = [] + for hint in hints: + hint_name = hint.name + if not _is_hint_field(hint_name, algo_name, output_names): + continue + + result_hint = _get_feature_by_name(hints, hint_name).data + + output_name = _get_output_name(hint_name, algo_name) + + # The first element of `result_hint` is the initial hint that is used in the + # input prompt. + input_hint_strs.append( + _feature_to_str( + name=output_name, + spec=spec, + x=np.array(result_hint[0]), + with_name=False, + inputs=inputs, + ) + ) + + unrolled_hints = [] + # The first element of `result_hint` is an input hint, and the last element + # is identical to the output result. We don't need either of these elements. + # for output hints, so we skip them. + for unrolled_hint in result_hint[1:-1]: + unrolled_hints.append( + _feature_to_str( + name=output_name, + spec=spec, + x=np.array(unrolled_hint), + with_name=False, + inputs=inputs, + ), + ) + unrolled_hints_strs.append(unrolled_hints) + + hints_found = len(input_hint_strs) & len(unrolled_hints_strs) + + input_hint_str = _format_hint(input_hint_strs, algo_name=algo_name) + output_hint_strs = [] + if hints_found: + unrolled_hints_lengths = set( + [len(unrolled_hint) for unrolled_hint in unrolled_hints_strs] + ) + if len(unrolled_hints_lengths) != 1: + raise ValueError(f'Output hints have to have equal length. Spec: {spec}') + + for hints in zip(*unrolled_hints_strs): + output_hint_strs.append(_format_hint(hints, algo_name)) + + output_hint_str = DEFAULT_SEPARATOR.join(output_hint_strs) + + return input_hint_str, output_hint_str, bool(hints_found) + + +def _feature_to_str( + name: str, + spec: clrs.Spec, + x: np.ndarray, + with_name: bool, + inputs: Optional[clrs.Features] = None, + edge_masks_as_edge_list: bool = False, +) -> str: + """Converts a numerical CLRS feature into a string.""" + if x.shape[0] != 1: + raise ValueError( + 'Feature first dimension (batch) must be 1 but it has shape' + f' {x.shape}.', + ) + + x = x[0] + unused_stage, location, typ_ = spec[name] + match location: + case clrs.Location.NODE: + output = _convert_node_features_to_str( + x=x, + spec_name=name, + spec=spec, + spec_type=typ_, + inputs=inputs, + ) + case clrs.Location.GRAPH: + output = _convert_graph_features_to_str( + x=x, + spec_name=name, + spec=spec, + spec_type=typ_, + ) + case clrs.Location.EDGE: + output = _convert_edge_features_to_str( + x=x, + spec_name=name, + spec=spec, + spec_type=typ_, + edge_masks_as_edge_list=edge_masks_as_edge_list, + ) + case _: + raise KeyError(f'Hint location not supported in spec {spec[name]}') + + if with_name: + return f'{name}: {output}' + else: + return output + + +def predecessors_to_order(x: np.ndarray) -> np.ndarray: + """From list of predecessors to list of ordered node indices.""" + x = x.astype(int) + y = np.ones(len(x)) + y[x] = 0 + [last] = np.where(y)[0] + order = np.zeros(len(x), dtype=int) + order[-1] = last + for i in range(len(order) - 2, -1, -1): + order[i] = x[order[i+1]] + return order + + +def _convert_node_features_to_str( + x: np.ndarray, + spec_name: str, + spec: clrs.Spec, + spec_type: str, + inputs: Optional[clrs.Features] = None, +) -> str: + """Converts node features into string.""" + match spec_type: + case clrs.Type.SHOULD_BE_PERMUTATION: + # For the text version of CLRS, if the output is a permutation, we present + # the "key" input values in the order given by the permutation. + nonsorted_values = _get_feature_by_name(inputs, 'key').data[0] + permutation_indexes = np.array(predecessors_to_order(x)).astype(int) + sorted_values = np.array( + [nonsorted_values[index] for index in permutation_indexes] + ) + + return _bracket( + SEQUENCE_SEPARATOR.join([f'{scalar:.3g}' for scalar in sorted_values]) + ) + + case clrs.Type.MASK_ONE: + [index] = x.nonzero()[0] + return f'{index}' + + case clrs.Type.SCALAR: + return _bracket(SEQUENCE_SEPARATOR.join([f'{a:.3g}' for a in x])) + + case clrs.Type.MASK | clrs.Type.POINTER | clrs.Type.CATEGORICAL: + if spec_type == clrs.Type.CATEGORICAL: + categories = np.argmax(x, axis=-1) + int_output = categories + else: + int_output = x.astype(int) + return _bracket(SEQUENCE_SEPARATOR.join([f'{a}' for a in int_output])) + + case _: + raise KeyError(f'Feature type not supported in spec {spec[spec_name]}') + + +def _convert_graph_features_to_str( + x: np.ndarray, + spec_name: str, + spec: clrs.Spec, + spec_type: str, +) -> str: + """Converts graph features into string.""" + match spec_type: + case clrs.Type.SCALAR: + return f'{x:.3f}' + + case clrs.Type.CATEGORICAL: + categories = np.argmax(x, axis=-1) + return f'{categories}' + + case _: + if spec_type in [clrs.Type.MASK, clrs.Type.MASK_ONE, clrs.Type.POINTER]: + return f'{x.astype(int)}' + else: + raise KeyError(f'Feature type not supported in spec {spec[spec_name]}') + + +def _convert_edge_features_to_str( + x: np.ndarray, + spec_name: str, + spec: clrs.Spec, + spec_type: str, + edge_masks_as_edge_list: bool, +): + """Converts edge features into string.""" + + if edge_masks_as_edge_list: + if spec_type == clrs.Type.MASK or ( + spec_type == clrs.Type.SCALAR and _is_binary(x) + ): + edges = list(zip(*np.nonzero(x > 0))) + return DEFAULT_SEPARATOR.join([f'({x},{y})' for x, y in edges]) + else: + match spec_type: + case clrs.Type.POINTER | clrs.Type.MASK | clrs.Type.CATEGORICAL: + if spec_type == clrs.Type.CATEGORICAL: + # lcs_length includes masked elements where the category is -1 + mask = np.any(x == clrs.OutputClass.MASKED, axis=-1) + categories = np.argmax(x, axis=-1) + categories[mask] = -1 + int_output = categories + else: + int_output = x.astype(int) + row_to_str = lambda r: _bracket(' '.join([f'{a}' for a in r])) + return _bracket( + DEFAULT_SEPARATOR.join( + [row_to_str(r) for r in int_output], + ), + ) + + case clrs.Type.SCALAR: + row_to_str = lambda r: _bracket(' '.join([f'{a:.3g}' for a in r])) + return _bracket(DEFAULT_SEPARATOR.join([row_to_str(r) for r in x])) + + raise KeyError(f'Feature type not supported in spec {spec[spec_name]}') + + +def _get_feature_by_name(examples: clrs.Features, spec_name: str) -> Any: + filtered_inputs = [ + example for example in examples if example.name == spec_name + ] + + if len(filtered_inputs) > 1: + raise ValueError("More than one example has name '{}'".format(spec_name)) + + return filtered_inputs[0] + + +def _is_binary(x: np.ndarray) -> bool: + precision = 10000 + elements = set(np.unique(np.round(x * precision).astype(int) / precision)) + return elements.issubset({-1, 0, 1}) + + +def _bracket(s: str) -> str: + return f'[{s}]' + + +def _do_not_include_input_in_text(spec_name: str, spec: clrs.Spec) -> bool: + if spec_name == 'pos': + return True + if spec_name == 'adj' and 'A' in spec: + return True # in all cases, 'adj' is redundant with A + + return False diff --git a/clrs/_src/clrs_text/clrs_utils_test.py b/clrs/_src/clrs_text/clrs_utils_test.py new file mode 100644 index 00000000..aea7d2fb --- /dev/null +++ b/clrs/_src/clrs_text/clrs_utils_test.py @@ -0,0 +1,86 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for clrs.src_.clrs_text.clrs_utils.""" + + +from absl.testing import absltest +from absl.testing import parameterized +import clrs +from clrs._src import probing +from clrs._src.clrs_text import clrs_utils +import numpy as np + + +class TestFormatCLRSExamples(parameterized.TestCase): + + @parameterized.product( + algo_name=list(clrs.CLRS_30_ALGS_SETTINGS.keys()), + use_hints=[True, False], + ) + def test_format(self, algo_name, use_hints): + """Test that we can format samples from any algo into strings.""" + sampler, _ = clrs.build_sampler( + algo_name, + seed=0, + num_samples=-1, + length=16, + track_max_steps=False, + use_padding=False, + ) + + for _ in range(100): + sample = sampler.next(batch_size=1) + + question, answer = clrs_utils.format_clrs_example( + algo_name, + sample, + use_hints=use_hints, + ) + + self.assertTrue(question.startswith(f'{algo_name}:\n')) + self.assertTrue(question.endswith(':\n')) + self.assertTrue(answer.endswith('\n\n')) + + if use_hints and algo_name in clrs_utils.CLRS_TASKS_WITH_HINTS: + self.assertIn('trace | ', question) + self.assertIn('initial_trace:', question) + else: + self.assertNotIn('trace | ', question) + self.assertNotIn('initial_trace:', question) + + +class TestPredecessorToOrder(parameterized.TestCase): + def test_predecessor_to_order(self): + """Test that `predecessor_to_order` matches the slower clrs conversion.""" + for i in range(20): + length = np.random.randint(4, 16) + sampler, unused_spec = clrs.build_sampler( + 'insertion_sort', + seed=i, + num_samples=-1, + length=length, + track_max_steps=False, + ) + x = sampler.next(batch_size=1) + pred = x.outputs[0].data[0] + expected_order = probing.predecessor_pointers_to_permutation_matrix( + pred + ) @ np.arange(pred.shape[0]) + order = clrs_utils.predecessors_to_order(pred) + np.testing.assert_array_equal(expected_order, order) + + +if __name__ == '__main__': + absltest.main()