Skip to content

Commit

Permalink
refactor: tool calling
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Dec 20, 2024
1 parent 9f2d2d8 commit 89783b4
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 200 deletions.
160 changes: 98 additions & 62 deletions llama-box/patches/llama.cpp/tool_calling.patch

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions llama-box/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,7 +1867,7 @@ struct server_context {
slot.n_sent_text += result.text_to_send.size();
} catch (const std::exception &e) {
SLT_ERR(slot, "failed to parse tool call: %s, fallback\n", e.what());
send_text = true;
send_text = llama_token_is_eog(llm_model, result.toks[result.toks.size()-1]);
}
}
}
Expand Down Expand Up @@ -4469,7 +4469,7 @@ int main(int argc, char **argv) {
return;
}
if (oaicompat) {
request = oaicompat_completions_request(ctx_server.llm_params, rid, request, ctx_server.llm_model, false);
request = oaicompat_completions_request(ctx_server.llm_params, rid, request, ctx_server.llm_model, false, false);
}

// construct task
Expand Down Expand Up @@ -4597,7 +4597,7 @@ int main(int argc, char **argv) {
res_error(res, format_error_response("\"messages\" must be provided and must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}
request = oaicompat_completions_request(ctx_server.llm_params, rid, request, ctx_server.llm_model, true);
request = oaicompat_completions_request(ctx_server.llm_params, rid, request, ctx_server.llm_model, true, ctx_server.support_tool_calls);

// construct task
std::vector<server_task> tasks = ctx_server.create_tasks_inference(rid, request, SERVER_TASK_TYPE_COMPLETION, tps);
Expand Down
225 changes: 134 additions & 91 deletions llama-box/tools/chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,25 @@
# SPDX-License-Identifier: MIT
#

ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"

LOG_FILE=${LOG_FILE:-/dev/null}

API_URL="${API_URL:-http://127.0.0.1:8080}"

MESSAGES=(
"{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"}"
"{\"role\":\"system\",\"content\":\"You are a helpful assistant with tool calling capabilities. Today is $(date +"%Y-%m-%d").\"}"
)

TOOLNAMES=()
TOOLS=()
for file in "${ROOT_DIR}/"*; do
if [[ -f "${file}" ]] && [[ "${file}" =~ .*/chat_tool_.*\.sh ]]; then
# shellcheck disable=SC1090
source "${file}"
fi
done

trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
Expand Down Expand Up @@ -53,101 +64,132 @@ chat_completion() {
else
DATA="{\"messages\":[$(format_messages){\"role\":\"user\",\"content\":\"${PROMPT}\"}]}"
fi
DATA="$(echo -n "${DATA}" | jq -cr \
--argjson frequency_penalty "${FREQUENCY_PENALTY}" \
--argjson logprobs "${LOGPROBS}" \
--argjson top_logprobs "${TOP_LOGPROBS}" \
--argjson max_tokens "${MAX_TOKENS}" \
--argjson presence_penalty "${PRESENCE_PENALTY}" \
--argjson response_format "{\"type\":\"${RESPONSE_FORMAT}\"}" \
--argjson seed "${SEED}" \
--argjson stop "${STOP}" \
--argjson temp "${TEMP}" \
--argjson top_p "${TOP_P}" \
'{
frequency_penalty: $frequency_penalty,
logprobs: $logprobs,
top_logprobs: $top_logprobs,
max_tokens: $max_tokens,
n: 1,
presence_penalty: $presence_penalty,
response_format: $response_format,
seed: $seed,
stop: $stop,
stream: true,
stream_options: {include_usage: true},
temperature: $temp,
top_p: $top_p
} * .')"
echo "Q: ${DATA}" >>"${LOG_FILE}"
echo "${DATA}" > /tmp/request.json

ANSWER=''
PRE_CONTENT=''
START_TIME=$(date +%s)

while IFS= read -r LINE; do
echo "A: ${LINE}" >>"${LOG_FILE}"
if [[ ! "${LINE}" = data:* ]]; then
if [[ "${LINE}" =~ error:.* ]]; then
LINE="${LINE:7}"
echo "Error: ${LINE}"
while true; do
DATA="$(echo -n "${DATA}" | jq -cr \
--argjson frequency_penalty "${FREQUENCY_PENALTY}" \
--argjson logprobs "${LOGPROBS}" \
--argjson top_logprobs "${TOP_LOGPROBS}" \
--argjson max_tokens "${MAX_TOKENS}" \
--argjson presence_penalty "${PRESENCE_PENALTY}" \
--argjson response_format "{\"type\":\"${RESPONSE_FORMAT}\"}" \
--argjson seed "${SEED}" \
--argjson stop "${STOP}" \
--argjson temp "${TEMP}" \
--argjson top_p "${TOP_P}" \
--argjson tools "$(printf '%s\n' "${TOOLS[@]}" | jq -cs .)" \
'{
frequency_penalty: $frequency_penalty,
logprobs: $logprobs,
top_logprobs: $top_logprobs,
max_tokens: $max_tokens,
n: 1,
presence_penalty: $presence_penalty,
response_format: $response_format,
seed: $seed,
stop: $stop,
stream: true,
stream_options: {include_usage: true},
temperature: $temp,
top_p: $top_p,
tools: $tools,
parallel_tool_calls: false
} * .')"
echo "Q: ${DATA}" >>"${LOG_FILE}"
echo "${DATA}" >/tmp/request.json

TOOL_CALLS=''
TOOL_RESULTS=()
CONTENT=''
PRE_CONTENT=''
START_TIME=$(date +%s)

while IFS= read -r LINE; do
echo "A: ${LINE}" >>"${LOG_FILE}"
if [[ ! "${LINE}" = data:* ]]; then
if [[ "${LINE}" =~ error:.* ]]; then
LINE="${LINE:7}"
echo "Error: ${LINE}"
fi
continue
fi
continue
fi
if [[ "${LINE}" =~ data:\ \[DONE\].* ]]; then
break
fi
LINE="${LINE:5}"
CONTENT="$(echo "${LINE}" | jq -cr '.choices[0].delta.content')"
if [[ "${CONTENT}" == "null" ]]; then
CONTENT=""
if [[ "${LINE}" =~ data:\ \[DONE\].* ]]; then
break
fi
LINE="${LINE:5}"
TOOL_CALLS="$(echo "${LINE}" | jq -cr '.choices[0].delta.tool_calls')"
if [[ "${TOOL_CALLS}" != "null" ]]; then
while IFS= read -r TOOL_CALL; do
ID="$(echo "${TOOL_CALL}" | jq -cr '.id')"
FUNC_NAME="$(echo "${TOOL_CALL}" | jq -cr '.function.name')"
FUNC_ARGS="$(echo "${TOOL_CALL}" | jq -cr '.function.arguments')"
printf "Call: %s %s %s\n" "${FUNC_NAME}" "${FUNC_ARGS}" "${ID}"
RESULT=$("${FUNC_NAME}" "${FUNC_ARGS}" "${ID}")
printf "Result: %s\n\n" "${RESULT}"
TOOL_RESULTS+=("${RESULT}")
done < <(jq -cr '.[]' <<<"${TOOL_CALLS}")
else
TOOL_CALLS=''
fi
CONTENT_SEG="$(echo "${LINE}" | jq -cr '.choices[0].delta.content')"
if [[ "${CONTENT_SEG}" != "null" ]]; then
if [[ "${PRE_CONTENT: -1}" == "\\" ]] && [[ "${CONTENT_SEG}" =~ ^b|n|r|t|\\|\'|\"$ ]]; then
printf "\b "
case "${CONTENT_SEG}" in
b) printf "\b\b" ;;
n) printf "\b\n" ;;
r) printf "\b\r" ;;
t) printf "\b\t" ;;
\\) printf "\b\\" ;;
\') printf "\b'" ;;
\") printf "\b\"" ;;
esac
CONTENT_SEG=""
else
PRE_CONTENT="${CONTENT_SEG}"
printf "%s" "${CONTENT_SEG}"
fi
CONTENT+="${CONTENT_SEG}"
fi
USAGE="$(echo "${LINE}" | jq -cr '.usage')"
if [[ "${USAGE}" != "null" ]]; then
printf "\n------------------------"
printf "\n- TTFT : %10.2fms -" "$(echo "${USAGE}" | jq -cr '.time_to_first_token_ms')"
printf "\n- TBT : %10.2fms -" "$(echo "${USAGE}" | jq -cr '.time_per_output_token_ms')"
printf "\n- TPS : %10.2f -" "$(echo "${USAGE}" | jq -cr '.tokens_per_second')"
DRAFTED_N="$(echo "${USAGE}" | jq -cr '.draft_tokens')"
if [[ "${DRAFTED_N}" != "null" ]]; then
printf "\n- DT : %10d -" "${DRAFTED_N}"
printf "\n- DTA : %10.2f%% -" "$(echo "${USAGE}" | jq -cr '.draft_tokens_acceptance*100')"
fi
ELAPSED=$(($(date +%s) - START_TIME))
printf "\n- TC : %10.2fs -" "${ELAPSED}"
printf "\n------------------------\n"
break
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/v1/chat/completions" \
--header "Content-Type: application/json" \
--data @/tmp/request.json)

