A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations
- Create a virtual environment and install the package.
$ python3.12 -m venv .venv
$ source .venv/bin/activate
(.venv) $ pip install -e .
- Run unit tests as a sanity check.
(.venv) $ pytest
- (Development) Run ruff + isort fixes to sanitize the code changes.
(.venv) $ ./beautify.sh
Our framework serves as a test bed to generate synthetic tokenized datasets for training and evaluating transformer models. We do so by using DAG
and TokenProcessor
classes. These can be configured directly by the Args
dataclass. For example:
from tokenized_cot_icl.core.args import Args
from tokenized_cot_icl.core.data import TokenizedDataset
args = Args(
vocab_size=1024,
n_inputs=4,
n_parents=2,
chain_length=3,
n_examples=1,
enable_cot=True,
prompt_strategy="cot",
activation="leaky_relu",
n_tasks=10,
)
dataset = TokenizedDataset(args=args)
print(dataset[0])
The above item in the dataset is as follows:
{
'adj_list': tensor([[0, 2], [4, 3], [5, 3]]),
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867, 240, 466, 217]),
'labels': tensor([-100, -100, -100, -100, 240, 466, 217]),
'cot_eval':
{
'attention_mask': tensor([1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867]),
'last_example_cot': tensor([240, 466, 217])
}
}
Let's break down the result above to understand the DAG structure. Consider
The 'adj_list': tensor([[0, 2], [4, 3], [5, 3]])
(based on zero-indexing) indicates that the parent tokens for the chain tokens are as follows:
Chain Token | Parent Tokens |
---|---|
Note
The TokenCoverage metric introduced in the paper relies on the uniqueness of chain tokens in the entire dataset and depends heavily on the "vocab_size" and "activation". Thus controlling the difficulty of the tasks.
We leverage the HuggingFace transformers library to create custom Llama models and expose a MODEL_REGISTRY
to register new model families.
# src/tokenized_cot_icl/core/models.py
MODEL_REGISTRY = {"llama": create_llama_model}
Tip
Users can register the creation function for models of their choice from the transformers
library to explore new architectures and validate ideas.
To make it suitable for bulk launching the experiments, we rely on a TASK_CARD
to collate all the args. For instance, to train a model with the args as per the above example, we do:
# src/tokenized_cot_icl/core/task_card.py
def custom_task_card() -> Dict[int, Args]:
"""A custom task card."""
args = Args(...) # set as needed
return {0: args}
# set the dictionary
TASK_CARD = custom_task_card()
The TASK_CARD
allows us to index into the experimental config of our choice and launch the torch distributed data parallel (DDP) training runs. For example:
(.venv) $ cd src
(.venv) $ export NUM_NODES=1 # change as needed
(.venv) $ export LOCAL_WORLD_SIZE=4 # change as needed
(.venv) $ torchrun --nnodes=$NUM_NODES --nproc-per-node=$LOCAL_WORLD_SIZE -m tokenized_cot_icl.core.train --task_card_key 0
- By default, we use
metric_logger="stdout"
inArgs
and log the metrics/params toSTDOUT
. - We also support logging to an MLFlow tracking server by setting the
MLFLOW_SERVICE_URL
environment variable and usingArgs(metric_logger="mlflow")
.
Users can also apply the Liger-Kernel optimizations to patch the llama models by setting Args(use_liger_kernels=True)
and speed up the training runs.
(.venv) $ pip install liger-kernel # install suitable version
In addition to using the transformers.GenerationConfig
for small scale inference during the training runs, we also support vLLM and SGLang based evaluation of the trained model (or model checkpoints) to analyze the predictions.
(.venv) $ pip install vllm # install suitable version
(.venv) $ pip install sglang # install suitable version
We provide an easy to extend example for calculating the answer token prediction accuracy as follows:
# for vllm
(.venv) $ cd src && python tokenized_cot_icl/inference/vllm/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.
# for sglang
(.venv) $ cd src && python tokenized_cot_icl/inference/sglang/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.
@misc{kothapalli2025coticllabsyntheticframework,
title={CoT-ICL Lab: A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations},
author={Vignesh Kothapalli and Hamed Firooz and Maziar Sanjabi},
year={2025},
eprint={2502.15132},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.15132},
}