Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Inject chat history before LLM calls #116

Open
ShaojieJiang opened this issue Aug 29, 2024 · 2 comments
Open

Feature request: Inject chat history before LLM calls #116

ShaojieJiang opened this issue Aug 29, 2024 · 2 comments

Comments

@ShaojieJiang
Copy link

Hi TextGrad developers,

First of all, thanks a lot for this great work!

I wonder do you have any plans to allow specifying chat history before the LLM calls? Below is some context explaining why this is important to me.

My task

Optimise the system prompt so that the chatbot behaves in a more controlled way as an interviewer. Therefore, the objective is not only to get a single good answer but also good responses after several turns of conversation.

Why can't I put the chat history in the prompt?

Most LLMs would understand it well if I put chat history in the prompts. However, to rule out any misalignment with inference, it's better to add history in the idiosyncratic way of each LLM. E.g., a list of dicts for OpenAI models.

MWE of my solution

Below is my solution. Although it's working, it looks a bit hacky so I'm calling it "history injection". Looking forward to your comments for a more proper implementation.

import json
import os
from typing import List, Union
import textgrad as tg
from textgrad import Variable
from textgrad.engine.openai import ChatOpenAI

os.environ["OPENAI_API_KEY"] = ""


class ChatOpenAIWithHistory(ChatOpenAI):
    def __init__(self, *args, **kwargs):
        self.history_messsages = []
        super().__init__(*args, **kwargs)

    def inject_history(self, messages: list[dict]) -> None:
        self.history_messsages = messages

    def _generate_from_single_prompt(
        self,
        prompt: str,
        system_prompt: str = None,
        temperature=0,
        max_tokens=2000,
        top_p=0.99,
    ):
        sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

        cache_or_none = self._check_cache(sys_prompt_arg + prompt)
        if cache_or_none is not None:
            return cache_or_none

        messages = [
            {"role": "system", "content": sys_prompt_arg},
            *self.history_messsages,
            {"role": "user", "content": prompt},
        ]
        self.history_messsages.clear()
        response = self.client.chat.completions.create(
            model=self.model_string,
            messages=messages,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
        )

        response = response.choices[0].message.content
        self._save_cache(sys_prompt_arg + prompt, response)
        return response

    def _generate_from_multiple_input(
        self,
        content: List[Union[str, bytes]],
        system_prompt=None,
        temperature=0,
        max_tokens=2000,
        top_p=0.99,
    ):
        sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
        formatted_content = self._format_content(content)

        cache_key = sys_prompt_arg + json.dumps(formatted_content)
        cache_or_none = self._check_cache(cache_key)
        if cache_or_none is not None:
            return cache_or_none

        messages = [
            {"role": "system", "content": sys_prompt_arg},
            *self.history_messsages,
            {"role": "user", "content": formatted_content},
        ]
        self.history_messsages.clear()
        response = self.client.chat.completions.create(
            model=self.model_string,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
        )

        response_text = response.choices[0].message.content
        self._save_cache(cache_key, response_text)
        return response_text


class BlackboxLLMWithHistory(tg.BlackboxLLM):
    def forward(self, x: Variable, history: list[dict] = []) -> Variable:
        if history and hasattr(self.engine, "inject_history"):
            self.engine.inject_history(history)

        return self.llm_call(x)


tg.set_backward_engine("gpt-4o", override=True)

# Step 1: Get an initial response from an LLM.
model = BlackboxLLMWithHistory(ChatOpenAIWithHistory("gpt-4o"))
question_string = (
    "If it takes 1 hour to dry 25 shirts under the sun, "
    "how long will it take to dry 30 shirts under the sun? "
    "Reason step by step"
)


question = tg.Variable(
    question_string, role_description="question to the LLM", requires_grad=False
)

history = [
    {
        "role": "user",
        "content": "Hi, how are you?",
    },
    {
        "role": "assistant",
        "content": "I'm fine!",
    },
]
answer = model(question, history=history)
print(answer)


answer.set_role_description("concise and accurate answer to the question")

# Step 2: Define the loss function and the optimizer, just like in PyTorch!
# Here, we don't have SGD, but we have TGD (Textual Gradient Descent)
# that works with "textual gradients".
optimizer = tg.TGD(parameters=[answer])
evaluation_instruction = (
    f"Here's a question: {question_string}. "
    "Evaluate any given answer to this question, "
    "be smart, logical, and very critical. "
    "Just provide concise feedback."
)


# TextLoss is a natural-language specified loss function that describes
# how we want to evaluate the reasoning.
loss_fn = tg.TextLoss(evaluation_instruction)

# Step 3: Do the loss computation, backward pass, and update the punchline.
# Exact same syntax as PyTorch!
loss = loss_fn(answer)
loss.backward()
optimizer.step()
answer1 = model(question, history)
print(answer1)

Best regards,
Shaojie Jiang

@mertyg
Copy link
Member

mertyg commented Aug 29, 2024

Thank you Shaojie! This is indeed useful to have and a good start, we'd love to merge this once it's in shape!

Something I'm not too sure about: Do you really need the history injection? e.g., in your BlackboxLLMWithHistory class, couldn't you directly pass a list of messages (including the variable at hand and the history) in the forward pass to _generate_from_multiple_input without using this new inject history function?

This is effectively what you are doing under the hood, but I think you could just extend the forward pass of LLM Call to pass not only the variable but also the history as a list to _generate_from_multiple_input.

@ShaojieJiang
Copy link
Author

Hi @mertyg , thanks for your swift response!

Good point! My short answer is that history injection is easier to implement.

Long answer

History through tg.Variable

I've thought about a more proper implementation. I considered passing the history as the content of tg.Variable. However, Variable only accept str type (for text) and byte (for other modality?). It might work if I just specify a list as the content, but it's too hacky and a more proper implementation requires a subclass of Variable, and this requires a better understanding of Variable, especially on how gradients are handled. And this is not the end of the story. If we go along this way, think about this not-so-edgy case:

  1. I may want to apply gradients to the question/last message to understand what could be said better
  2. I rarely want to apply gradients to history because part of the task objective is to see how the chatbot reacts in various situations, both in easy and difficult conversations.

Counting these factors in, history injection makes it much easier to fulfill the needs.

Modify LLMCall

If I pass a list of Variable's as the input, then I need to modify LLMCall, but again, this is more complex to implement and has broader implications.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants