Skip to content

Commit 74edb42

Browse files
committed
llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
1 parent 873279b commit 74edb42

File tree

16 files changed

+49
-49
lines changed

16 files changed

+49
-49
lines changed

common/common.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
554554
<< ":pos " << std::to_string(batch.pos[i])
555555
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
556556
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
557-
<< ":logits " << std::to_string(batch.logits[i]);
557+
<< ":output " << std::to_string(batch.output[i]);
558558
}
559559

560560
buf << " ]";
@@ -1480,7 +1480,7 @@ void common_batch_add(
14801480
llama_token id,
14811481
llama_pos pos,
14821482
const std::vector<llama_seq_id> & seq_ids,
1483-
bool logits) {
1483+
bool output) {
14841484
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
14851485

14861486
batch.token [batch.n_tokens] = id;
@@ -1489,7 +1489,7 @@ void common_batch_add(
14891489
for (size_t i = 0; i < seq_ids.size(); ++i) {
14901490
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
14911491
}
1492-
batch.logits [batch.n_tokens] = logits;
1492+
batch.output [batch.n_tokens] = output;
14931493

14941494
batch.n_tokens++;
14951495
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373
batch.pos + i,
7474
batch.n_seq_id + i,
7575
batch.seq_id + i,
76-
batch.logits + i,
76+
batch.output + i,
7777
};
7878

7979
const int ret = llama_decode(ctx, batch_view);
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
common_batch_add(batch, 0, i, { j }, false);
129129
}
130130
}
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
const auto t_pp_start = ggml_time_us();
134134

examples/batched.swift/Sources/main.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ for (i, token) in tokens.enumerated() {
9999
if let seq_id = batch.seq_id[i] {
100100
seq_id[0] = 0
101101
}
102-
batch.logits[i] = 0
102+
batch.output[i] = 0
103103
}
104104

105105
// llama_decode will output logits only for the last token of the prompt
106-
batch.logits[Int(batch.n_tokens) - 1] = 1
106+
batch.output[Int(batch.n_tokens) - 1] = 1
107107

108108
if llama_decode(context, batch) != 0 {
109109
print("llama_decode() failed")
@@ -166,7 +166,7 @@ while n_cur <= n_len {
166166
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
167167
seq_id[0] = Int32(i)
168168
}
169-
batch.logits[Int(batch.n_tokens)] = 1
169+
batch.output[Int(batch.n_tokens)] = 1
170170

171171
i_batch[i] = batch.n_tokens
172172

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
128128
}
129129

130130
// llama_decode will output logits only for the last token of the prompt
131-
batch.logits[batch.n_tokens - 1] = true;
131+
batch.output[batch.n_tokens - 1] = true;
132132

133133
if (llama_decode(ctx, batch) != 0) {
134134
LOG_ERR("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5454
}
5555

5656
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
57+
if (!batch.output[i]) {
5858
continue;
5959
}
6060

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
193193
common_batch_add(*batch, 0, i, { 0 }, false);
194194
}
195195

196-
batch->logits[batch->n_tokens - 1] = true;
196+
batch->output[batch->n_tokens - 1] = true;
197197
llama_kv_cache_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
@@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
297297
for (int i = 0; i < n_tokens; ++i) {
298298
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299299
}
300-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
300+
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
301301

302302
return reinterpret_cast<jlong>(batch);
303303
}
@@ -377,7 +377,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
377377
}
378378

379379
// llama_decode will output logits only for the last token of the prompt
380-
batch->logits[batch->n_tokens - 1] = true;
380+
batch->output[batch->n_tokens - 1] = true;
381381

