-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Token rotation #987
Token rotation #987
Changes from all commits
70eae6a
b350a8a
2296950
d0fb2f8
2ef481f
9586a82
65d72c9
f28eb35
a2e8f2d
207d069
619735c
510c29e
9327d71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,4 +267,99 @@ namespace ov::genai { | |
m_scores[decoder_layer_idx] = new_scores; | ||
m_cache_counter[decoder_layer_idx] = new_counter; | ||
} | ||
|
||
CacheRotationCalculator::CacheRotationCalculator(size_t block_size, | ||
size_t max_context_length_in_blocks, | ||
size_t kv_head_size, | ||
double rope_theta) | ||
: m_block_size(block_size), | ||
m_head_size(kv_head_size) { | ||
// Frequencies follow the original recipe from RoFormer: | ||
// https://arxiv.org/pdf/2104.09864v5 | ||
// | ||
// However, the way the rotation coefficients are ultimately applied in Llama and related models from | ||
// huggingface is very different from the RoFormer - the embedding-dimension coefficients are not treated as | ||
// consecutive x-y coordinate pairs, but are rather divided into contiguous x-like and y-like halves - see | ||
// `rotate_half` function in HF transformers. It can be shown that this form still preserves the relative | ||
// positioning property from the RoFormer article. | ||
Comment on lines
+283
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand how I should interpret this part. Does it mean that LUT computed here suits both interleaved and not-interleaved RoPE? Or only for one of them, but the difference can be ignored? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The LUT here suits only the non-interleaved case. |
||
OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive"); | ||
size_t num_freqs = kv_head_size / 2; | ||
m_rope_sin_lut.resize(max_context_length_in_blocks); | ||
m_rope_cos_lut.resize(max_context_length_in_blocks); | ||
|
||
for (size_t i = 0; i < max_context_length_in_blocks; i++) { | ||
m_rope_sin_lut[i].reserve(num_freqs); | ||
m_rope_cos_lut[i].reserve(num_freqs); | ||
for (size_t j = 0; j < num_freqs; j++) { | ||
double exponent = -static_cast<double>(2 * j) / kv_head_size; | ||
double base_angle = std::pow(rope_theta, exponent); | ||
m_rope_sin_lut[i].push_back( | ||
-std::sin(i * block_size * base_angle)); // minus since we will be rotating by an inverse angle | ||
m_rope_cos_lut[i].push_back(std::cos(i * block_size * base_angle)); | ||
} | ||
} | ||
} | ||
|
||
const std::vector<std::vector<float>>& CacheRotationCalculator::get_sin_lut() const { | ||
return m_rope_sin_lut; | ||
} | ||
|
||
const std::vector<std::vector<float>>& CacheRotationCalculator::get_cos_lut() const { | ||
return m_rope_cos_lut; | ||
} | ||
|
||
std::vector<CacheRotationCalculator::BlockRotationData> CacheRotationCalculator::get_rotation_data( | ||
const std::set<size_t>& evicted_block_logical_indices, | ||
size_t num_logical_blocks_before_eviction, | ||
bool deltas_only) { | ||
|
||
|
||
std::vector<BlockRotationData> retval; | ||
if (evicted_block_logical_indices.empty()) { | ||
return retval; | ||
} | ||
|
||
for (auto idx : evicted_block_logical_indices) { | ||
OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction); | ||
} | ||
|
||
// num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the | ||
// set property and the previous assertion | ||
retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); | ||
|
||
ptrdiff_t current_rotation_delta_in_blocks = 0; | ||
std::vector<size_t> logical_block_space(num_logical_blocks_before_eviction); | ||
std::iota(logical_block_space.begin(), logical_block_space.end(), 0); | ||
|
||
for (size_t logical_block_idx : logical_block_space) { | ||
if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) { | ||
current_rotation_delta_in_blocks += 1; | ||
} else { | ||
if (current_rotation_delta_in_blocks != 0) { | ||
BlockRotationData block_rotation_data; | ||
block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks; | ||
|
||
// rotation delta is in tokens, but LUT is in blocks right now since we evict per-block | ||
// delta recomputation to a valid LUT index is done at a later stage | ||
block_rotation_data.rotation_delta = current_rotation_delta_in_blocks * m_block_size; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the number of different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implemented the idea, adjusting for reality ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to revert this back to maximum sequence length because in the latest iteration, in effort to align with the Python POC, I only evict once the prompt is pre-filled, which means that the occupancy of cache by single sequence is still bound only by max sequence length. |
||
OPENVINO_ASSERT(block_rotation_data.rotation_delta / m_block_size <= m_rope_cos_lut.size(), "rotation delta larger than LUT size"); | ||
|
||
if (!deltas_only) { | ||
block_rotation_data.cosines.reserve(m_block_size); | ||
block_rotation_data.sines.reserve(m_block_size); | ||
for (size_t i = 0; i < m_block_size; i++) { | ||
block_rotation_data.cosines.push_back( | ||
m_rope_cos_lut[current_rotation_delta_in_blocks]); | ||
block_rotation_data.sines.push_back( | ||
m_rope_sin_lut[current_rotation_delta_in_blocks]); | ||
} | ||
} | ||
|
||
retval.push_back(block_rotation_data); | ||
} | ||
} | ||
} | ||
|
||
return retval; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we enable rotation automatically based on model topology without mandatory user-level control? I, as a user, have no idea when I should enable it, and will leave it as-is (disabled) or enable it every time. What is the recommendation for the user? Is
true
a better default value because majority of modern architectures uses RoPE?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't achieve expected accuracy results in currently available test cases right now with rotation enabled. Will probably set the default to
true
once we do get accuracy, but the user's model may have a differing RoPE scheme from what we currently support. I'll add some words about that to the docstring.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it is OK to merge it to master to continue development there, but not in the currently open release branch.