Skip to content

Commit

Permalink
FLUX dev and lite support
Browse files Browse the repository at this point in the history
  • Loading branch information
likholat committed Dec 6, 2024
1 parent ee91fcf commit ed2d70a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ class OPENVINO_GENAI_EXPORTS FluxTransformer2DModel {
public:
struct Config {
size_t in_channels = 64;
bool guidance_embeds = false;

size_t m_default_sample_size = 128;
std::vector<std::string> m_model_input_names;

explicit Config(const std::filesystem::path& config_path);
};
Expand Down
8 changes: 8 additions & 0 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class FluxPipeline : public DiffusionPipeline {

ov::Tensor latent_image_ids = prepare_latent_image_ids(generation_config.num_images_per_prompt, height / 2, width / 2);

// add guidance tensor if input exist
auto input_names = m_transformer->get_config().m_model_input_names;
if (std::find(input_names.begin(), input_names.end(), "guidance") != input_names.end()) {
ov::Tensor guidance = ov::Tensor(ov::element::f32, {generation_config.num_images_per_prompt});
std::fill_n(guidance.data<float>(), guidance.get_size(), static_cast<float>(m_generation_config.guidance_scale));
m_transformer->set_hidden_states("guidance", guidance);
}

m_transformer->set_hidden_states("pooled_projections", pooled_prompt_embeds);
m_transformer->set_hidden_states("encoder_hidden_states", prompt_embeds);
m_transformer->set_hidden_states("txt_ids", text_ids);
Expand Down
12 changes: 12 additions & 0 deletions src/cpp/src/image_generation/models/flux_transformer_2d_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
#include "json_utils.hpp"
#include "utils.hpp"

namespace {
void get_input_names(std::vector<std::string>& input_names, const std::vector<ov::Output<const ov::Node>>& inputs_info) {
for (const auto& port : inputs_info) {
input_names.push_back(port.get_any_name());
}
}
}

namespace ov {
namespace genai {

Expand All @@ -21,6 +29,7 @@ FluxTransformer2DModel::Config::Config(const std::filesystem::path& config_path)
using utils::read_json_param;

read_json_param(data, "in_channels", in_channels);
read_json_param(data, "guidance_embeds", guidance_embeds);
file.close();
}

Expand Down Expand Up @@ -95,6 +104,8 @@ FluxTransformer2DModel& FluxTransformer2DModel::reshape(int batch_size,
name_to_shape[input_name] = {height * width / 4, name_to_shape[input_name][1]};
} else if (input_name == "txt_ids") {
name_to_shape[input_name] = {tokenizer_model_max_length, name_to_shape[input_name][1]};
} else if (input_name == "guidance") {
name_to_shape[input_name] = {batch_size};
}
}

Expand All @@ -106,6 +117,7 @@ FluxTransformer2DModel& FluxTransformer2DModel::reshape(int batch_size,
FluxTransformer2DModel& FluxTransformer2DModel::compile(const std::string& device, const ov::AnyMap& properties) {
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, properties);
get_input_names(m_config.m_model_input_names, compiled_model.inputs());
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
2 changes: 2 additions & 0 deletions src/docs/SUPPORTED_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ The pipeline can work with other similar topologies produced by `optimum-intel`
<td>
<ul>
<li><a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell"><code>black-forest-labs/FLUX.1-schnell</code></a></li>
<li><a href="https://huggingface.co/Freepik/flux.1-lite-8B-alpha"><code>Freepik/flux.1-lite-8B-alpha</code></a></li>
<li><a href="https://huggingface.co/black-forest-labs/FLUX.1-dev"><code>black-forest-labs/FLUX.1-dev</code></a></li>
</ul>
</td>
</tr>
Expand Down

0 comments on commit ed2d70a

Please sign in to comment.