382382
if (llama_decode(context, *batch) != 0) {
383383
LOGe("llama_decode() failed");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ actor LlamaContext {
137137
let i = Int(i1)
138138
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
139139
}
140-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
140+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
141141

142142
if llama_decode(context, batch) != 0 {
143143
print("llama_decode() failed")
@@ -206,7 +206,7 @@ actor LlamaContext {
206206
for i in 0..<n_tokens {
207207
llama_batch_add(&batch, 0, Int32(i), [0], false)
208208
}
209-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
209+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
210210

211211
llama_kv_cache_clear(context)
212212

examples/llava/llava.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,13 @@ struct llava_embd_batch {
406406
std::vector<int32_t> n_seq_id;
407407
std::vector<llama_seq_id> seq_id_0;
408408
std::vector<llama_seq_id *> seq_ids;
409-
std::vector<int8_t> logits;
409+
std::vector<int8_t> outputs;
410410
llama_batch batch;
411411
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412412
pos .resize(n_tokens);
413413
n_seq_id.resize(n_tokens);
414414
seq_ids .resize(n_tokens + 1);
415-
logits .resize(n_tokens);
415+
outputs .resize(n_tokens);
416416
seq_id_0.resize(1);
417417
seq_id_0[0] = seq_id;
418418
seq_ids [n_tokens] = nullptr;
@@ -423,13 +423,13 @@ struct llava_embd_batch {
423423
/*pos =*/ pos.data(),
424424
/*n_seq_id =*/ n_seq_id.data(),
425425
/*seq_id =*/ seq_ids.data(),
426-
/*logits =*/ logits.data(),
426+
/*output =*/ outputs.data(),
427427
};
428428
for (int i = 0; i < n_tokens; i++) {
429429
batch.pos [i] = pos_0 + i;
430430
batch.n_seq_id[i] = 1;
431431
batch.seq_id [i] = seq_id_0.data();
432-
batch.logits [i] = false;
432+
batch.output [i] = false;
433433
}
434434
}
435435
};

examples/parallel/parallel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
264264

265265
// extract the logits only for the last token
266266
if (batch.n_tokens > 0) {
267-
batch.logits[batch.n_tokens - 1] = true;
267+
batch.output[batch.n_tokens - 1] = true;
268268
}
269269

270270
client.n_prompt = tokens_prompt.size();
@@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
307307
batch.pos + i,
308308
batch.n_seq_id + i,
309309
batch.seq_id + i,
310-
batch.logits + i,
310+
batch.output + i,
311311
};
312312

313313
const int ret = llama_decode(ctx, batch_view);

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
144144
}
145145

146146
if (i + n_batch >= n_tokens_all) {
147-
batch.logits[batch.n_tokens - 1] = true;
147+
batch.output[batch.n_tokens - 1] = true;
148148
}
149149

150150
if (llama_decode(ctx, batch) != 0) {
@@ -178,7 +178,7 @@ int main(int argc, char ** argv) {
178178
}
179179

180180
if (i + n_batch >= n_tokens_all) {
181-
batch.logits[batch.n_tokens - 1] = true;
181+
batch.output[batch.n_tokens - 1] = true;
182182
}
183183

184184
if (llama_decode(ctx, batch) != 0) {

examples/perplexity/perplexity.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
615615
batch.pos [idx] = j*n_batch + k;
616616
batch.n_seq_id[idx] = 1;
617617
batch.seq_id [idx][0] = seq;
618-
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
618+
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
619619

620-
n_outputs += batch.logits[idx] != 0;
620+
n_outputs += batch.output[idx] != 0;
621621
}
622622
batch.n_tokens += batch_size;
623623

@@ -712,7 +712,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
712712
batch.pos + i,
713713
batch.n_seq_id + i,
714714
batch.seq_id + i,
715-
batch.logits + i,
715+
batch.output + i,
716716
};
717717

718718
const int ret = llama_decode(ctx, batch_view);
@@ -723,7 +723,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
723723

724724
int n_outputs = 0;
725725
for (int i = 0; i < n_tokens; ++i) {
726-
n_outputs += batch_view.logits[i] != 0;
726+
n_outputs += batch_view.output[i] != 0;
727727
}
728728

729729
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@@ -936,7 +936,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
936936
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
937937
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
938938
}
939-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
939+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
940940
n_logits += 1;
941941

942942
for (int s = 0; s < 4; ++s) {
@@ -1215,7 +1215,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
12151215
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
12161216
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
12171217
}
1218-
batch.logits[batch.n_tokens - 1] = true;
1218+
batch.output[batch.n_tokens - 1] = true;
12191219
n_logits += 1;
12201220

12211221
for (int s = 0; s < 2; ++s) {
@@ -1581,7 +1581,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15811581
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
15821582
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
15831583
}
1584-
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1584+
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
15851585
n_logits += 1;
15861586