printf "\n"

MESSAGES+=("{\"role\":\"user\",\"content\":\"$PROMPT\"}")
if [[ -n "${TOOL_CALLS}" ]]; then
MESSAGES+=("{\"role\":\"assistant\",\"tool_calls\":$TOOL_CALLS}")
fi
if [[ "${PRE_CONTENT: -1}" == "\\" ]] && [[ "${CONTENT}" =~ ^b|n|r|t|\\|\'|\"$ ]]; then
printf "\b "
case "${CONTENT}" in
b) printf "\b\b" ;;
n) printf "\b\n" ;;
r) printf "\b\r" ;;
t) printf "\b\t" ;;
\\) printf "\b\\" ;;
\') printf "\b'" ;;
\") printf "\b\"" ;;
esac
CONTENT=""
if [[ -n "${CONTENT}" ]]; then
MESSAGES+=("{\"role\":\"assistant\",\"content\":\"$CONTENT\"}")
fi
PRE_CONTENT="${CONTENT}"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
USAGE="$(echo "${LINE}" | jq -cr '.usage')"
if [[ "${USAGE}" != "null" ]]; then
printf "\n------------------------"
printf "\n- TTFT : %10.2fms -" "$(echo "${USAGE}" | jq -cr '.time_to_first_token_ms')"
printf "\n- TBT : %10.2fms -" "$(echo "${USAGE}" | jq -cr '.time_per_output_token_ms')"
printf "\n- TPS : %10.2f -" "$(echo "${USAGE}" | jq -cr '.tokens_per_second')"
DRAFTED_N="$(echo "${USAGE}" | jq -cr '.draft_tokens')"
if [[ "${DRAFTED_N}" != "null" ]]; then
printf "\n- DT : %10d -" "${DRAFTED_N}"
printf "\n- DTA : %10.2f%% -" "$(echo "${USAGE}" | jq -cr '.draft_tokens_acceptance*100')"
fi
ELAPSED=$(($(date +%s) - START_TIME))
printf "\n- TC : %10.2fs -" "${ELAPSED}"
printf "\n------------------------\n"
if [[ "${#TOOL_RESULTS[@]}" -gt 0 ]]; then
MESSAGES+=("${TOOL_RESULTS[@]}")
DATA="{\"messages\":$(printf '%s\n' "${MESSAGES[@]}" | jq -cs .)}"
else
break
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/v1/chat/completions" \
--header "Content-Type: application/json" \
--data @/tmp/request.json)

