Skip to content

Commit a966b4a

Browse files
ursgdaniandtheweb
authored andcommitted
Support BREAK pseudo-token
1 parent 14206fd commit a966b4a

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

conditioner.hpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
248248
const std::string& curr_text = item.first;
249249
float curr_weight = item.second;
250250
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
251-
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
252251
int32_t clean_index = 0;
252+
if(curr_text == "BREAK" && curr_weight == -1.0f) {
253+
// Pad token array up to chunk size at this point.
254+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
255+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
256+
int padding_size = 75 - (tokens_acc % 75);
257+
for (int j = 0; j < padding_size; j++) {
258+
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
259+
clean_index++;
260+
}
261+
262+
// After padding, continue to the next iteration to process the following text as a new segment
263+
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
264+
weights.insert(weights.end(), padding_size, curr_weight);
265+
continue;
266+
}
267+
268+
// Regular token, process normally
269+
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
253270
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
254271
int token_id = curr_tokens[i];
255-
if (token_id == image_token)
272+
if (token_id == image_token) {
256273
class_token_index.push_back(clean_index - 1);
257-
else {
274+
} else {
258275
clean_input_ids.push_back(token_id);
259276
clean_index++;
260277
}
@@ -354,6 +371,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
354371
for (const auto& item : parsed_attention) {
355372
const std::string& curr_text = item.first;
356373
float curr_weight = item.second;
374+
375+
if(curr_text == "BREAK" && curr_weight == -1.0f) {
376+
// Pad token array up to chunk size at this point.
377+
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
378+
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
379+
size_t current_size = tokens.size();
380+
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding
381+
382+
if (padding_size > 0) {
383+
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
384+
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
385+
weights.insert(weights.end(), padding_size, 1.0f);
386+
}
387+
continue; // Skip to the next item after handling BREAK
388+
}
389+
357390
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
358391
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
359392
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
@@ -1203,4 +1236,4 @@ struct FluxCLIPEmbedder : public Conditioner {
12031236
}
12041237
};
12051238

1206-
#endif
1239+
#endif

util.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <codecvt>
66
#include <fstream>
77
#include <locale>
8+
#include <regex>
89
#include <sstream>
910
#include <string>
1011
#include <thread>
@@ -606,7 +607,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
606607
float round_bracket_multiplier = 1.1f;
607608
float square_bracket_multiplier = 1 / 1.1f;
608609

609-
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
610+
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:]+|:)");
610611
std::regex re_break(R"(\s*\bBREAK\b\s*)");
611612

612613
auto multiply_range = [&](int start_position, float multiplier) {
@@ -639,6 +640,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
639640
square_brackets.pop_back();
640641
} else if (text == "\\(") {
641642
res.push_back({text.substr(1), 1.0f});
643+
} else if (std::regex_search(text, re_break)) {
644+
res.push_back({"BREAK", -1.0f});
642645
} else {
643646
res.push_back({text, 1.0f});
644647
}

0 commit comments

Comments
 (0)