-
Notifications
You must be signed in to change notification settings - Fork 12.1k
server: Experimental new speculative decoding algorithm #14132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I had this kinda backwards: it's the break-even probabilities (ie: the reciprocals of the relative costs) that are getting under-estimated (at the tail), and that's not really what we want (but not sure it will matter that much...). We could add an extra parameter like so: import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
# Your data points
#y_data = np.array([ 1, 0.526, 0.352, 0.269, 0.229, 0.226, 0.217, 0.209, 0.137, 0.123, 0.112, 0.103, 0.095, 0.089, 0.083, 0.078, 0.075, 0.071, 0.068, 0.064, 0.061, 0.059, 0.056, 0.054, 0.051, 0.049, 0.047, 0.046, 0.044, 0.043, 0.042, 0.04 ])
#y_data = np.array([ 1, 1.544, 1.045, 0.823, 0.695, 0.604, 0.544, 0.494, 0.472, 0.439, 0.413, 0.391, 0.374, 0.35, 0.342, 0.33, 0.323, 0.316, 0.309, 0.301, 0.295, 0.29, 0.285, 0.281, 0.277, 0.274, 0.27, 0.268, 0.264, 0.261, 0.259, 0.257 ])
y_data = np.array([ 1, 0.652, 0.469, 0.363, 0.312, 0.278, 0.252, 0.235, 0.222, 0.213, 0.205, 0.201, 0.197, 0.196, 0.196, 0.194, 0.191, 0.189, 0.188, 0.186, 0.185, 0.183, 0.181, 0.18, 0.179, 0.179, 0.179, 0.178, 0.178, 0.177, 0.177, 0.176 ])
#y_data = np.array([ 1, 0.638, 0.459, 0.36, 0.309, 0.279, 0.254, 0.236, 0.225, 0.214, 0.208, 0.203, 0.2, 0.198, 0.198, 0.196 ])
x_data = np.arange(len(y_data))
# Find the first value less than 1
n_skipped = 0
for i, val in enumerate(y_data):
if val < 1:
n_skipped = i
break
# Get the base value (first value less than 1)
base_value = y_data[n_skipped]
# Define the modified power decay function with asymptotic value c
def power_decay(x, b, c):
return (base_value - c) * (x + 1)**(-b) + c
# Adjust the data to start from the first value < 1
x_fit = x_data[n_skipped:] - n_skipped # Shift x to start at 0
y_fit = y_data[n_skipped:]
try:
# Fit the function with initial guesses for b and c
# c should be between 0 and base_value, b positive
popt_power, _ = curve_fit(power_decay, x_fit, y_fit, p0=[0.5, 0.0])
power = popt_power[0]
offset = popt_power[1]
# Calculate fitted values
y_fitted = power_decay(x_fit, power, offset)
# Plot results
plt.figure(figsize=(10, 6))
plt.scatter(x_data, y_data, label='Original Data')
plt.scatter(x_fit + n_skipped, y_fit, color='red', label='Data used for fitting')
plt.plot(x_fit + n_skipped, y_fitted, label=f'Modified Power fit: ({base_value:.2f}-{offset:.2f})*x^(-{power:.2f}) + {offset:.2f}')
plt.legend()
plt.xlabel('Drafted Tokens')
plt.ylabel('Relative Cost')
plt.title('Modified Power Law Fitting')
plt.show()
# Calculate and print RMSE
def rmse(y_true, y_pred):
return np.sqrt(np.mean((y_true - y_pred)**2))
print("RMSE for modified power fit:", rmse(y_fit, y_fitted))
# Print the actual line and rounded version
print(f"\nActual fit line: ({base_value} - {offset})*x^(-{power}) + {offset}")
print(f"Rounded fit line: ({base_value:.2f} - {offset:.2f})*x^(-{power:.2f}) + {offset:.2f}")
# Print suggested PR parameters for llama-server
print("\nSuggested PR parameters for llama-server:\n")
print(f"--draft-min {n_skipped}")
print(f"--draft-max {len(y_data)}")
print(f"--draft-p-min 0.{100.0*base_value:.0f}{100.0*power:.0f}{100.0*offset:.0f}{n_skipped} (NOTE: Encoded as 0.{{base}}{{power}}{{offset}}{{min}} for use with this PR only!)")
except Exception as e:
print("Error during fitting:", e) and the fit it much better: but it's starting to get much more complex and unintuitive to set by hand... :/ It may not matter that much in practice too:
|
Just tested on an Qwen-2.5-Coder:32B
Qwen-3:32B
This actually suggests to me that it would be much better to pass a vector of probability thresholds (eg: like my first attempt at this) rather than attempt any type of fit... The potential gains for models like this could be huge, as it should be clear that using a fixed I'll try and see if I can find a way to pass the vector more cleanly than my |
It also shows that the |
I've figured out what's causing the extreme jumps for the CUDA and Metal tests:
So rerunning the tests now to see what comes out with
|
Surprisingly, increasing the existing context doesn't really change the jaggedness and just moves up the asymptote due to the extra constant overhead. We could easily add another parameter to model the overhead: def power_decay(x, b, base_value):
return (base_value - c) * (x + 1)**(-b) + c but the jagged line is really a killer of the whole idea so gonna close this and have a rethink over the weekend if there is anything better we can do... It definitely looks like there are some serious gains to be made here, but having to run the |
How to use this PR
1. Edit the parameters of the
llama-batched-bench
call in theBENCHMARK_COMMAND
command here to match how you intend to use the model, eg:NOTE: You will need the
jq
tool installed for the results processing.NOTE: Not all
lama-server
parameters are available for use withllama-batched-bench
.2. Run this and it will produce a line of python that looks like this:
3. Copy that line into this python program replacing the line under the
# Your data points
comment:4. Run this (eg: online here: https://python-fiddle.com/examples/matplotlib).
and it will produce some output like this:
and a graph:
5. Then run your
llama-server
using the draft parameters it has generated, eg:You can also manually set the parameters, eg:
--draft-p-min
translates to use this power-law formula:
and with the last digit set to always match your
--draft-min 1
.NOTE: Don't change the
--draft-min
without also changing this last digit or the formula will be completely wrong!So, in general you can set:
where
base
andpower
are always 2 digits andmin
is always 1 digit.(sorry it's such a crappy way of doing this, but if this shows more promise I will add the proper command line arg(s) later...).
The discussion that led to the idea of this PR start here:
#10466 (comment)
and the basic idea is that the marginal cost of adding 1 more token to a batch goes way down as you add more and more tokens, but different models have very different cost profiles.
For example here I repeat the above for
deepseek-v3-0324
:and we get a completely different set of values:
which give a different set of optimal values:
So basically drafts of less than 3 tokens always have negative expectation here!
(also note in the
qwen-2.5-coder:32b
the effect the flash attention kernels have for the small batch sizes)One final thing to note is that the linked discussion shows that a rational approximation fit the data much better, but actually the power-law fit is better at modelling the marginal cost here as it tends to under-estimate the costs (ie: the gradient of the line is usually steeper than the data shows), and this in turn reduces the need to recalibrate the draft models' output.
It's also:
--draft-p-min
if it ever makes it into the code.I'm keen to get some feedback on this, as for my use cases and models; it looks to be quite a big improvement and also seems to work really well for different levels of "draftability" without the need to reload the model for refactoring tasks, etc.