15871587
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
95-
if (!batch.logits[i]) {
95+
if (!batch.output[i]) {
9696
continue;
9797
}
9898

examples/save-load-state/save-load-state.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
5252
for (size_t i = 0; i < tokens.size(); i++) {
5353
common_batch_add(batch, tokens[i], i, {0}, false);
5454
}
55-
batch.logits[batch.n_tokens - 1] = true; // generate next token
55+
batch.output[batch.n_tokens - 1] = true; // generate next token
5656

5757
// evaluate prompt
5858
llama_decode(ctx, batch);

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ struct server_context {
13821382
std::vector<float> embd_res(n_embd, 0.0f);
13831383

13841384
for (int i = 0; i < batch.n_tokens; ++i) {
1385-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1385+
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
13861386
continue;
13871387
}
13881388

@@ -1422,7 +1422,7 @@ struct server_context {
14221422
res.stop = true;
14231423

14241424
for (int i = 0; i < batch.n_tokens; ++i) {
1425-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1425+
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
14261426
continue;
14271427
}
14281428

@@ -2289,7 +2289,7 @@ struct server_context {
22892289
GGML_ASSERT(batch.n_tokens > 0);
22902290

22912291
// extract the logits only for the last token
2292-
batch.logits[batch.n_tokens - 1] = true;
2292+
batch.output[batch.n_tokens - 1] = true;
22932293

22942294
slot.n_decoded = 0;
22952295
slot.i_batch = batch.n_tokens - 1;
@@ -2325,7 +2325,7 @@ struct server_context {
23252325
batch.pos + i,
23262326
batch.n_seq_id + i,
23272327
batch.seq_id + i,
2328-
batch.logits + i,
2328+
batch.output + i,
23292329
};
23302330

23312331
const int ret = llama_decode(ctx, batch_view);

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ extern "C" {
247247
llama_pos * pos;
248248
int32_t * n_seq_id;
249249
llama_seq_id ** seq_id;
250-
int8_t * logits; // TODO: rename this to "output"
250+
int8_t * output;
251251
} llama_batch;
252252

253253
enum llama_model_kv_override_type {

src/llama.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3072,17 +3072,17 @@ struct llama_sbatch {
30723072
ubatch.output[ubatch.n_tokens + i] = 1;
30733073
out_ids.push_back(ids[seq.offset + i]);
30743074
}
3075-
} else if (batch->logits) {
3075+
} else if (batch->output) {
30763076
if (ubatch.equal_seqs) {
30773077
for (size_t i = 0; i < length; ++i) {
30783078
size_t id = ids[seq.offset + i];
3079-
int8_t is_output = batch->logits[id];
3079+
int8_t is_output = batch->output[id];
30803080
ubatch.output[ubatch.n_tokens + i] = is_output;
30813081
if (is_output) { out_ids.push_back(id); }
30823082
}
30833083
} else {
30843084
// simple split
3085-
ubatch.output = batch->logits + seq.offset;
3085+
ubatch.output = batch->output + seq.offset;
30863086
for (size_t i = 0; i < length; ++i) {
30873087
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
30883088
}
@@ -5184,7 +5184,7 @@ struct llama_batch_allocr {
51845184
std::vector<llama_pos> pos;
51855185
std::vector<int32_t> n_seq_id;
51865186
std::vector<llama_seq_id *> seq_id;
5187-
std::vector<int8_t> logits;
5187+
std::vector<int8_t> outputs;
51885188
struct llama_batch batch;
51895189
// optionally fulfill the batch returned by llama_batch_get_one
51905190
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
@@ -5220,10 +5220,10 @@ struct llama_batch_allocr {
52205220
}
52215221
batch.seq_id = seq_id.data();
52225222
}
5223-
if (!batch.logits) {
5224-
logits.resize(batch.n_tokens);
5225-
logits[logits.size() - 1] = true;
5226-
batch.logits = logits.data();
5223+
if (!batch.output) {
5224+
outputs.resize(batch.n_tokens);
5225+
outputs[outputs.size() - 1] = true;
5226+
batch.output = outputs.data();
52275227
}
52285228
}
52295229
};
@@ -17200,9 +17200,9 @@ static int llama_decode_internal(
1720017200
lctx.embd_seq.clear();
1720117201

1720217202
// count outputs
17203-
if (batch.logits && !embd_pooled) {
17203+
if (batch.output && !embd_pooled) {
1720417204
for (uint32_t i = 0; i < n_tokens_all; ++i) {
17205-
n_outputs += batch.logits[i] != 0;
17205+
n_outputs += batch.output[i] != 0;
1720617206
}
1720717207
} else if (lctx.logits_all || embd_pooled) {
1720817208
n_outputs = n_tokens_all;
@@ -21129,7 +21129,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
2112921129
}
2113021130
batch.seq_id[n_tokens_alloc] = nullptr;
2113121131

21132-
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
21132+
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
2113321133

2113421134
return batch;
2113521135
}
@@ -21145,7 +21145,7 @@ void llama_batch_free(struct llama_batch batch) {
2114521145
}
2114621146
free(batch.seq_id);
2114721147
}
21148-
if (batch.logits) free(batch.logits);
21148+
if (batch.output) free(batch.output);
2114921149
}
2115021150

2115121151
int32_t llama_encode(

0 commit comments

Comments
 (0)