printf "\n"

MESSAGES+=(
"{\"role\":\"user\",\"content\":\"$PROMPT\"}"
"{\"role\":\"assistant\",\"content\":\"$ANSWER\"}")
done
}

echo "====================================================="
Expand All @@ -163,6 +205,7 @@ echo "SEED : ${SEED}"
echo "STOP : ${STOP}"
echo "TEMP : ${TEMP}"
echo "TOP_P : ${TOP_P}"
echo "TOOLS : $(printf '%s\n' "${TOOLNAMES[@]}" | jq -R . | jq -cs .)"
printf "=====================================================\n\n"

if [[ -f "${LOG_FILE}" ]]; then
Expand Down
24 changes: 24 additions & 0 deletions llama-box/tools/chat_tool_get_location.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

#
# MIT license
# Copyright (c) 2024 llama-box authors
# SPDX-License-Identifier: MIT
#

function get_location() {
ARGS="${1}"
ID="${2}"

LOCATION="$(curl -s https://wttr.in/?format="%l")"

MESSAGE="{\"role\":\"tool\",\"content\":\"{\\\"location\\\":\\\"${LOCATION}\\\"}\",\"tool_call_id\":\"${ID}\"}"
echo "${MESSAGE}"
}

function register_tool_get_location() {
TOOLNAMES+=("get_location")
TOOLS+=("{\"type\":\"function\",\"function\":{\"name\":\"get_location\",\"description\":\"Return the location without any arguments.\",\"parameters\":{\"type\":\"object\",\"properties\":{},\"required\":[]}}}")
}

register_tool_get_location
25 changes: 25 additions & 0 deletions llama-box/tools/chat_tool_get_temperature.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

#
# MIT license
# Copyright (c) 2024 llama-box authors
# SPDX-License-Identifier: MIT
#

function get_temperature() {
ARGS="${1}"
ID="${2}"

LOCATION=$(echo "${ARGS}" | jq -cr '.location')
TEMPERATURE="$(curl -s https://wttr.in/"${LOCATION}"?format="%t")"

MESSAGE="{\"role\":\"tool\",\"content\":\"{\\\"temperature\\\":\\\"${TEMPERATURE%%°C}\\\"}\",\"tool_call_id\":\"${ID}\"}"
echo "${MESSAGE}"
}

function register_tool_get_temperature() {
TOOLNAMES+=("get_temperature")
TOOLS+=("{\"type\":\"function\",\"function\":{\"name\":\"get_temperature\",\"description\":\"Return the degrees Celsius temperature value of the given location.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\"}},\"required\":[\"location\"]}}}")
}

register_tool_get_temperature
25 changes: 25 additions & 0 deletions llama-box/tools/chat_tool_get_weather.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

#
# MIT license
# Copyright (c) 2024 llama-box authors
# SPDX-License-Identifier: MIT
#

function get_weather() {
ARGS="${1}"
ID="${2}"

LOCATION=$(echo "${ARGS}" | jq -cr '.location')
WEATHER="$(curl -s https://wttr.in/"${LOCATION}"?format="%C")"

MESSAGE="{\"role\":\"tool\",\"content\":\"{\\\"weather\\\":\\\"${WEATHER}\\\"}\",\"tool_call_id\":\"${ID}\"}"
echo "${MESSAGE}"
}

function register_tool_get_weather() {
TOOLNAMES+=("get_weather")
TOOLS+=("{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Return the weather by the given location.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\"}},\"required\":[\"location\"]}}}")
}

register_tool_get_weather
Loading

0 comments on commit 89783b4

Please sign in to comment.