-
Notifications
You must be signed in to change notification settings - Fork 305
/
flux.hpp
961 lines (802 loc) · 46.2 KB
/
flux.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
#ifndef __FLUX_HPP__
#define __FLUX_HPP__
#include <vector>
#include "ggml_extend.hpp"
#include "model.h"
#define FLUX_GRAPH_SIZE 10240
namespace Flux {
struct MLPEmbedder : public UnaryBlock {
public:
MLPEmbedder(int64_t in_dim, int64_t hidden_dim) {
blocks["in_layer"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, true));
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [..., in_dim]
// return: [..., hidden_dim]
auto in_layer = std::dynamic_pointer_cast<Linear>(blocks["in_layer"]);
auto out_layer = std::dynamic_pointer_cast<Linear>(blocks["out_layer"]);
x = in_layer->forward(ctx, x);
x = ggml_silu_inplace(ctx, x);
x = out_layer->forward(ctx, x);
return x;
}
};
class RMSNorm : public UnaryBlock {
protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
}
public:
RMSNorm(int64_t hidden_size,
float eps = 1e-06f)
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["scale"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
return x;
}
};
struct QKNorm : public GGMLBlock {
public:
QKNorm(int64_t dim) {
blocks["query_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim));
blocks["key_norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim));
}
struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [..., dim]
// return: [..., dim]
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["query_norm"]);
x = norm->forward(ctx, x);
return x;
}
struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [..., dim]
// return: [..., dim]
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["key_norm"]);
x = norm->forward(ctx, x);
return x;
}
};
__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]
int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]
pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]
auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}
__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head]
return x;
}
struct SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]);
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
return {q, k, v};
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); // [N, n_token, dim]
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
struct ModulationOut {
ggml_tensor* shift = NULL;
ggml_tensor* scale = NULL;
ggml_tensor* gate = NULL;
ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL)
: shift(shift), scale(scale), gate(gate) {}
};
struct Modulation : public GGMLBlock {
public:
bool is_double;
int multiplier;
public:
Modulation(int64_t dim, bool is_double)
: is_double(is_double) {
multiplier = is_double ? 6 : 3;
blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier));
}
std::vector<ModulationOut> forward(struct ggml_context* ctx, struct ggml_tensor* vec) {
// x: [N, dim]
// return: [ModulationOut, ModulationOut]
auto lin = std::dynamic_pointer_cast<Linear>(blocks["lin"]);
auto out = ggml_silu(ctx, vec);
out = lin->forward(ctx, out); // [N, multiplier*dim]
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim]
auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim]
auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim]
if (is_double) {
auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim]
auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim]
auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim]
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)};
}
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()};
}
};
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* shift,
struct ggml_tensor* scale) {
// x: [N, L, C]
// scale: [N, C]
// shift: [N, C]
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C]
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
x = ggml_add(ctx, x, shift);
return x;
}
struct DoubleStreamBlock : public GGMLBlock {
public:
DoubleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio,
bool qkv_bias = false) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* vec,
struct ggml_tensor* pe) {
// img: [N, n_img_token, hidden_size]
// txt: [N, n_txt_token, hidden_size]
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
// return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size])
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
auto img_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm1"]);
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm2"]);
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm2"]);
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
auto img_mods = img_mod->forward(ctx, vec);
ModulationOut img_mod1 = img_mods[0];
ModulationOut img_mod2 = img_mods[1];
auto txt_mods = txt_mod->forward(ctx, vec);
ModulationOut txt_mod1 = txt_mods[0];
ModulationOut txt_mod2 = txt_mods[1];
// prepare image for attention
auto img_modulated = img_norm1->forward(ctx, img);
img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale);
auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head]
auto img_q = img_qkv[0];
auto img_k = img_qkv[1];
auto img_v = img_qkv[2];
// prepare txt for attention
auto txt_modulated = txt_norm1->forward(ctx, txt);
txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale);
auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head]
auto txt_q = txt_qkv[0];
auto txt_k = txt_qkv[1];
auto txt_v = txt_qkv[2];
// run actual attention
auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
attn->ne[0],
attn->ne[1],
txt->ne[1],
attn->nb[1],
attn->nb[2],
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx,
attn,
attn->ne[0],
attn->ne[1],
img->ne[1],
attn->nb[1],
attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
// calculate the img bloks
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out);
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate));
// calculate the txt bloks
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out);
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate));
return {img, txt};
}
};
struct SingleStreamBlock : public GGMLBlock {
public:
int64_t num_heads;
int64_t hidden_size;
int64_t mlp_hidden_dim;
public:
SingleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0f,
float qk_scale = 0.f)
: hidden_size(hidden_size), num_heads(num_heads) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
scale = 1 / sqrt((float)head_dim);
}
mlp_hidden_dim = hidden_size * mlp_ratio;
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim));
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
// mlp_act is nn.GELU(approximate="tanh")
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* vec,
struct ggml_tensor* pe) {
// x: [N, n_token, hidden_size]
// pe: [n_token, d_head/2, 2, 2]
// return: [N, n_token, hidden_size]
auto linear1 = std::dynamic_pointer_cast<Linear>(blocks["linear1"]);
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
auto modulation = std::dynamic_pointer_cast<Modulation>(blocks["modulation"]);
auto mods = modulation->forward(ctx, vec);
ModulationOut mod = mods[0];
auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
auto qkv = ggml_view_3d(ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token]
mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim]
auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate));
return output;
}
};
struct LastLayer : public GGMLBlock {
public:
LastLayer(int64_t hidden_size,
int64_t patch_size,
int64_t out_channels) {
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
// return: [N, n_token, patch_size * patch_size * out_channels]
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);
return x;
}
};
struct FluxParams {
int64_t in_channels = 64;
int64_t vec_in_dim = 768;
int64_t context_in_dim = 4096;
int64_t hidden_size = 3072;
float mlp_ratio = 4.0f;
int64_t num_heads = 24;
int64_t depth = 19;
int64_t depth_single_blocks = 38;
std::vector<int> axes_dim = {16, 56, 56};
int64_t axes_dim_sum = 128;
int theta = 10000;
bool qkv_bias = true;
bool guidance_embed = true;
};
struct Flux : public GGMLBlock {
public:
std::vector<float> linspace(float start, float end, int num) {
std::vector<float> result(num);
float step = (end - start) / (num - 1);
for (int i = 0; i < num; ++i) {
result[i] = start + i * step;
}
return result;
}
std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size();
int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
transposed[j][i] = mat[i][j];
}
}
return transposed;
}
std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec;
for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
}
return flat_vec;
}
std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0);
int half_dim = dim / 2;
std::vector<float> scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim);
std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0 / std::pow(theta, scale[i]);
}
int pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
out[i][j] = pos[i] * omega[j];
}
}
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
result[i][4 * j] = std::cos(out[i][j]);
result[i][4 * j + 1] = -std::sin(out[i][j]);
result[i][4 * j + 2] = std::sin(out[i][j]);
result[i][4 * j + 3] = std::cos(out[i][j]);
}
}
return result;
}
// Generate IDs for image patches and text
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
std::vector<float> row_ids = linspace(0, h_len - 1, h_len);
std::vector<float> col_ids = linspace(0, w_len - 1, w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][1] = row_ids[i];
img_ids[i * w_len + j][2] = col_ids[j];
}
}
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < img_ids.size(); ++j) {
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
}
}
std::vector<std::vector<float>> txt_ids(bs * context_len, std::vector<float>(3, 0.0));
std::vector<std::vector<float>> ids(bs * (context_len + img_ids.size()), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < context_len; ++j) {
ids[i * (context_len + img_ids.size()) + j] = txt_ids[j];
}
for (int j = 0; j < img_ids.size(); ++j) {
ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j];
}
}
return ids;
}
// Generate positional embeddings
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len);
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size();
int num_axes = axes_dim.size();
for (int i = 0; i < pos_len; i++) {
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
}
int emb_dim = 0;
for (int d : axes_dim)
emb_dim += d / 2;
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0;
for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) {
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
}
}
}
offset += rope_emb[0].size();
}
return flatten(emb);
}
public:
FluxParams params;
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t out_channels = params.in_channels;
int64_t pe_dim = params.hidden_size / params.num_heads;
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
if (params.guidance_embed) {
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
}
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size, true));
for (int i = 0; i < params.depth; i++) {
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio,
params.qkv_bias));
}
for (int i = 0; i < params.depth_single_blocks; i++) {
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t patch_size) {
// x: [N, C, H, W]
// return: [N, h*w, C * patch_size * patch_size]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = patch_size;
int64_t h = H / patch_size;
int64_t w = W / patch_size;
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w,
int64_t patch_size) {
// x: [N, h*w, C*patch_size*patch_size]
// return: [N, C, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / patch_size / patch_size;
int64_t H = h * patch_size;
int64_t W = w * patch_size;
int64_t p = patch_size;
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
return x;
}
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* timesteps,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
img = img_in->forward(ctx, img);
auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) {
GGML_ASSERT(guidance != NULL);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
}
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
txt = txt_in->forward(ctx, txt);
for (int i = 0; i < params.depth; i++) {
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
auto img_txt = block->forward(ctx, img, txt, vec, pe);
img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size]
}
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
for (int i = 0; i < params.depth_single_blocks; i++) {
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, vec, pe);
}
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx,
txt_img,
txt_img->ne[0],
txt_img->ne[1],
img->ne[1],
txt_img->nb[1],
txt_img->nb[2],
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
return img;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
// return: (N, C, H, W)
GGML_ASSERT(x->ne[3] == 1);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
return out;
}
};
struct FluxRunner : public GGMLRunner {
public:
FluxParams flux_params;
Flux flux;
std::vector<float> pe_vec; // for cache
FluxRunner(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: GGMLRunner(backend, wtype) {
if (version == VERSION_FLUX_SCHNELL) {
flux_params.guidance_embed = false;
}
flux = Flux(flux_params);
flux.init(params_ctx, wtype);
}
std::string get_desc() {
return "flux";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
flux.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
x = to_backend(x);
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
guidance = to_backend(guidance);
}
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* out = flux.forward(compute_ctx,
x,
timesteps,
context,
y,
guidance,
pe);
ggml_build_forward_expand(gf, out);
return gf;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
params.mem_buffer = NULL;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL);
{
// cpu f16:
// cuda f16: nan
// cuda q8_0: pass
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
ggml_set_f32(x, 0.01f);
// print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 999.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
std::vector<float> guidance_vec(1, 3.5f);
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1);
ggml_set_f32(context, 0.01f);
// print_ggml_tensor(context);
auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
ggml_set_f32(y, 0.01f);
// print_ggml_tensor(y);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("flux test done in %dms", t1 - t0);
}
}
static void load_from_file_and_test(const std::string& file_path) {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
std::shared_ptr<FluxRunner> flux = std::shared_ptr<FluxRunner>(new FluxRunner(backend, model_data_type));
{
LOG_INFO("loading from '%s'", file_path.c_str());
flux->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
flux->get_param_tensors(tensors, "model.diffusion_model");
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return;
}
bool success = model_loader.load_tensors(tensors, backend);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("flux model loaded");
}
flux->test();
}
};
} // namespace Flux
#endif // __FLUX_HPP__