diff --git a/apps/hallmark/CMakeLists.txt b/apps/hallmark/CMakeLists.txt new file mode 100644 index 000000000000..1c17108581c4 --- /dev/null +++ b/apps/hallmark/CMakeLists.txt @@ -0,0 +1,108 @@ +cmake_minimum_required(VERSION 3.22) +project(hallmark) + +# We need to set this for some of the subprojects pulled in by TFLite (eg flatbuffers) +# set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + +enable_testing() + +# ---------------------------- + +# Compatibility cruft +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +# ---------------------------- + +# Set up language settings +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED YES) +set(CMAKE_CXX_EXTENSIONS NO) + +# ---------------------------- + +# Find HalideHelpers -- this is just the Runtime headers and CMake functions, but no libraries +find_package(Halide REQUIRED) +find_package(ZLIB REQUIRED) + +# ---------------------------- + +# go get various deps we need + +include(FetchContent) + +set(ABSL_PROPAGATE_CXX_STD ON) +set(ABSL_USE_EXTERNAL_GOOGLETEST ON) +FetchContent_Declare(abseil + GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git + GIT_TAG 20240116.2) + +set(FLATBUFFERS_BUILD_TESTS OFF) +set(FLATBUFFERS_INSTALL OFF) +FetchContent_Declare(flatbuffers + GIT_REPOSITORY https://github.com/google/flatbuffers.git + GIT_TAG v23.5.26 + GIT_SHALLOW TRUE) + +set(BUILD_GMOCK OFF) +set(INSTALL_GTEST OFF) +set(GTEST_HAS_ABSL OFF) +FetchContent_Declare(googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0) + +set(BENCHMARK_ENABLE_TESTING OFF) +set(BENCHMARK_ENABLE_EXCEPTIONS OFF) +set(BENCHMARK_ENABLE_INSTALL OFF) +set(BENCHMARK_INSTALL_DOCS OFF) +set(BENCHMARK_ENABLE_GTEST_TESTS OFF) +FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3) + +set(protobuf_INSTALL OFF) +set(protobuf_BUILD_TESTS OFF) +set(protobuf_BUILD_CONFORMANCE OFF) +set(protobuf_BUILD_EXAMPLES OFF) +set(protobuf_BUILD_PROTOC_BINARIES ON) +set(protobuf_BUILD_LIBPROTOC OFF) +set(protobuf_BUILD_LIBUPB OFF) +set(protobuf_DISABLE_RTTI ON) +set(protobuf_WITH_ZLIB ON CACHE BOOL "" FORCE) +FetchContent_Declare(protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git + GIT_TAG v26.1 + GIT_SHALLOW TRUE) + +# TODO(srj,zvookin): Exact parameters should be double checked. +FetchContent_Declare(sentencepiece + GIT_REPOSITORY https://github.com/google/sentencepiece + GIT_TAG v0.2.0 # Old: 53de76561cfc149d3c01037f0595669ad32a5e7c + ) + +FetchContent_MakeAvailable(abseil flatbuffers googletest googlebenchmark protobuf sentencepiece) + +# ---------- Set up targets for flatc +add_library(hallmark_flatbuffers INTERFACE) +target_sources(hallmark_flatbuffers INTERFACE $>) +target_include_directories(hallmark_flatbuffers + SYSTEM # Use -isystem instead of -I; this is a trick so that clang-tidy won't analyze these includes + INTERFACE + $/include + $/include) +set_target_properties(hallmark_flatbuffers PROPERTIES EXPORT_NAME flatbuffers) +add_executable(flatbuffers::flatc ALIAS flatc) + +# ---------- Set up targets for protobuf +FetchContent_GetProperties(protobuf SOURCE_DIR protobuf_SOURCE_DIR) +# Include the script which defines 'protobuf_generate' +include(${protobuf_SOURCE_DIR}/cmake/protobuf-generate.cmake) + +# ---------------------------- + +add_subdirectory(contrib) +add_subdirectory(src) +add_subdirectory(test) + +# ---------------------------- diff --git a/apps/hallmark/README.md b/apps/hallmark/README.md new file mode 100644 index 000000000000..c97a77bac286 --- /dev/null +++ b/apps/hallmark/README.md @@ -0,0 +1,17 @@ +Hallmark (HAlide LLM Advanced Research Kit) is Halide-written execution engine for Gemini/Gemma models; +it serves as a testbed for writing efficient ML kernels in Halide. + +To build with CMake: + +- build and install Halide locally to ${HALIDE_INSTALL} +- cd apps/hallmark +- mkdir build && cd build +- cmake .. -DHalide_DIR=${HALIDE_INSTALL}/lib/cmake/Halide -DCMAKE_BUILD_TYPE=Release +- cd build && ninja (or make) + +To run the tests: +- ./build/test/llm_generator_test --model_path=/path/to/model.tflite + +To run the benchmarks: +- ./build/test/llm_generator_bench --model_path=/path/to/model.tflite --benchmark_filter=all + diff --git a/apps/hallmark/contrib/CMakeLists.txt b/apps/hallmark/contrib/CMakeLists.txt new file mode 100644 index 000000000000..187cc4dfc635 --- /dev/null +++ b/apps/hallmark/contrib/CMakeLists.txt @@ -0,0 +1,41 @@ + +# --------------------------- Generate flatbuffer files +set(tflite_schema_source "${CMAKE_CURRENT_SOURCE_DIR}/tflite_schema.fbs") +set(tflite_generated_header "${CMAKE_CURRENT_BINARY_DIR}/tflite_schema_generated.h") +add_custom_command( + OUTPUT "${tflite_generated_header}" + COMMAND flatbuffers::flatc --cpp --cpp-std C++17 --no-union-value-namespacing --keep-prefix -o "${CMAKE_CURRENT_BINARY_DIR}" "${tflite_schema_source}" + DEPENDS "${fb_def}" + VERBATIM +) +add_custom_target(generate_tflite_schema_header DEPENDS "${tflite_generated_header}") +set_source_files_properties("${tflite_generated_header}" PROPERTIES GENERATED TRUE) + +# --------------------------- Generate protobuf files +add_library(proto_objects OBJECT llm_params.proto transformer_params.proto) +target_link_libraries(proto_objects PUBLIC protobuf::libprotobuf) +target_include_directories(proto_objects PUBLIC "$") + +protobuf_generate( + TARGET proto_objects + PROTOC_OUT_DIR "${CMAKE_CURRENT_BINARY_DIR}") + + +# --------------------------- Ordinary code +add_library(hallmark_contrib + llm_params.cc + sampler.cc + weights_loader.cc) +add_dependencies(hallmark_contrib + generate_tflite_schema_header) +target_include_directories(hallmark_contrib INTERFACE + $) +target_include_directories(hallmark_contrib PRIVATE + $ + $) +target_link_libraries(hallmark_contrib + PRIVATE + absl::status + Halide::Runtime + hallmark_flatbuffers + proto_objects) diff --git a/apps/hallmark/contrib/llm_params.cc b/apps/hallmark/contrib/llm_params.cc new file mode 100644 index 000000000000..c7b8bab57b82 --- /dev/null +++ b/apps/hallmark/contrib/llm_params.cc @@ -0,0 +1,194 @@ +#include "contrib/llm_params.h" + +#include "contrib/memory_mapped_file.h" +#include "contrib/status_helpers.h" +#include "contrib/llm_params.pb.h" +#include "contrib/transformer_params.pb.h" +// macOS system headers #define this value in syslimits.h +#undef ARG_MAX +#include "contrib/tflite_schema_generated.h" + +namespace hallmark { + +namespace { + +using odml::infra::proto::LlmParameters; +using odml::infra::proto::TransformerParameters; + +const ::tflite::Metadata *FindMetadata(const ::tflite::Model *tflite_model, + std::string name) { + if (tflite_model->metadata() == nullptr) { + return nullptr; + } + + for (const auto *metadata : *tflite_model->metadata()) { + if (name == metadata->name()->c_str()) { + return metadata; + } + } + return nullptr; +} + +LlmParams::Norm TransformerParametersProtoNormTypeToLlmParamsNormType( + TransformerParameters::Norm norm_type) { + switch (norm_type) { + case TransformerParameters::NORM_UNSPECIFIED: + ABSL_LOG(DFATAL) << "Unspecified norm type."; + return LlmParams::Norm::UNSPECIFIED; + case TransformerParameters::NO_NORM: + return LlmParams::Norm::NO_NORM; + case TransformerParameters::RMS_NORM: + return LlmParams::Norm::RMS_NORM; + case TransformerParameters::LAYER_NORM: + return LlmParams::Norm::LAYER_NORM; + default: + ABSL_LOG(DFATAL) << "Unknown norm type: " << norm_type; + } + return LlmParams::Norm::UNSPECIFIED; +} + +LlmParams FromLLMParametersProto(const LlmParameters &llm_params) { + const auto &transformer_params = llm_params.transformer_parameters(); + LlmParams params = { + .num_transformer_M = static_cast(transformer_params.num_stacks()), + .batch_size_B = static_cast(transformer_params.batch_size()), + .seq_size_T = static_cast(transformer_params.max_seq_length()), + .model_dim_D = static_cast(transformer_params.embedding_dim()), + .hidden_dim_HD = + static_cast(transformer_params.hidden_dimension()), + .head_dim_H = static_cast(transformer_params.head_dimension()), + .n_heads_N = static_cast(transformer_params.num_heads()), + .voc_size_V = static_cast(llm_params.vocab_size()), + + .num_kv_heads = + static_cast(transformer_params.num_kv_heads() == 0 ? transformer_params.num_heads() : transformer_params.num_kv_heads()), + .enable_kv_cache = true, + .enable_dynamic_shape = false}; + switch ( + transformer_params.self_attention_parameters().attention_mask_type()) { + case TransformerParameters::UNSPECIFIED: + ABSL_LOG(DFATAL) << "Unspecified attention_mask_type, assuming causal"; + params.model_type = LlmParams::ModelType::UNSPECIFIED; + break; + case TransformerParameters::CAUSAL: + params.model_type = LlmParams::ModelType::CAUSAL; + break; + case TransformerParameters::PREFIX: + params.model_type = LlmParams::ModelType::PREFIX; + break; + default: + ABSL_LOG(DFATAL) << "Unknown attention_mask_type: " + << transformer_params.self_attention_parameters() + .attention_mask_type(); + } + params.ff_params = LlmParams::FeedForwardParams{ + .no_bias = transformer_params.feed_forward_parameters().no_bias(), + }; + params.final_proj_params = LlmParams::FinalProjectParams{ + .no_bias = transformer_params.final_project_parameters().no_bias(), + }; + switch (transformer_params.feed_forward_parameters().activation()) { + case TransformerParameters::ACTIVATION_UNSPECIFIED: + ABSL_LOG(DFATAL) << "Unspecified feed_forward_parameters.activation."; + params.ff_params.activation = LlmParams::Activation::UNSPECIFIED; + break; + case TransformerParameters::GELU: + params.ff_params.activation = LlmParams::Activation::GELU; + break; + case TransformerParameters::SILU: + params.ff_params.activation = LlmParams::Activation::SILU; + break; + case TransformerParameters::RELU: + params.ff_params.activation = LlmParams::Activation::RELU; + break; + default: + ABSL_LOG(DFATAL) + << "Unknown feed_forward_parameters.activation: " + << transformer_params.feed_forward_parameters().activation(); + } + params.sa_params.qkv_no_bias = + transformer_params.self_attention_parameters().qkv_no_bias(); + params.sa_params.post_proj_no_bias = + transformer_params.self_attention_parameters().post_proj_no_bias(); + params.sa_params.pre_norm = + TransformerParametersProtoNormTypeToLlmParamsNormType( + transformer_params.pre_norm()); + params.sa_params.post_norm = + TransformerParametersProtoNormTypeToLlmParamsNormType( + transformer_params.post_norm()); + params.sa_params.soft_cap_value = + transformer_params.self_attention_parameters().soft_cap_value(); + params.ff_params.pre_norm = + TransformerParametersProtoNormTypeToLlmParamsNormType( + transformer_params.feed_forward_parameters().pre_norm()); + params.ff_params.post_norm = + TransformerParametersProtoNormTypeToLlmParamsNormType( + transformer_params.feed_forward_parameters().post_norm()); + params.final_norm = TransformerParametersProtoNormTypeToLlmParamsNormType( + transformer_params.final_norm()); + params.skip_absolute_positional_embeddings = + transformer_params.skip_absolute_positional_embeddings(); + if (transformer_params.self_attention_parameters() + .has_attention_scale_type()) { + switch ( + transformer_params.self_attention_parameters().attention_scale_type()) { + case TransformerParameters::SCALE_TYPE_UNSPECIFIED: + ABSL_LOG(DFATAL) << "Unspecified attention_scale_type."; + params.sa_params.attention_scale_type = + LlmParams::AttentionScaleType::UNSPECIFIED; + break; + case TransformerParameters::SCALE_TYPE_PER_DIM_SCALE: + params.sa_params.attention_scale_type = + LlmParams::AttentionScaleType::PER_DIM_SCALE; + break; + case TransformerParameters::SCALE_TYPE_INV_SQRT_HEAD_DIM: + params.sa_params.attention_scale_type = + LlmParams::AttentionScaleType::INV_SQRT_HEAD_DIM; + break; + default: + ABSL_LOG(DFATAL) << "Unknown attention_scale_type: " + << transformer_params.self_attention_parameters() + .attention_scale_type(); + } + } else { + if (transformer_params.num_kv_heads() == 0 || + transformer_params.num_heads() == transformer_params.num_kv_heads()) { + // If MHA, PER_DIM_SCALE is used. + params.sa_params.attention_scale_type = + LlmParams::AttentionScaleType::PER_DIM_SCALE; + } else { + // If MQA or GQA, INV_SQRT_HEAD_DIM is used. + params.sa_params.attention_scale_type = + LlmParams::AttentionScaleType::INV_SQRT_HEAD_DIM; + } + } + + return params; +} + +} // namespace + +absl::StatusOr LoadLlmParams(absl::string_view tflite_path) { + MemoryMappedFile file(tflite_path); + if (!file.valid()) { + return absl::InvalidArgumentError("Could not open file for llm_params"); + } + + const ::tflite::Model *tflite_model = ::tflite::GetModel(file.data()); + const auto *metadata = + FindMetadata(tflite_model, "odml.infra.proto.LlmParameters"); + if (!metadata) { + return absl::InvalidArgumentError("No metadata found in model"); + } + + const ::tflite::Buffer *buffer = + tflite_model->buffers()->Get(metadata->buffer()); + const void *base = (const char *)file.data() + buffer->offset(); + const size_t len = buffer->size(); + + LlmParameters llm_parameters; + RET_CHECK(llm_parameters.ParseFromArray(base, len)); + return FromLLMParametersProto(llm_parameters); +} + +} // namespace hallmark diff --git a/apps/hallmark/contrib/llm_params.h b/apps/hallmark/contrib/llm_params.h new file mode 100644 index 000000000000..3e94a342a23a --- /dev/null +++ b/apps/hallmark/contrib/llm_params.h @@ -0,0 +1,112 @@ +#ifndef HALIDE_APPS_HALLMARK_LLM_PARAMS_H_ +#define HALIDE_APPS_HALLMARK_LLM_PARAMS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace hallmark { + +struct LlmParams { + size_t num_transformer_M = 0; + size_t batch_size_B = 0; + size_t seq_size_T = 0; + size_t model_dim_D = 0; + size_t hidden_dim_HD = 0; + size_t head_dim_H = 0; + size_t n_heads_N = 0; + size_t voc_size_V = 0; + + // Number of kv heads. In case of Multi-Head-Attention (MHA), num_kv_heads is + // the same as n_heads_N, which is number of query heads; In case of + // Multi-Query-Attention (MQA), key and value have one head; otherwise, this + // specifies the number of heads for key and value, and + // Grouped-Query-Attention (GQA) will be used. See + // https://arxiv.org/pdf/2305.13245.pdf for details. + size_t num_kv_heads = 0; + + // Meant to be a mapping of pax LanguageModelType. This will affect e.g. + // attention mask shape. + enum class ModelType { + UNSPECIFIED = 0, + // Attention mask for input are prefixed to be bidirectional. + PREFIX = 1, + // Attention mask are forward only. + CAUSAL = 2, + } model_type = ModelType::CAUSAL; + + enum class Activation { + UNSPECIFIED = 0, + // Gaussian Error Linear Unit. + GELU = 1, + // Sigmoid-Weighted Linear Unit. + SILU = 2, + // Rectified Linear Unit. + RELU = 3, + }; + + enum class Norm { + UNSPECIFIED = 0, + NO_NORM = 1, + RMS_NORM = 2, + LAYER_NORM = 3, + }; + + enum class AttentionScaleType { + UNSPECIFIED = 0, + // Per dimension scale, query is scaled by log_2(1 + exp(w)) / + // sqrt(head_dim) where w is s static weight. + PER_DIM_SCALE = 1, + // Query is scaled by 1/sqrt(head_dim). + INV_SQRT_HEAD_DIM = 2, + }; + + // If false, add absolute positional embeddings. + bool skip_absolute_positional_embeddings = false; + + struct SelfAttentionParams { + bool qkv_no_bias = false; + bool post_proj_no_bias = false; + Norm pre_norm = Norm::RMS_NORM; + Norm post_norm = Norm::RMS_NORM; + + // If greater than 0, CapTanh will be applied. Otherwise, no cap will be + // applied. + float soft_cap_value = 0.0f; + + // Attention scale type to be applied within the transformer. + AttentionScaleType attention_scale_type; + } sa_params; + + struct FeedForwardParams { + // If `no_bias`, fully connect will degrade to matrix multiply. + bool no_bias = false; + Activation activation = Activation::GELU; + Norm pre_norm = Norm::RMS_NORM; + Norm post_norm = Norm::RMS_NORM; + } ff_params; + + Norm final_norm = Norm::RMS_NORM; + + struct FinalProjectParams { + // If `no_bias`, final fully connect will degrade to matrix multiply. + bool no_bias = false; + } final_proj_params; + + /* + * Parameters below do NOT change the "correctness" of the model, they + * configure the acceleration of inference. + */ + + bool enable_kv_cache = false; + // If true, inference engine will optimize tensor shape according to current + // sequence length to avoid computation waste. + bool enable_dynamic_shape = false; +}; + +absl::StatusOr LoadLlmParams(absl::string_view tflite_path); + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_LLM_PARAMS_H_ diff --git a/apps/hallmark/contrib/llm_params.proto b/apps/hallmark/contrib/llm_params.proto new file mode 100644 index 000000000000..18989da05527 --- /dev/null +++ b/apps/hallmark/contrib/llm_params.proto @@ -0,0 +1,69 @@ +/* Copyright 2023 The ODML Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package odml.infra.proto; + +import "transformer_params.proto"; + +option java_package = "com.google.odml.infra.proto"; +option java_outer_classname = "LLMParametersProto"; + +// Naming convention is as following: LLM_MODEL_TYPE__. +enum LlmModelType { + // Unknown LLM model type + LLM_MODEL_TYPE_UNKNOWN = 0; + + reserved 1, 2, 3, 4, 7, 9, 10; + + // FALCON RefinedWeb with 1B parameters. + // https://huggingface.co/tiiuae/falcon-rw-1b + LLM_MODEL_TYPE_FALCON_RW_1B = 5; + + // GEMMA with 2B parameters + LLM_MODEL_TYPE_GEMMA_2B = 6; + + // GEMMA with 7B parameters + LLM_MODEL_TYPE_GEMMA_7B = 12; + + // StableLM 4E1T with 3B parameters + LLM_MODEL_TYPE_STABLELM_4E1T_3B = 8; + + // Phi-2 + // https://huggingface.co/microsoft/phi-2 + LLM_MODEL_TYPE_PHI_2 = 11; +} + +// Parameters for Large Language Models (LLM). +message LlmParameters { + TransformerParameters transformer_parameters = 1; + + // Size of vocabulary. + int32 vocab_size = 2; + + // Was used for disable_kv_cache. + reserved 3; + + // Start token prepended to the beginning of input sequence. + oneof start_token_union { + int32 start_token_id = 4; + + string start_token = 6; + } + + // Stop tokens to determine the end of output stream. + repeated string stop_tokens = 5; +} diff --git a/apps/hallmark/contrib/llm_weights.h b/apps/hallmark/contrib/llm_weights.h new file mode 100644 index 000000000000..5f17efa72e7a --- /dev/null +++ b/apps/hallmark/contrib/llm_weights.h @@ -0,0 +1,85 @@ +#ifndef HALIDE_APPS_HALLMARK_LLM_WEIGHTS_H_ +#define HALIDE_APPS_HALLMARK_LLM_WEIGHTS_H_ + +#include +#include +#include +#include + +#include "HalideBuffer.h" + +namespace hallmark { + +// Provides access to data tied to an underlying resource. The resource may be +// released when this object is destroyed. +class DataHolder { +public: + virtual ~DataHolder() = default; +}; + +// If dim_scale >= 0, then `weights` should be scaled by that dimension. +// Otherwise, scale is an empty (unallocated) Buffer. +struct ScaledTensor { + Halide::Runtime::Buffer<> weights, scale; + int dim_scale = -1; +}; + +struct RMSNormWeights { + ScaledTensor norm_weight; +}; + +struct LayerNormWeights { + float epsilon = 1e-5; + ScaledTensor gamma; + ScaledTensor beta; +}; + +struct LlmWeights { + using NormWeights = std::variant; + + struct SelfAttentionWeights { + std::optional pre_norm_weight; + + ScaledTensor k_weight; + ScaledTensor k_bias; + ScaledTensor q_weight; + ScaledTensor q_bias; + ScaledTensor v_weight; + ScaledTensor v_bias; + ScaledTensor per_dim_scale; + ScaledTensor post_proj_weight; + ScaledTensor post_proj_bias; + + std::optional post_norm_weight; + }; + + struct FeedForwardWeights { + std::optional pre_norm_weight; + ScaledTensor layer_1_weight; + ScaledTensor layer_1_bias; + ScaledTensor layer_1_gate_weight; + ScaledTensor layer_1_gate_bias; + ScaledTensor layer_2_weight; + ScaledTensor layer_2_bias; + std::optional post_norm_weight; + }; + + std::vector ffs; + std::vector sas; + std::optional final_norm_weight; + ScaledTensor softmax_linear; + ScaledTensor softmax_bias; + + // Usually same as softmax_linear, but some models use different + // softmax_linear v.s. embedding table. + ScaledTensor token_embedding; + + // TODO: a bit of an ugly hack here; if the weights are loaded from + // a memory-mapped file, this is a shared_ptr to ensure that the mapping + // remains valid for the life of this instance. + std::shared_ptr data_holder; +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_LLM_WEIGHTS_H_ diff --git a/apps/hallmark/contrib/memory_mapped_file.h b/apps/hallmark/contrib/memory_mapped_file.h new file mode 100644 index 000000000000..361a96b7e47c --- /dev/null +++ b/apps/hallmark/contrib/memory_mapped_file.h @@ -0,0 +1,52 @@ +#ifndef HALIDE_APPS_HALLMARK_MEMORY_MAPPED_FILE_H_ +#define HALIDE_APPS_HALLMARK_MEMORY_MAPPED_FILE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace hallmark { + +class MemoryMappedFile final { +public: + explicit MemoryMappedFile(absl::string_view path) { + fd_ = open(path.data(), O_RDONLY); + if (fd_ >= 0) { + length_ = lseek(fd_, 0, SEEK_END); + data_ = mmap(nullptr, length_, PROT_READ, MAP_SHARED, fd_, 0); + } else { + length_ = 0; + data_ = nullptr; + } + } + + virtual ~MemoryMappedFile() { + if (data_) { + munmap(data_, length_); + } + if (fd_ >= 0) { + close(fd_); + } + } + + uint64_t length() const { + return length_; + } + void *data() const { + return data_; + } + bool valid() const { + return data_ != nullptr; + } + +private: + int fd_; + uint64_t length_; + void *data_; +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_MEMORY_MAPPED_FILE_H_ diff --git a/apps/hallmark/contrib/sampler.cc b/apps/hallmark/contrib/sampler.cc new file mode 100644 index 000000000000..66b8fab964ab --- /dev/null +++ b/apps/hallmark/contrib/sampler.cc @@ -0,0 +1,208 @@ +#include "contrib/sampler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "HalideBuffer.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "contrib/status_helpers.h" + +namespace hallmark { + +absl::StatusOr> Sampler::Create(Type type, int top_k, + float top_p, + float temperature, + int seed) { + if (type == Type::kTopK || type == Type::kTopP) { + if (top_k <= 1) { + return absl::InvalidArgumentError("top_k must be > 1"); + } else if (temperature < 0.0f) { + return absl::InvalidArgumentError("temperature must be >= 0"); + } + if (type == Type::kTopP && (top_p <= 0 || top_p > 1.0)) { + return absl::InvalidArgumentError("top_p must be between 0 and 1"); + } + } + return std::unique_ptr( + new Sampler(type, top_k, top_p, temperature, seed)); +} + +absl::StatusOr> Sampler::Sample( + Halide::Runtime::Buffer &logits) { + if (logits.dimensions() != 3 || logits.dim(1).extent() != 1) { + return absl::InvalidArgumentError( + "Buffer must be (vocab_size, 1 [seq_len], Batch)"); + } + + switch (type_) { + case Type::kGreedy: + return SampleGreedy(logits); + case Type::kTopK: + return SampleTopK(logits); + case Type::kTopP: + return SampleTopP(logits); + default: + return absl::InvalidArgumentError("Unsupported sampler type"); + } +}; + +Sampler::Sampler(Type type, int top_k, float top_p, float temperature, int seed) + : type_(type), + top_k_(top_k), + top_p_(top_p), + temperature_(temperature), + generator_(std::make_unique(seed)) { +} + +absl::StatusOr> Sampler::SampleGreedy( + Halide::Runtime::Buffer &logits) { + const int vocab_size = logits.dim(0).extent(); + const int seq_pos = logits.dim(1).min(); + const int batch_size = logits.dim(2).extent(); + + std::vector outputs; + outputs.reserve(batch_size); + // select the token with the highest logit directly. + for (int c = 0; c < batch_size; ++c) { + float max_logit = logits(0, seq_pos, c); + int max_id = 0; + for (int v = 0; v < vocab_size; ++v) { + const float prob = logits(v, seq_pos, c); + if (prob > max_logit) { + max_logit = prob; + max_id = v; + } + } + outputs.push_back(max_id); + } + return outputs; +}; + +absl::StatusOr> Sampler::SampleTopK( + Halide::Runtime::Buffer &logits) { + const int vocab_size = logits.dim(0).extent(); + const int seq_pos = logits.dim(1).min(); + const int batch_size = logits.dim(2).extent(); + + std::vector outputs; + outputs.reserve(batch_size); + for (int batch = 0; batch < batch_size; ++batch) { + std::vector> logits_ids; + logits_ids.reserve(vocab_size); + for (int v = 0; v < vocab_size; ++v) { + const float logit = logits(v, seq_pos, batch); + logits_ids.push_back(std::make_pair(logit, v)); + } + RETURN_IF_ERROR(SelectTopK(logits_ids, top_k_)); + // No need to normalize logits here, sampler takes care of that. + RETURN_IF_ERROR(ScaledSoftmax(logits_ids, /*normalize=*/false)); + auto sample_idx = DoSampling(logits_ids); + if (!sample_idx.ok()) { + return sample_idx.status(); + } + outputs.push_back(sample_idx.value()); + } + return outputs; +} + +absl::StatusOr> Sampler::SampleTopP( + Halide::Runtime::Buffer &logits) { + const int vocab_size = logits.dim(0).extent(); + const int seq_pos = logits.dim(1).min(); + const int batch_size = logits.dim(2).extent(); + const int k = top_k_ > 0 ? top_k_ : vocab_size; + + std::vector outputs; + outputs.reserve(batch_size); + for (int batch = 0; batch < batch_size; ++batch) { + std::vector> logits_ids; + logits_ids.reserve(vocab_size); + for (int v = 0; v < vocab_size; ++v) { + const float logit = logits(v, seq_pos, batch); + logits_ids.push_back(std::make_pair(logit, v)); + } + RETURN_IF_ERROR(SelectTopK(logits_ids, k)); + RETURN_IF_ERROR(ScaledSoftmax(logits_ids, /*normalize=*/true)); + RETURN_IF_ERROR(SelectTopP(logits_ids, top_p_)); + auto sample_idx = DoSampling(logits_ids); + if (!sample_idx.ok()) { + return sample_idx.status(); + } + outputs.push_back(sample_idx.value()); + } + return outputs; +} + +absl::Status Sampler::SelectTopK(std::vector> &logits_ids, + int k) { + if (k > logits_ids.size()) { + return absl::InvalidArgumentError( + "Top k value must be smaller than the number of logits."); + } + std::partial_sort( + logits_ids.begin(), logits_ids.begin() + k, logits_ids.end(), + [](const std::pair &a, const std::pair &b) { + // reverse order. + return a.first > b.first; + }); + logits_ids.resize(k); + return absl::OkStatus(); +} + +absl::Status Sampler::SelectTopP(std::vector> &logits_ids, + float p) { + int included = 0; + float prob_sum = 0.0; + for (const auto &[logit, _] : logits_ids) { + ++included; + prob_sum += logit; + if (prob_sum >= p) { + break; + } + } + if (included == 0) { + return absl::InternalError("Bad top_p value."); + } + logits_ids.resize(included); + return absl::OkStatus(); +} + +absl::Status Sampler::ScaledSoftmax( + std::vector> &logits_ids, bool normalize) { + float scale = 1 / (temperature_ ? temperature_ : 1.0); + double sum = 0.0; + float max_logit = logits_ids[0].first; + for (int i = 0; i < logits_ids.size(); ++i) { + const float logit = logits_ids[i].first; + const float p = expf(scale * (logit - max_logit)); + sum += p; + logits_ids[i].first = p; + } + if (normalize) { + for (int i = 0; i < logits_ids.size(); ++i) { + logits_ids[i].first /= sum; + } + } + return absl::OkStatus(); +} + +absl::StatusOr Sampler::DoSampling( + std::vector> &logits_ids) { + std::vector probs; + probs.reserve(logits_ids.size()); + for (const auto &[logit, _] : logits_ids) { + probs.push_back(logit); + } + // Probabilities are normalized by `discrete_distribution`. + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int sample_idx = dist(*generator_); + return logits_ids[sample_idx].second; +} + +} // namespace hallmark diff --git a/apps/hallmark/contrib/sampler.h b/apps/hallmark/contrib/sampler.h new file mode 100644 index 000000000000..20fec1868720 --- /dev/null +++ b/apps/hallmark/contrib/sampler.h @@ -0,0 +1,60 @@ +#ifndef HALIDE_APPS_HALLMARK_SAMPLER_H_ +#define HALIDE_APPS_HALLMARK_SAMPLER_H_ + +#include + +#include "HalideBuffer.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace hallmark { + +class Sampler { +public: + enum class Type { kGreedy, + kTopK, + kTopP }; + + // Creates a Sampler. + // * If kGreedy sampler is used, Argmax will be returned ignoring all other + // arguments provided. + // * If kTopK sampler is used, the top k logit values are selected. That is + // followed by temperature scaling and applying softmax. Finally, a sampled + // is drawn from the resulting distribution. + // * If kTopP sampler is selcted, the top k logits are first selcted if k > 0. + // Otherwise, k = vocab size. This is followed by temperature scaling and + // applying softmax. Finally, the top p are selcted from the probabilities + // such that sum of p_i is greater than or equal to top_p. Lastly, a sample + // is drawn from the resulting distribution. + static absl::StatusOr> Create(Type type, int top_k, + float top_p, + float temperature, + int seed); + // Given an input tensor of shape `(Batch, 1 [seq_len], vocab_size)`, runs + // the configured sampling algorithm to find a winning class. The results are + // reported as a vector of integer indicies where each entry corresponds to a + // batch. + absl::StatusOr> Sample(Halide::Runtime::Buffer &logits); + +private: + Sampler(Type type, int top_k, float top_p, float temperature, int seed); + absl::StatusOr> SampleGreedy(Halide::Runtime::Buffer &logits); + absl::StatusOr> SampleTopK(Halide::Runtime::Buffer &logits); + absl::StatusOr> SampleTopP(Halide::Runtime::Buffer &logits); + absl::Status SelectTopK(std::vector> &logits_ids, int k); + // `logits_ids` must be sorted and normalized. + absl::Status SelectTopP(std::vector> &logits_ids, float p); + // `logits_ids` must be sorted. + absl::Status ScaledSoftmax(std::vector> &logits_ids, bool normalize); + absl::StatusOr DoSampling(std::vector> &logits_ids); + + const Type type_; + const int top_k_; + const float top_p_; + const float temperature_; + std::unique_ptr generator_; +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_SAMPLER_H_ diff --git a/apps/hallmark/contrib/status_helpers.h b/apps/hallmark/contrib/status_helpers.h new file mode 100644 index 000000000000..d0a81cda3966 --- /dev/null +++ b/apps/hallmark/contrib/status_helpers.h @@ -0,0 +1,69 @@ +#ifndef HALIDE_APPS_HALLMARK_STATUS_HELPERS_H_ +#define HALIDE_APPS_HALLMARK_STATUS_HELPERS_H_ + +#include "HalideRuntime.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace hallmark { + +inline ::absl::Status StatusFromHalide(int halide_error) { + switch (halide_error) { + case halide_error_code_t::halide_error_code_success: + return ::absl::OkStatus(); + case halide_error_code_t::halide_error_code_out_of_memory: + return absl::ResourceExhaustedError("Halide error: out of memory"); + case halide_error_code_t::halide_error_code_device_malloc_failed: + return absl::ResourceExhaustedError("Halide error: device malloc failed"); + case halide_error_code_t::halide_error_code_buffer_allocation_too_large: + return absl::OutOfRangeError("Halide error: buffer allocation too large. Consider enabling 'large_buffers'"); + case halide_error_code_t::halide_error_code_buffer_extents_too_large: + return absl::OutOfRangeError("Halide error: buffer extents too large"); + case halide_error_code_t::halide_error_code_constraint_violated: + return absl::OutOfRangeError("Halide error: A constraint on a size or stride of an input or output buffer was not met."); + case halide_error_code_t::halide_error_code_bad_dimensions: + return absl::InvalidArgumentError("Halide error: The dimensions of an input buffer do not match the generator Input or Param dimensions."); + default: + return absl::UnknownError(::absl::StrFormat("Halide error: %d", halide_error)); + } +} + +// ------------------------------------- + +#define RETURN_IF_ERROR(expr) \ + do { \ + auto status = (expr); \ + if (!status.ok()) return status; \ + } while (0) + +#define RET_CHECK(expr) \ + do { \ + if (!(expr)) return absl::UnknownError("RET_CHECK failure: " #expr); \ + } while (0) + +// ------------------------------------- + +template +inline absl::Status DoAssignOrReturn(T &lhs, absl::StatusOr result) { + if (result.ok()) { + lhs = result.value(); + } + return result.status(); +} + +#define STATUS_MACROS_CONCAT_NAME_INNER(x, y) x##y +#define STATUS_MACROS_CONCAT_NAME(x, y) STATUS_MACROS_CONCAT_NAME_INNER(x, y) + +#define ASSIGN_OR_RETURN_IMPL(status, lhs, rexpr) \ + absl::Status status = DoAssignOrReturn(lhs, (rexpr)); \ + if (!status.ok()) return status; + +#define ASSIGN_OR_RETURN(lhs, rexpr) \ + ASSIGN_OR_RETURN_IMPL( \ + STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr); + +// ------------------------------------- + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_STATUS_HELPERS_H_ diff --git a/apps/hallmark/contrib/tflite_schema.fbs b/apps/hallmark/contrib/tflite_schema.fbs new file mode 100644 index 000000000000..382462f938d9 --- /dev/null +++ b/apps/hallmark/contrib/tflite_schema.fbs @@ -0,0 +1,1642 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. +// Version 3b: Rename fields in SignatureDef. Has backward compatibility with +// version 3 and 3a. +// Version 3c: Move constant tensor buffers & custom op buffers outside from +// Flatbuffers. Has backward compatibility with version 3, 3a and +// 3b. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, + COMPLEX128 = 11, + UINT64 = 12, + // Experimental: Resource and variant types are experimental, that are subject + // to change. Do not implement custom kernels using resource & variant types + // now. + RESOURCE = 13, + VARIANT = 14, + UINT32 = 15, + UINT16 = 16, + INT4 = 17, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +// The nested tensor type for VARIANT type. +table VariantSubType { + // The tensor shape. + shape:[int]; + type:TensorType; + // If false, the rank or the number of tensor dimensions is unknown. + // If false, "shape" must be []. + has_rank: bool = false; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. + + // This field is added to distinguish between scalars and tensors of unknown + // ranks (both of which shape is []). + // For scalars (rank = 0), shape = [] and has_rank = true. + // For tensors with known rank (rank > 0) and shape, shape = [...] and + // has_rank = true. + // For tensors with unknown rank and shape, shape = [] and has_rank = false. + has_rank: bool = false; + + // The nested Tensor types for VARIANT type. This is always empty for + // non-VARIANT types. This is optional because the nested type can be omitted. + // Currently only 1 subtype is supported. The field is defined as an array for + // flexibility of supporting multiple subtypes in the future. + variant_tensors:[VariantSubType]; +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +// LINT.IfChange +enum BuiltinOperator : int32 { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128, + CALL_ONCE = 129, + BROADCAST_TO = 130, + RFFT2D = 131, + CONV_3D = 132, + IMAG=133, + REAL=134, + COMPLEX_ABS=135, + HASHTABLE = 136, + HASHTABLE_FIND = 137, + HASHTABLE_IMPORT = 138, + HASHTABLE_SIZE = 139, + REDUCE_ALL = 140, + CONV_3D_TRANSPOSE = 141, + VAR_HANDLE = 142, + READ_VARIABLE = 143, + ASSIGN_VARIABLE = 144, + BROADCAST_ARGS = 145, + RANDOM_STANDARD_NORMAL = 146, + BUCKETIZE = 147, + RANDOM_UNIFORM = 148, + MULTINOMIAL = 149, + GELU = 150, + DYNAMIC_UPDATE_SLICE = 151, + RELU_0_TO_1 = 152, + UNSORTED_SEGMENT_PROD = 153, + UNSORTED_SEGMENT_MAX = 154, + UNSORTED_SEGMENT_SUM = 155, + ATAN2 = 156, + UNSORTED_SEGMENT_MIN = 157, + SIGN = 158, + BITCAST = 159, + BITWISE_XOR = 160, + RIGHT_SHIFT = 161, + // All Operators start with STABLEHLO_ prefixes are subject to change + // Many of the ops below can not be executed by TFlite runtime + STABLEHLO_LOGISTIC = 162, // WARNING: Do not have runtime support + STABLEHLO_ADD = 163, + STABLEHLO_DIVIDE = 164, // WARNING: No runtime support yet + STABLEHLO_MULTIPLY = 165, + STABLEHLO_MAXIMUM = 166, + STABLEHLO_RESHAPE = 167, // WARNING: No runtime support yet + STABLEHLO_CLAMP = 168, // WARNING: No runtime support + STABLEHLO_CONCATENATE = 169, // WARNING: No runtime support + STABLEHLO_BROADCAST_IN_DIM = 170, // WARNING: No runtime support + STABLEHLO_CONVOLUTION = 171, // WARNING: No runtime support + STABLEHLO_SLICE = 172, // WARNING: No runtime support + STABLEHLO_CUSTOM_CALL = 173, // WARNING: No runtime support + STABLEHLO_REDUCE = 174, // WARNING: No runtime support + STABLEHLO_ABS = 175, // WARNING: No runtime support + STABLEHLO_AND = 176, // WARNING: No runtime support + STABLEHLO_COSINE = 177, // WARNING: No runtime support + STABLEHLO_EXPONENTIAL = 178, // WARNING: No runtime support + STABLEHLO_FLOOR = 179, // WARNING: No runtime support + STABLEHLO_LOG = 180, // WARNING: No runtime support + STABLEHLO_MINIMUM = 181, + STABLEHLO_NEGATE = 182, // WARNING: No runtime support + STABLEHLO_OR = 183, // WARNING: No runtime support + STABLEHLO_POWER = 184, // WARNING: No runtime support + STABLEHLO_REMAINDER = 185, // WARNING: No runtime support + STABLEHLO_RSQRT = 186, // WARNING: No runtime support + STABLEHLO_SELECT = 187, // WARNING: No runtime support + STABLEHLO_SUBTRACT = 188, // WARNING: No runtime support + STABLEHLO_TANH = 189, // WARNING: No runtime support + STABLEHLO_SCATTER = 190, + STABLEHLO_COMPARE = 191, // WARNING: No runtime support + STABLEHLO_CONVERT = 192, // WARNING: No runtime support + STABLEHLO_DYNAMIC_SLICE = 193, // WARNING: No runtime support + STABLEHLO_DYNAMIC_UPDATE_SLICE = 194, // WARNING: No runtime support + STABLEHLO_PAD = 195, + STABLEHLO_IOTA = 196, // WARNING: No runtime support + STABLEHLO_DOT_GENERAL = 197, // WARNING: No runtime support + STABLEHLO_REDUCE_WINDOW = 198, + STABLEHLO_SORT = 199, // WARNING: No runtime support + STABLEHLO_WHILE = 200, // WARNING: No runtime support + STABLEHLO_GATHER = 201, + STABLEHLO_TRANSPOSE = 202, // WARNING: No runtime support + DILATE = 203, + STABLEHLO_RNG_BIT_GENERATOR = 204, + REDUCE_WINDOW = 205 (deprecated), +} +// LINT.ThenChange(nnapi_linter/linter.proto) + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions, + CumsumOptions, + CallOnceOptions, + BroadcastToOptions, + Rfft2dOptions, + Conv3DOptions, + HashtableOptions, + HashtableFindOptions, + HashtableImportOptions, + HashtableSizeOptions, + VarHandleOptions, + ReadVariableOptions, + AssignVariableOptions, + RandomOptions, + BucketizeOptions, + GeluOptions, + DynamicUpdateSliceOptions, + UnsortedSegmentProdOptions, + UnsortedSegmentMaxOptions, + UnsortedSegmentMinOptions, + UnsortedSegmentSumOptions, + ATan2Options, + SignOptions, + BitcastOptions, + BitwiseXorOptions, + RightShiftOptions, + // DO NOT add new options this union, will cause failure in Java api + // generation otherwise + // Add new builtin options into builtin options 2 instead +} + +union BuiltinOptions2{ + StablehloConcatenateOptions, + StablehloBroadcastInDimOptions, + StablehloSliceOptions, + StablehloConvolutionOptions, + StablehloCustomCallOptions, + StablehloReduceOptions, + StablehloScatterOptions, + StablehloCompareOptions, + StablehloDynamicSliceOptions, + StablehloPadOptions, + StablehloIotaOptions, + StablehloDotGeneralOptions, + StablehloReduceWindowOptions, + StablehloSortOptions, + StablehloWhileOptions, + StablehloGatherOptions, + StablehloTransposeOptions, + DilateOptions, + StablehloRngBitGeneratorOptions, + ReduceWindowOptions (deprecated), +} + +table StablehloGatherOptions{ + offset_dims : [long]; + collapsed_slice_dims : [long]; + start_index_map : [long]; + index_vector_dim : long; + slice_sizes : [long]; + indices_are_sorted : bool; +} + +table StablehloTransposeOptions{ + permutation : [long]; +} + +enum StablehloPrecisionConfig : uint { + DEFAULT, + HIGH, + HIGHEST, +} + +table StablehloDotGeneralOptions{ + lhs_batching_dimensions : [long]; + rhs_batching_dimensions : [long]; + lhs_contracting_dimensions : [long]; + rhs_contracting_dimensions : [long]; + precision_config : [StablehloPrecisionConfig]; +} + +table StablehloReduceWindowOptions{ + window_dimensions : [long]; + window_strides : [long]; + base_dilations : [long]; + window_dilations : [long]; + padding : [long]; + body_subgraph_index : int; +} + +table StablehloWhileOptions{ + cond_subgraph_index : int; + body_subgraph_index : int; +} + +table StablehloSortOptions{ + dimension : long; + is_stable : bool; + comparator_subgraph_index : int; +} + +table StablehloConcatenateOptions { + dimension : long; +} + +table StablehloBroadcastInDimOptions{ + broadcast_dimensions : [long]; +} + +enum StablehloComparisonDirection : uint { + STABLEHLO_COMPARISON_DIRECTION_EQ, + STABLEHLO_COMPARISON_DIRECTION_NE, + STABLEHLO_COMPARISON_DIRECTION_GE, + STABLEHLO_COMPARISON_DIRECTION_GT, + STABLEHLO_COMPARISON_DIRECTION_LE, + STABLEHLO_COMPARISON_DIRECTION_LT, + +} + +enum StablehloComparisonType : uint { + STABLEHLO_COMPARISON_TYPE_NOTYPE, + STABLEHLO_COMPARISON_TYPE_FLOAT, + STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + STABLEHLO_COMPARISON_TYPE_SIGNED, + STABLEHLO_COMPARISON_TYPE_UNSIGNED, +} + +table StablehloCompareOptions{ + comparison_direction : StablehloComparisonDirection; + compare_type : StablehloComparisonType; +} + +table StablehloDynamicSliceOptions{ + slice_sizes : [long]; +} + +table StablehloPadOptions{ + edge_padding_low : [long]; + edge_padding_high : [long]; + interior_padding : [long]; +} + +table StablehloIotaOptions{ + iota_dimension : long; +} + +table StablehloCustomCallOptions { + call_target_name : string; + has_side_effect : bool; + backend_config: string; + api_version : int; // will be decprecated + called_computations: [int]; // should point to subgraphs of the computations + custom_attributes : [ubyte]; +} + +table StablehloReduceOptions { + dimensions : [long]; + body_subgraph_index : int; +} + +table StablehloSliceOptions{ + start_indices : [long]; + limit_indices : [long]; + strides : [long]; +} + +table StablehloConvolutionOptions{ + window_strides : [long]; + padding : [long]; + lhs_dilation : [long]; + rhs_dilation : [long]; + window_reversal : [bool]; + input_batch_dimension : long; + input_feature_dimension : long; + input_spatial_dimensions : [long]; + kernel_input_feature_dimension : long; + kernel_output_feature_dimension : long; + kernel_spatial_dimensions : [long]; + output_batch_dimension : long; + output_feature_dimension : long; + output_spatial_dimensions : [long]; + feature_group_count : long; + batch_group_count : long; + precision_config : [StablehloPrecisionConfig]; +} + +table StablehloScatterOptions { + indices_are_sorted: bool; + update_window_dims: [long]; + inserted_window_dims: [long]; + scatter_dims_to_operand_dims: [long]; + index_vector_dim: long; + unique_indices: bool; + update_computation_subgraph_index: int; +} + +enum RngAlgorithm : byte { + // An algorithm auto-selected by the system according to device type. + DEFAULT = 0, + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + PHILOX = 1, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + THREEFRY = 2, +} + +table StablehloRngBitGeneratorOptions { + algorithm:RngAlgorithm; +} + +// LINT.IfChange +enum Padding : byte { SAME, VALID } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// LINT.IfChange +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; + // Parameters for Conv2D version 8 or above. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +// Options for both Conv3D and Conv3DTranspose. +table Conv3DOptions { + padding:Padding; + stride_d:int; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_d_factor:int = 1; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +// LINT.IfChange +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; + + // Parameters for FullyConnected version 11 or above. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 3. + pot_scale_int16:bool = true; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + // This field is currently ignored in the L2 Norm Op. + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// LINT.IfChange +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 3. + asymmetric_quantize_inputs:bool; + + // Parameter for unidirectional sequence RNN version 4. + diagonal_recurrent_tensors:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; + // Parameters for Gather version 5 or above. + batch_dims: int = 0; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; + // If true, then the end tensor is an offset of the begin tensor. + offset: bool; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + // Parameters supported by version 1, 2, 3: + padding:Padding; + stride_w:int; + stride_h:int; + + // Parameters supported by version 4: + fused_activation_function:ActivationFunctionType = NONE; + + // Parameters for TransposeConv version 5 or above. + // If set, use this for bias and accumulator. + // When set, quantized_bias_type defines the dtype for both bias and accumulator. + quantized_bias_type: TensorType; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +// LINT.IfChange +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td) + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table CallOnceOptions { + init_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + +table BroadcastToOptions { +} + +table Rfft2dOptions { +} + +table HashtableOptions { + // The identity of hash tables. This identity will be used across different + // subgraphs in the same interpreter instance. + table_id:int; + key_dtype:TensorType; + value_dtype:TensorType; +} + +table HashtableFindOptions { +} + +table HashtableImportOptions { +} + +table HashtableSizeOptions { +} + +table VarHandleOptions { + container:string; + shared_name:string; +} + +table ReadVariableOptions { +} + +table AssignVariableOptions { +} + +table RandomOptions { + seed: long; + seed2: long; +} + +table BucketizeOptions { + boundaries: [float]; // The bucket boundaries. +} + +table GeluOptions { + approximate: bool; +} + +table DynamicUpdateSliceOptions { +} + +table UnsortedSegmentProdOptions { +} + +table UnsortedSegmentMaxOptions { +} + +table UnsortedSegmentSumOptions { +} + +table ATan2Options { +} + +table UnsortedSegmentMinOptions{ +} + +table SignOptions { +} + +table BitcastOptions { +} + +table BitwiseXorOptions { +} + +table RightShiftOptions { +} + +table DilateOptions { +} + +enum ReduceWindowFunction : int { + UNSUPPORTED, + ADD, + MUL, + MINIMUM, + MAXIMUM, + ALL, + ANY, +} + +table ReduceWindowOptions (deprecated) { + reduce_function: ReduceWindowFunction; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; + + // When an op is using custom_options in a model that is larger than 2GB, then + // we instead use the following attributes to find the buffer location which + // is stored outside of flatbuffers, the offset is calculated relative to the + // beginning of the file and is only valid if > 1 + large_custom_options_offset: ulong; + large_custom_options_size: ulong; + + // Flatbuffers union struct has a 128 elements limit in JAVA, so a second + // union is added, in the case of where BuitlinOptions2 runs out, a third + // one can be added + builtin_options_2 : BuiltinOptions2; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); + + // In a model that is larger than 2GB, then buffers instead uses the following + // attributes to find stored data, which is outside of flatbuffers + // the offset is calculated relative to the beginning of the file and is only + // valid if > 1. + offset: ulong; + size: ulong; +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + signature_key:string; + + // Model tag, deprecated. + deprecated_tag:string (deprecated); + + // Index of subgraphs that corresponds to the exported method. + subgraph_index:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; +} + +root_type Model; diff --git a/apps/hallmark/contrib/transformer_params.proto b/apps/hallmark/contrib/transformer_params.proto new file mode 100644 index 000000000000..2b2089ba680f --- /dev/null +++ b/apps/hallmark/contrib/transformer_params.proto @@ -0,0 +1,158 @@ +/* Copyright 2023 The ODML Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package odml.infra.proto; + +option java_package = "com.google.odml.infra.proto"; +option java_outer_classname = "TransformerParametersProto"; + +// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf) +message TransformerParameters { + // Batch size of tensors. + int32 batch_size = 1; + + // TODO: Deprecate parameter. + // Maximum sequence length of the input/output tensor. + int32 max_seq_length = 2; + + // Embedding dimension (or model dimension), `d_model` in the paper. + // `d_k` == `d_v` == `d_model`/`h`. + int32 embedding_dim = 3; + + // Hidden dimension used in the feedforward layer, `d_ff` in the paper. + int32 hidden_dimension = 4; + + // Head dimension, `d_k` or `d_v` in the paper. + int32 head_dimension = 5; + + // Number of heads, `h` in the paper. + int32 num_heads = 6; + + // Number of stacked transformers, `N` in the paper. + int32 num_stacks = 7; + + // Deprecated: bool use_mqa. Use num_kv_heads below. + reserved 8; + + // Number of kv heads. 0 means Multi-Head-Attention (MHA), key and value have + // same number of heads as query; 1 means Multi-Query-Attention (MQA), key and + // value have one head; otherwise, this specifies the number of heads for key + // and value, and Grouped-Query-Attention (GQA) will be used. See + // https://arxiv.org/pdf/2305.13245.pdf for details. + int32 num_kv_heads = 9; + + // Different types of attention mask type. + enum AttentionMaskType { + UNSPECIFIED = 0; + CAUSAL = 1; + PREFIX = 2; + } + // Deprecated, use SelfAttentionParameters. + reserved 10; + + enum Activation { + ACTIVATION_UNSPECIFIED = 0; + // GELU stands for Gaussian Error Linear Unit, see + // https://arxiv.org/pdf/1606.08415.pdf for deatils. + GELU = 1; + // SILU stands for Sigmoid-Weighted Linear Unit, see + // https://arxiv.org/pdf/1702.03118v3.pdf for deatils. + SILU = 2; + // RELU stands for Rectified Linear Unit, see + // https://dl.acm.org/doi/10.5555/3104322.3104425 for details. + RELU = 3; + } + + enum Norm { + NORM_UNSPECIFIED = 0; + // No normalization operation will be perform. + NO_NORM = 1; + // RMSNORM stands for Root Mean Square Layer Normalization, see + // https://arxiv.org/pdf/1910.07467.pdf for deatils. + RMS_NORM = 2; + // LAYERNORM stands for Layer Normalization, see + // https://arxiv.org/pdf/1607.06450v1.pdf for deatils. + LAYER_NORM = 3; + } + + message FeedForwardParameters { + // If `no_bias`, fully connect will degrade to matrix multiply. + bool no_bias = 1; + Activation activation = 2; + // Normalization before the dense layer. + Norm pre_norm = 3; + // Normalization after the dense layer. + Norm post_norm = 4; + } + + FeedForwardParameters feed_forward_parameters = 11; + + message FinalProjectParameters { + // If `no_bias`, fully connect will degrade to matrix multiply. + bool no_bias = 1; + + // The value to set the soft cap (Tanh) before calling the final project + // layer. Setting the value to be <=0 indicates there is no cap. + float soft_cap_value = 2; + } + + FinalProjectParameters final_project_parameters = 12; + + // Normalization before the transformer block. + Norm pre_norm = 13; + // Normalization after the transformer block. + Norm post_norm = 14; + Norm final_norm = 15; + + enum AttentionScaleType { + SCALE_TYPE_UNSPECIFIED = 0; + + // Per dimension scale, query is scaled by log_2(1 + exp(w)) / + // sqrt(head_dim) where w is s static weight. + SCALE_TYPE_PER_DIM_SCALE = 1; + + // Query is scaled by 1/sqrt(head_dim). + SCALE_TYPE_INV_SQRT_HEAD_DIM = 2; + } + + message SelfAttentionParameters { + // Whether bias term is used in Q, K, and V projections. + bool qkv_no_bias = 1; + // Whether bias term is used in post-projection. + bool post_proj_no_bias = 2; + + AttentionMaskType attention_mask_type = 3; + + // The value to set the soft cap (Tanh) before calling the attention + // softmax. Setting the value to be <=0 indicates there is no cap. + float soft_cap_value = 4; + + // If specified, inference pipeline will use the specified scale type. + // Otherwise SCALE_TYPE_PER_DIM_SCALE is used for Multi-Query-Attention by + // default, and SCALE_TYPE_INV_SQRT_HEAD_DIM is used for + // Multi-Head-Attention by default. + optional AttentionScaleType attention_scale_type = 5; + } + + SelfAttentionParameters self_attention_parameters = 16; + + reserved 17; + // Whether to skip absolute positional embeddings. If the value is false, then + // the absolute positional embeddings will be applied to the token embeddings + // before the attention. + bool skip_absolute_positional_embeddings = 18; +} diff --git a/apps/hallmark/contrib/weights_loader.cc b/apps/hallmark/contrib/weights_loader.cc new file mode 100644 index 000000000000..a7c7e81dd7ff --- /dev/null +++ b/apps/hallmark/contrib/weights_loader.cc @@ -0,0 +1,438 @@ +#include "contrib/weights_loader.h" + +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "contrib/memory_mapped_file.h" +#include "contrib/status_helpers.h" +// macOS system headers #define this value in syslimits.h +#undef ARG_MAX +#include "contrib/tflite_schema_generated.h" + +namespace hallmark { + +namespace { + +class DataHolderMemoryMappedFile : public DataHolder { +public: + explicit DataHolderMemoryMappedFile(absl::string_view path) + : file(path) { + } + MemoryMappedFile file; +}; + +using FeedForwardWeights = LlmWeights::FeedForwardWeights; +using SelfAttentionWeights = LlmWeights::SelfAttentionWeights; + +class LlmWeightsLoader { +public: + LlmWeightsLoader(absl::string_view weight_path, const LlmParams ¶ms); + + absl::StatusOr LoadWeights(); + +private: + LlmParams params_; + std::shared_ptr mapped_file_; + absl::flat_hash_map> weights_; + + absl::StatusOr LoadSelfAttention( + int layer_id); + + absl::StatusOr LoadFeedForward(int layer_id); + + absl::StatusOr LoadWeight(absl::string_view tensor_name, + std::vector expected_dims, + size_t dim_scale_if_any = 0) const; + + absl::StatusOr LoadTransposedWeight( + absl::string_view tensor_name, std::vector expected_dims, + size_t dim_scale_if_any) const; + + absl::StatusOr> LoadNormWeights( + LlmParams::Norm norm_type, absl::string_view basename); + + // is_query: indicating whether the weight is for query projection or not. + // Note that the key/value projection weights are handled differently between + // MHA vs. MQA. + absl::StatusOr TryCacheThenLoadSelfAttention( + absl::string_view filename_prefix, absl::string_view alt_filename_prefix, + bool is_query); + + void BuildWeightsMapFromTfliteModel(char *data); +}; + +// According to norm_type, load necessary weights with given basename. +absl::StatusOr> +LlmWeightsLoader::LoadNormWeights(LlmParams::Norm norm_type, + absl::string_view basename) { + switch (norm_type) { + case LlmParams::Norm::UNSPECIFIED: + break; + case LlmParams::Norm::NO_NORM: + break; + case LlmParams::Norm::RMS_NORM: { + auto rms_norm_weights = RMSNormWeights(); + ASSIGN_OR_RETURN(rms_norm_weights.norm_weight, + LoadWeight(absl::StrCat(basename, ".scale"), + {(int)params_.model_dim_D})); + return rms_norm_weights; + } + case LlmParams::Norm::LAYER_NORM: { + auto layer_norm_weights = LayerNormWeights(); + ASSIGN_OR_RETURN(layer_norm_weights.beta, + LoadWeight(absl::StrCat(basename, ".bias"), + {1, 1, (int)params_.model_dim_D})); + ASSIGN_OR_RETURN(layer_norm_weights.gamma, + LoadWeight(absl::StrCat(basename, ".scale"), + {1, 1, (int)params_.model_dim_D})); + return layer_norm_weights; + } + default: + break; + } + return std::nullopt; +} + +void LlmWeightsLoader::BuildWeightsMapFromTfliteModel(char *data) { + auto *tflite_model = ::tflite::GetModel(mapped_file_->file.data()); + const auto *buffers = tflite_model->buffers(); + for (const auto *subgraph : *tflite_model->subgraphs()) { + for (const auto *tfl_tensor : *subgraph->tensors()) { + auto tensor_name = absl::string_view(tfl_tensor->name()->data(), + tfl_tensor->name()->size()); + halide_type_t halide_type; + CHECK(tfl_tensor->buffer() < buffers->size()); + const ::tflite::Buffer &tfl_buffer = *buffers->Get(tfl_tensor->buffer()); + switch (tfl_tensor->type()) { + case ::tflite::TensorType::FLOAT32: + halide_type = halide_type_t(halide_type_float, 32); + break; + case ::tflite::TensorType::INT8: + halide_type = halide_type_t(halide_type_int, 8); + break; + case ::tflite::TensorType::INT4: + halide_type = halide_type_t(halide_type_int, 4); + break; + default: + std::cerr << "Unsupported tensor type: " << (int) tfl_tensor->type(); + std::abort(); + break; + } + + std::vector tfl_dims(tfl_tensor->shape()->begin(), + tfl_tensor->shape()->end()); + // Halide convention has dims in opposite order of Tensor. + std::vector halide_dims; + halide_dims.reserve(tfl_dims.size()); + for (size_t i = tfl_dims.size(); i > 0; i--) { + halide_dims.push_back(static_cast(tfl_dims[i - 1])); + } + + weights_[tensor_name] = Halide::Runtime::Buffer<>( + halide_type, data + tfl_buffer.offset(), halide_dims); + } + } +} + +absl::StatusOr LlmWeightsLoader::LoadWeight( + absl::string_view tensor_name, std::vector expected_dims, + size_t dim_scale_if_any) const { + if (!weights_.contains(tensor_name)) { + // LOG(WARNING) << "Tensor not found: " << tensor_name; + return ScaledTensor(); + } + ScaledTensor result; + result.weights = weights_.at(tensor_name); + // Check dimensions. + { + bool correct_dimension = true; + const int d = result.weights.dimensions(); + correct_dimension &= (d == expected_dims.size()); + // Note that 'expected' is in the reverse order of what we expect. + for (int i = 0; i < d; ++i) { + correct_dimension &= + (result.weights.dim(i).extent() == expected_dims[d - i - 1]); + } + if (!correct_dimension) { + return absl::InvalidArgumentError( + absl::StrCat("Dimension mismatch for ", tensor_name)); + } + } + + if (result.weights.type().code == halide_type_float) { + result.scale = Halide::Runtime::Buffer<>(); + return result; + } + + // Following are logic for quantized weights. + std::string scale_tensor_name = absl::StrCat(tensor_name, "_quantized_scale"); + if (!weights_.contains(scale_tensor_name)) { + return absl::NotFoundError( + absl::StrCat("Scale tensor not found: ", scale_tensor_name)); + } + result.scale = weights_.at(scale_tensor_name); + result.dim_scale = dim_scale_if_any; + if (expected_dims[dim_scale_if_any] != result.scale.number_of_elements()) { + return absl::InvalidArgumentError( + absl::StrCat("Dimension mismatch for ", scale_tensor_name)); + } + return result; +} + +absl::StatusOr LlmWeightsLoader::LoadTransposedWeight( + absl::string_view tensor_name, std::vector expected_dims, + size_t dim_scale_if_any) const { + return LoadWeight( + tensor_name, + std::vector(expected_dims.rbegin(), expected_dims.rend()), + 1 - dim_scale_if_any); +} + +LlmWeightsLoader::LlmWeightsLoader(absl::string_view weight_path, + const LlmParams ¶ms) + : params_(params) { + mapped_file_ = std::make_shared(weight_path); + if (mapped_file_->file.valid()) { + BuildWeightsMapFromTfliteModel(static_cast(mapped_file_->file.data())); + } +} + +absl::StatusOr LlmWeightsLoader::TryCacheThenLoadSelfAttention( + absl::string_view filename_prefix, absl::string_view alt_filename_prefix, + bool is_query) { + ScaledTensor r; + if (!is_query) { + ASSIGN_OR_RETURN(r, LoadTransposedWeight(filename_prefix, + {(int)params_.model_dim_D, + (int)params_.num_kv_heads * + (int)params_.head_dim_H}, + 1)); + if (!r.weights.data()) { + ASSIGN_OR_RETURN(r, LoadTransposedWeight(alt_filename_prefix, + {(int)params_.model_dim_D, + (int)params_.num_kv_heads * + (int)params_.head_dim_H}, + 1)); + } + // r->SetMetadata("self_attention_reshaped_weight_N", params_.num_kv_heads); + } else { + ASSIGN_OR_RETURN(r, LoadTransposedWeight( + filename_prefix, + {(int)params_.model_dim_D, + (int)params_.n_heads_N * (int)params_.head_dim_H}, + 1)); + if (!r.weights.data()) { + ASSIGN_OR_RETURN(r, LoadTransposedWeight(alt_filename_prefix, + {(int)params_.model_dim_D, + (int)params_.n_heads_N * + (int)params_.head_dim_H}, + 1)); + } + // r->SetMetadata("self_attention_reshaped_weight_N", params_.n_heads_N); + } + // r->SetMetadata("in_dim_last_in_weight", 1); + return r; +} + +absl::StatusOr LlmWeightsLoader::LoadFeedForward( + int layer_id) { + const auto ¶ms = params_; + auto ff_file_prefix = + absl::StrCat("params.lm.transformer.x_layers_", layer_id, ".ff_layer."); + FeedForwardWeights feed_forward; + + ASSIGN_OR_RETURN( + feed_forward.pre_norm_weight, + LoadNormWeights(params.ff_params.pre_norm, + absl::StrCat(ff_file_prefix, "pre_layer_norm"))); + + ASSIGN_OR_RETURN( + feed_forward.post_norm_weight, + LoadNormWeights(params.ff_params.post_norm, + absl::StrCat(ff_file_prefix, "post_layer_norm"))); + + ASSIGN_OR_RETURN( + feed_forward.layer_1_weight, + LoadTransposedWeight(absl::StrCat(ff_file_prefix, "ffn_layer1.w"), + {(int)params.model_dim_D, (int)params.hidden_dim_HD}, + /*original_dim_scale=*/1)); + if (!feed_forward.layer_1_weight.weights.data()) { + ASSIGN_OR_RETURN(feed_forward.layer_1_weight, + LoadTransposedWeight( + absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w"), + {(int)params.model_dim_D, (int)params.hidden_dim_HD}, + /*original_dim_scale=*/1)); + } + ASSIGN_OR_RETURN( + feed_forward.layer_1_gate_weight, + LoadTransposedWeight(absl::StrCat(ff_file_prefix, "ffn_layer1_gate.w"), + {(int)params.model_dim_D, (int)params.hidden_dim_HD}, + /*original_dim_scale=*/1)); + if (!feed_forward.layer_1_gate_weight.weights.data()) { + ASSIGN_OR_RETURN( + feed_forward.layer_1_gate_weight, + LoadTransposedWeight( + absl::StrCat(ff_file_prefix, "ffn_layer1_gate.linear.w"), + {(int)params.model_dim_D, (int)params.hidden_dim_HD}, + /*original_dim_scale=*/1)); + } + ASSIGN_OR_RETURN( + feed_forward.layer_2_weight, + LoadTransposedWeight(absl::StrCat(ff_file_prefix, "ffn_layer2.w"), + {(int)params.hidden_dim_HD, (int)params.model_dim_D}, + /*original_dim_scale=*/1)); + if (!feed_forward.layer_2_weight.weights.data()) { + ASSIGN_OR_RETURN(feed_forward.layer_2_weight, + LoadTransposedWeight( + absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"), + {(int)params.hidden_dim_HD, (int)params.model_dim_D}, + /*original_dim_scale=*/1)); + } + + if (!params.ff_params.no_bias) { + ASSIGN_OR_RETURN( + feed_forward.layer_1_bias, + LoadWeight(absl::StrCat(ff_file_prefix, "ffn_layer1.bias.b"), + {(int)params.hidden_dim_HD})); + ASSIGN_OR_RETURN( + feed_forward.layer_1_gate_bias, + LoadWeight(absl::StrCat(ff_file_prefix, "ffn_layer1_gate.bias.b"), + {(int)params.hidden_dim_HD})); + ASSIGN_OR_RETURN( + feed_forward.layer_2_bias, + LoadWeight(absl::StrCat(ff_file_prefix, "ffn_layer2.bias.b"), + {(int)params.model_dim_D})); + } + + return feed_forward; +} + +absl::StatusOr LlmWeightsLoader::LoadSelfAttention( + int layer_id) { + const auto ¶ms = params_; + SelfAttentionWeights self_attention; + + auto sa_file_prefix = + absl::StrCat("params.lm.transformer.x_layers_", layer_id); + + ASSIGN_OR_RETURN( + self_attention.pre_norm_weight, + LoadNormWeights(params.sa_params.pre_norm, + absl::StrCat(sa_file_prefix, ".pre_layer_norm"))); + ASSIGN_OR_RETURN( + self_attention.post_norm_weight, + LoadNormWeights(params.sa_params.post_norm, + absl::StrCat(sa_file_prefix, ".post_layer_norm"))); + + absl::StrAppend(&sa_file_prefix, ".self_attention."); + + ASSIGN_OR_RETURN( + self_attention.k_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w"), + absl::StrCat(sa_file_prefix, "k.linear.w"), + /*is_query=*/false)); + ASSIGN_OR_RETURN( + self_attention.q_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w"), + absl::StrCat(sa_file_prefix, "q.linear.w"), + /*is_query=*/true)); + ASSIGN_OR_RETURN( + self_attention.v_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w"), + absl::StrCat(sa_file_prefix, "v.linear.w"), + /*is_query=*/false)); + + if (!params.sa_params.qkv_no_bias) { + ASSIGN_OR_RETURN( + self_attention.q_bias, + LoadWeight(absl::StrCat(sa_file_prefix, "q.bias.b"), + {(int)params.n_heads_N * (int)params.head_dim_H})); + ASSIGN_OR_RETURN( + self_attention.k_bias, + LoadWeight(absl::StrCat(sa_file_prefix, "k.bias.b"), + {(int)params.n_heads_N * (int)params.head_dim_H})); + ASSIGN_OR_RETURN( + self_attention.v_bias, + LoadWeight(absl::StrCat(sa_file_prefix, "v.bias.b"), + {(int)params.n_heads_N * (int)params.head_dim_H})); + } + + if (params.sa_params.attention_scale_type == + LlmParams::AttentionScaleType::PER_DIM_SCALE) { + ASSIGN_OR_RETURN( + self_attention.per_dim_scale, + LoadWeight(absl::StrCat(sa_file_prefix, "per_dim_scale.per_dim_scale"), + {(int)params.head_dim_H})); + } + ASSIGN_OR_RETURN(self_attention.post_proj_weight, + LoadWeight(absl::StrCat(sa_file_prefix, "post.w"), + {(int)params.model_dim_D, + (int)params.n_heads_N * (int)params.head_dim_H}, + /*dim_scale_if_any=*/0)); + if (!self_attention.post_proj_weight.weights.data()) { + ASSIGN_OR_RETURN( + self_attention.post_proj_weight, + LoadWeight(absl::StrCat(sa_file_prefix, "post.linear.w"), + {(int)params.model_dim_D, + (int)params.n_heads_N * (int)params.head_dim_H}, + /*dim_scale_if_any=*/0)); + } + if (!params.sa_params.post_proj_no_bias) { + ASSIGN_OR_RETURN(self_attention.post_proj_bias, + LoadWeight(absl::StrCat(sa_file_prefix, "post.bias.b"), + {(int)params.model_dim_D})); + } + + return self_attention; +} + +absl::StatusOr LlmWeightsLoader::LoadWeights() { + LlmWeights result; + + for (int layer_id = 0; layer_id < params_.num_transformer_M; ++layer_id) { + FeedForwardWeights ff; + ASSIGN_OR_RETURN(ff, LoadFeedForward(layer_id)); + result.ffs.push_back(std::move(ff)); + SelfAttentionWeights sa; + ASSIGN_OR_RETURN(sa, LoadSelfAttention(layer_id)); + result.sas.push_back(std::move(sa)); + } + + ASSIGN_OR_RETURN(result.final_norm_weight, + LoadNormWeights(params_.final_norm, "params.lm.final_ln")); + + ASSIGN_OR_RETURN(result.softmax_linear, + LoadTransposedWeight( + "params.lm.softmax.logits_ffn.w", + {(int)params_.model_dim_D, (int)params_.voc_size_V}, 1)); + if (!result.softmax_linear.weights.data()) { + ASSIGN_OR_RETURN( + result.softmax_linear, + LoadTransposedWeight( + "params.lm.softmax.logits_ffn.linear.w", + {(int)params_.model_dim_D, (int)params_.voc_size_V}, 1)); + } + if (!params_.final_proj_params.no_bias) { + ASSIGN_OR_RETURN(result.softmax_bias, + LoadWeight("params.lm.softmax.logits_ffn.bias.b", + {(int)params_.voc_size_V})); + } + + ASSIGN_OR_RETURN( + result.token_embedding, + LoadWeight("params.lm.token_embedding.w", + {(int)params_.voc_size_V, (int)params_.model_dim_D})); + + result.data_holder = mapped_file_; + return result; +} + +} // namespace + +absl::StatusOr LoadLlmWeights(absl::string_view tflite_path, const LlmParams ¶ms) { + LlmWeightsLoader loader(tflite_path, params); + return loader.LoadWeights(); +} + +} // namespace hallmark diff --git a/apps/hallmark/contrib/weights_loader.h b/apps/hallmark/contrib/weights_loader.h new file mode 100644 index 000000000000..6bafcc9fa732 --- /dev/null +++ b/apps/hallmark/contrib/weights_loader.h @@ -0,0 +1,20 @@ +#ifndef HALIDE_APPS_HALLMARK_WEIGHTS_LOADER_H_ +#define HALIDE_APPS_HALLMARK_WEIGHTS_LOADER_H_ + +#include "HalideBuffer.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "contrib/llm_params.h" +#include "contrib/llm_weights.h" + +namespace hallmark { + +absl::StatusOr LoadLlmWeights(absl::string_view tflite_path, const LlmParams ¶ms); + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_WEIGHTS_LOADER_H_ diff --git a/apps/hallmark/src/CMakeLists.txt b/apps/hallmark/src/CMakeLists.txt new file mode 100644 index 000000000000..eb0fe3f590f1 --- /dev/null +++ b/apps/hallmark/src/CMakeLists.txt @@ -0,0 +1,149 @@ +add_subdirectory(ml_ops) + +add_halide_generator(llm.generator llm_generator.cpp) +target_link_libraries(llm.generator PRIVATE hallmark_ml_ops absl::log) + +add_halide_library( + hallmark_rope_values FROM llm.generator + FUNCTION_NAME rope_values + GENERATOR LlmRoPEValues + PARAMS + head_dim_H=256 + processing_type=float32 + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_preprocessor FROM llm.generator + FUNCTION_NAME preprocessor + GENERATOR LlmPreprocessor + PARAMS + model_dim_D=2048 + skip_absolute_positional_embeddings=true + processing_type=float32 + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_transformer_no_kv_cache FROM llm.generator + FUNCTION_NAME transformer_no_kv_cache + GENERATOR LlmTransformer + PARAMS + seq_size_T=512 + model_dim_D=2048 + head_dim_H=256 + hidden_dim_HD=16384 + transformer_kind=prefix_only_uncached + processing_type=float32 + sa_pre_norm=rms + sa_post_norm=none + feedforward_pre_norm=rms + feedforward_post_norm=none + attention_scale_type=inverse_sqrt_head_dim + use_mqa=false + soft_cap=0.0 + feed_forward_params_activation=gelu + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_transformer_kv_update_cache FROM llm.generator + FUNCTION_NAME transformer_kv_update_cache + GENERATOR LlmTransformer + PARAMS + seq_size_T=512 + model_dim_D=2048 + hidden_dim_HD=16384 + head_dim_H=256 + transformer_kind=prefix_decode_update_cache + processing_type=float32 + sa_pre_norm=rms + sa_post_norm=none + feedforward_pre_norm=rms + feedforward_post_norm=none + attention_scale_type=inverse_sqrt_head_dim + use_mqa=false + soft_cap=0.0 + feed_forward_params_activation=gelu + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_transformer_kv_use_cache FROM llm.generator + FUNCTION_NAME transformer_kv_use_cache + GENERATOR LlmTransformer + PARAMS + seq_size_T=512 + model_dim_D=2048 + hidden_dim_HD=16384 + head_dim_H=256 + transformer_kind=prefix_decode_use_cache + processing_type=float32 + sa_pre_norm=rms + sa_post_norm=none + feedforward_pre_norm=rms + feedforward_post_norm=none + attention_scale_type=inverse_sqrt_head_dim + use_mqa=false + soft_cap=0.0 + feed_forward_params_activation=gelu + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_postprocessor FROM llm.generator + FUNCTION_NAME postprocessor + GENERATOR LlmPostprocessor + PARAMS + seq_size_T=512 + model_dim_D=2048 + head_dim_H=256 + voc_size_V=256000 + NAMESPACE hallmark + FEATURES c_plus_plus_name_mangling +) + +add_halide_library( + hallmark_position_embedding FROM llm.generator + FUNCTION_NAME position_embedding + GENERATOR LlmPositionEmbedding +) + +# -------------------- + +add_library(hallmark_llm + llm.cpp) +target_include_directories(hallmark_llm PRIVATE + $) +target_link_libraries(hallmark_llm + PRIVATE + absl::status + hallmark_contrib + hallmark_position_embedding + hallmark_postprocessor + hallmark_preprocessor + hallmark_rope_values + hallmark_transformer_kv_update_cache + hallmark_transformer_kv_use_cache + hallmark_transformer_no_kv_cache) + +add_executable(llm_runner llm_runner.cpp) +target_link_libraries(llm_runner + PRIVATE + absl::flags + absl::flags_parse + hallmark_contrib + hallmark_llm + hallmark_position_embedding + hallmark_postprocessor + hallmark_preprocessor + hallmark_rope_values + hallmark_transformer_kv_update_cache + hallmark_transformer_kv_use_cache + sentencepiece) +target_include_directories(llm_runner PUBLIC ${sentencepiece_SOURCE_DIR}) diff --git a/apps/hallmark/src/llm.cpp b/apps/hallmark/src/llm.cpp new file mode 100644 index 000000000000..e767aa036551 --- /dev/null +++ b/apps/hallmark/src/llm.cpp @@ -0,0 +1,699 @@ +#include "src/llm.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/log/check.h" +#include "contrib/status_helpers.h" +#include "hallmark_position_embedding.h" +#include "hallmark_postprocessor.h" +#include "hallmark_preprocessor.h" +#include "hallmark_rope_values.h" +#include "hallmark_transformer_kv_update_cache.h" +#include "hallmark_transformer_kv_use_cache.h" +#include "hallmark_transformer_no_kv_cache.h" + +#define DUMP_INFO_TO_STDOUT 0 + +namespace hallmark { + +namespace { + +void dump_segpos(const float *data, size_t n) { + for (size_t i = 0; i < n; i++) { + std::cout << "data[" << i << "] = " << data[i] << "\n"; + } +} + +void do_indent(int indent) { + while (indent-- > 0) { + std::cout << "\t"; + } +} + +void PrintBuffer(const std::string &base_name, const Halide::Runtime::Buffer<> &buf, const char *end_of_line = "\n") { +#if DUMP_INFO_TO_STDOUT + std::cout << base_name << ": ["; + const char *prefix = ""; + for (int32_t i = 0; i < buf.dimensions(); i++) { + std::cout << prefix << "{" << buf.dim(i).min() << ", " + << buf.dim(i).extent() << "}"; + prefix = ", "; + } + std::cout << "]" << end_of_line + << std::flush; +#endif +} + +void DumpFloatBuffer(const std::string &base_name, const Halide::Runtime::Buffer<> &buf, + int dim0_count, int dim1_count = 1) { +#if DUMP_INFO_TO_STDOUT + PrintBuffer(base_name, buf); + Halide::Runtime::Buffer temp_buf = buf; + float *data = temp_buf.data(); + for (int j = temp_buf.dim(1).min(); + j < + std::min(temp_buf.dim(1).max() + 1, temp_buf.dim(1).min() + dim1_count); + j++) { + std::cout << "Start of dump for " << base_name << " (0, " << j << ") " + << /*data <<*/ ":\n"; + for (int i = temp_buf.dim(0).min(); + i < std::min(temp_buf.dim(0).max() + 1, + temp_buf.dim(0).min() + dim0_count); + i++) { + std::cout << "data[" << i << "] = " << data[i] << "\n"; + } + std::cout << "End of dump for " << base_name << " (0, " << j << "):\n"; + data += temp_buf.dim(1).stride(); + } + std::cout << std::flush; +#endif +} + +void PrintTensorInfo(int indent, const char *label, const ScaledTensor &tensor, const char *end_of_line = "\n") { + do_indent(indent); + PrintBuffer(std::string(label) + " weights: ", tensor.weights, ""); + PrintBuffer(" scale: ", tensor.scale, ""); + std::cout << " dim_scale: " << tensor.dim_scale << end_of_line; +} + +void PrintNormWeightInfo(int indent, const char *label, const std::optional &norm_weights) { + do_indent(indent); + if (norm_weights) { + switch (norm_weights->index()) { + case 0: { + const auto &rms_weight = std::get<0>(*norm_weights); + std::cout << label << ": RMS Norm "; + PrintTensorInfo(0, "", rms_weight.norm_weight, ""); + } break; + case 1: { + const auto &layer_weight = std::get<1>(*norm_weights); + std::cout << label << ": Layer Norm epsilon: " << layer_weight.epsilon << " gamma: "; + PrintTensorInfo(0, "", layer_weight.gamma, ""); + std::cout << " beta: "; + PrintTensorInfo(0, "", layer_weight.beta, ""); + } break; + default: + std::cout << label << " "; + break; + } + } else { + std::cout << label << ": "; + } + std::cout << "\n"; +} + +void PrintInFloatBuffer2D(const std::string &base_name, + const Halide::Runtime::Buffer<> &buf) { +#if DUMP_INFO_TO_STDOUT + PrintBuffer(base_name, buf); + const Halide::Runtime::Buffer &fp_buf = buf; + if (fp_buf.dim(0).extent() > 0) { + std::cout << base_name << "[0, 0] : " << fp_buf(0, 0) << "\n"; + } else { + std::cout << base_name << ": empty\n"; + return; + } + if (fp_buf.dim(0).extent() > 1) { + int32_t index = fp_buf.dim(0).extent() - 1; + std::cout << base_name << "[" << index << ", 0] : " << fp_buf(index, 0) + << "\n"; + } + if (fp_buf.dim(1).extent() > 0) { + std::cout << base_name << "[0, 1] : " << fp_buf(0, 1) << "\n"; + if (fp_buf.dim(0).extent() > 1) { + int32_t index = fp_buf.dim(0).extent() - 1; + std::cout << base_name << "[" << index << ", 1] : " << fp_buf(index, 1) + << "\n"; + } + } + if (fp_buf.dim(1).extent() > 1) { + int32_t index_outer = fp_buf.dim(1).extent() - 1; + std::cout << base_name << "[0, " << index_outer + << "] : " << fp_buf(0, index_outer) << "\n"; + if (fp_buf.dim(0).extent() > 1) { + int32_t index_inner = fp_buf.dim(0).extent() - 1; + std::cout << base_name << "[" << index_inner << ", " << index_outer + << "] : " << fp_buf(index_inner, index_outer) << "\n"; + } + } + std::cout << std::flush; +#endif +} + +void PrintInFloatBuffer(const std::string &base_name, const Halide::Runtime::Buffer<> &buf) { +#if DUMP_INFO_TO_STDOUT + PrintBuffer(base_name, buf); + const Halide::Runtime::Buffer &fp_buf = buf; + if (fp_buf.dim(0).extent() > 0) { + std::cout << base_name << "[0, 0, 0] : " << fp_buf(0, 0, 0) << "\n"; + } else { + std::cout << base_name << ": empty\n"; + return; + } + if (fp_buf.dim(0).extent() > 1) { + int32_t index = fp_buf.dim(0).extent() - 1; + std::cout << base_name << "[" << index + << ", 0, 0] : " << fp_buf(index, 0, 0) << "\n"; + } + if (fp_buf.dim(1).extent() > 0) { + std::cout << base_name << "[0, 1, 0] : " << fp_buf(0, 1, 0) << "\n"; + if (fp_buf.dim(0).extent() > 1) { + int32_t index = fp_buf.dim(0).extent() - 1; + std::cout << base_name << "[" << index + << ", 1, 0] : " << fp_buf(index, 1, 0) << "\n"; + } + } + std::cout << std::flush; +#endif +} + +} // anonymous namespace + +void do_indent(int indent) { + while (indent-- > 0) { + std::cout << "\t"; + } +} + +void Llm::PrintParamsAndWeights() const { +#if DUMP_INFO_TO_STDOUT + std::cout << "LLM Params:\n\t"; + std::cout << "num_transformer_M: " << llm_params_.num_transformer_M << "\n\t"; + std::cout << "batch_size_B: " << llm_params_.batch_size_B << "\n\t"; + std::cout << "seq_size_T: " << llm_params_.seq_size_T << "\n\t"; + std::cout << "model_dim_D: " << llm_params_.model_dim_D << "\n\t"; + std::cout << "hidden_dim_HD: " << llm_params_.hidden_dim_HD << "\n\t"; + std::cout << "head_dim_H: " << llm_params_.head_dim_H << "\n\t"; + std::cout << "n_heads_N: " << llm_params_.n_heads_N << "\n\t"; + std::cout << "voc_size_V: " << llm_params_.voc_size_V << "\n\t"; + std::cout << "num_kv_heads: " << llm_params_.num_kv_heads << "\n\t"; + std::cout << "model_type: " << static_cast(llm_params_.model_type) + << "\n\t"; + std::cout << "skip_absolute_positional_embeddings: " + << llm_params_.skip_absolute_positional_embeddings << "\n\t"; + std::cout << "sa_params:" + << "\n\t\t"; + std::cout << "qkv_no_bias: " << llm_params_.sa_params.qkv_no_bias << "\n\t\t"; + std::cout << "post_proj_no_bias: " << llm_params_.sa_params.post_proj_no_bias + << "\n\t\t"; + std::cout << "pre_norm: " << static_cast(llm_params_.sa_params.pre_norm) + << "\n\t\t"; + std::cout << "post_norm: " + << static_cast(llm_params_.sa_params.post_norm) << "\n\t\t"; + std::cout << "soft_cap_value: " << llm_params_.sa_params.soft_cap_value + << "\n\t\t"; + std::cout << "attention_scale_type: " + << static_cast(llm_params_.sa_params.attention_scale_type) + << "\n\t"; + std::cout << "ff_params:" + << "\n\t\t"; + std::cout << "no_bias: " << llm_params_.ff_params.no_bias << "\n\t\t"; + std::cout << "activation: " + << static_cast(llm_params_.ff_params.activation) << "\n\t\t"; + std::cout << "pre_norm: " << static_cast(llm_params_.ff_params.pre_norm) + << "\n\t\t"; + std::cout << "post_norm: " + << static_cast(llm_params_.ff_params.post_norm) << "\n\t"; + std::cout << "final_norm: " << static_cast(llm_params_.final_norm) + << "\n\t"; + std::cout << "final_proj_params:" + << "\n\t\t"; + std::cout << "no_bias: " << llm_params_.final_proj_params.no_bias << "\n\t"; + std::cout << "enable_kv_cache: " << llm_params_.enable_kv_cache << "\n\t"; + std::cout << "enable_dynamic_shape: " << llm_params_.enable_dynamic_shape + << "\n\t"; + std::cout << "Weights Info:\n"; + for (const auto &sa : llm_weights_.sas) { + std::cout << "\tSelf Attention:\n"; + PrintNormWeightInfo(2, "pre_norm_weight", sa.pre_norm_weight); + PrintTensorInfo(2, "k_weight", sa.k_weight); + PrintTensorInfo(2, "k_bias", sa.k_bias); + PrintTensorInfo(2, "q_weight", sa.q_weight); + PrintTensorInfo(2, "q_bias", sa.q_bias); + PrintTensorInfo(2, "v_weight", sa.v_weight); + PrintTensorInfo(2, "v_bias", sa.v_bias); + PrintTensorInfo(2, "per_dim_scale", sa.per_dim_scale); + PrintTensorInfo(2, "post_proj_weight", sa.post_proj_weight); + PrintTensorInfo(2, "post_proj_bias", sa.post_proj_bias); + PrintNormWeightInfo(2, "post_norm_weight", sa.post_norm_weight); + } + + for (const auto &ff : llm_weights_.ffs) { + std::cout << "\tFeed Forward:\n"; + PrintNormWeightInfo(2, "pre_norm_weight", ff.pre_norm_weight); + PrintTensorInfo(2, "layer_1_weight", ff.layer_1_weight); + PrintTensorInfo(2, "layer_1_bias", ff.layer_1_bias); + PrintTensorInfo(2, "layer_1_gate_weight", ff.layer_1_gate_weight); + PrintTensorInfo(2, "layer_1_gate_bias", ff.layer_1_gate_bias); + PrintTensorInfo(2, "layer_2_weight", ff.layer_2_weight); + PrintTensorInfo(2, "layer_2_bias", ff.layer_2_bias); + PrintNormWeightInfo(2, "post_norm_weight", ff.post_norm_weight); + } + PrintNormWeightInfo(1, "final_norm_weight", llm_weights_.final_norm_weight); + PrintTensorInfo(1, "softmax_linear", llm_weights_.softmax_linear); + PrintTensorInfo(1, "softmax_bias", llm_weights_.softmax_bias); + PrintTensorInfo(1, "token_embedding", llm_weights_.token_embedding); +#endif +} + +absl::StatusOr> Llm::CreateLlm( + const LlmWeights &llm_weights, const LlmParams &llm_params) { + auto llm = std::make_unique(); + llm->llm_params_ = llm_params; + llm->llm_weights_ = llm_weights; + + const int top_k = 0; + const float top_p = 0.f; + const float temperature = 0.f; + const int seed = 0; + auto s = Sampler::Create(Sampler::Type::kGreedy, top_k, + top_p, temperature, seed); + if (!s.ok()) { + return s.status(); + } + llm->sampler_ = std::move(s.value()); + + llm->PrintParamsAndWeights(); + + return llm; +} + +// TODO: convert this to Halide? +absl::StatusOr> ConvertToF32( + const ScaledTensor &in) { + if (in.weights.type().code == halide_type_float && + in.weights.type().bits == 32) { + return in.weights; + } + if (in.weights.type().code == halide_type_int) { + if (in.dim_scale != 0 && in.dim_scale != 1) { + return absl::InvalidArgumentError("Unsupported dim_scale"); + } + if (in.scale.type() != halide_type_t(halide_type_float, 32)) { + return absl::InvalidArgumentError("Unsupported scale type"); + } + auto b = Halide::Runtime::Buffer::make_with_shape_of(in.weights); + if (in.weights.type().bits == 8) { + auto &w = in.weights.as(); + auto &s = in.scale.as(); + if (in.dim_scale == 1) { + b.for_each_element([&](int x, int y) { b(x, y) = w(x, y) * s(x); }); + } else { + b.for_each_element([&](int x, int y) { b(x, y) = w(x, y) * s(y); }); + } + return b; + } else if (in.weights.type().bits == 4) { + return absl::InvalidArgumentError("TODO support scaled int4 here"); + } + // else fall thru + } + return absl::InvalidArgumentError("Unsupported scaled type"); +} + +absl::Status Llm::Reset() { + prev_ids_.clear(); + last_kv_cache_start_ = 0; + attention_mask_values_ = Halide::Runtime::Buffer<>(); + position_embedding_values_ = Halide::Runtime::Buffer<>(); + segment_pos_values_ = Halide::Runtime::Buffer(llm_params_.head_dim_H, + llm_params_.seq_size_T); + // TODO: This will be potentially large. Though probably not onerously so + // compared to weights. Halide currently doesn't support sparse buffers, but + // it might be possible to use extern calls to get slices of the cache, + // which might allow using a non-contiguous representation. + kv_cache_.clear(); + kv_cache_.resize(llm_params_.num_transformer_M); + for (auto &entry : kv_cache_) { + auto k_cache = Halide::Runtime::Buffer( + llm_params_.head_dim_H, + 1, // llm_params_.model_dim_D / llm_params_.head_dim_H, + llm_params_.seq_size_T, llm_params_.batch_size_B); + k_cache.fill(0.0f); + entry.k_cache = k_cache; + auto v_cache = Halide::Runtime::Buffer( + llm_params_.head_dim_H, + 1, // llm_params_.model_dim_D / llm_params_.head_dim_H, + llm_params_.seq_size_T, llm_params_.batch_size_B); + v_cache.fill(0.0f); + entry.v_cache = v_cache; + } + auto s = ConvertToF32(llm_weights_.softmax_linear); + if (!s.ok()) { + return s.status(); + } + softmax_linear_f32_ = std::move(s.value()); + return absl::OkStatus(); +} + +absl::Status Llm::InitAttentionMaskValues(size_t process_seq_len) { + const auto &seq_size = llm_params_.seq_size_T; + constexpr float neg_value = 0.5 * std::numeric_limits::lowest(); + Halide::Runtime::Buffer attention_mask_values(seq_size, seq_size); + // TODO: Could be sped up as a Halide kernel. + switch (llm_params_.model_type) { + case LlmParams::ModelType::PREFIX: { + // std::cout << "InitAttentionMaskValues prefix\n"; + RET_CHECK(process_seq_len <= seq_size); + // Prefix full attention for all tokens within input ids size(input), + // and causal attention mask for all following tokens. + for (int i = 0; i < seq_size; ++i) { + for (int j = 0; j < seq_size; ++j) { + attention_mask_values(j, i) = + (j <= i || std::max(j, i) < process_seq_len) ? 0.0f : neg_value; + } + } + break; + } + case LlmParams::ModelType::CAUSAL: { + // std::cout << "InitAttentionMaskValues causal\n"; + for (int i = 0; i < seq_size; ++i) { + for (int j = 0; j < seq_size; ++j) { + attention_mask_values(j, i) = (j <= i) ? 0 : neg_value; + } + } + break; + } + default: { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported model type: ", llm_params_.model_type)); + } + } +#if DUMP_INFO_TO_STDOUT + std::cout << "AttentionMaskValues dims [" << seq_size << ", " << seq_size + << "]\n"; + std::cout << "AttentionMaskValues[0, 0]: " << attention_mask_values(0, 0) + << "\n"; + std::cout << "AttentionMaskValues[" << (seq_size - 1) + << ", 0]: " << attention_mask_values(seq_size - 1, 0) << "\n"; + std::cout << "AttentionMaskValues[0, 1]: " << attention_mask_values(0, 1) + << "\n"; + std::cout << "AttentionMaskValues[" << (seq_size - 1) + << ", 1]: " << attention_mask_values(seq_size - 1, 1) << "\n"; + std::cout << "AttentionMaskValues[0, " << (seq_size - 1) + << "]: " << attention_mask_values(0, seq_size - 1) << "\n"; + std::cout << "AttentionMaskValues[" << (seq_size - 1) << ", " + << (seq_size - 1) + << "]: " << attention_mask_values(seq_size - 1, seq_size - 1) + << "\n"; +#endif + attention_mask_values_ = attention_mask_values; + return absl::OkStatus(); +} + +Halide::Runtime::Buffer<> Llm::AllocateSeqBuffer(int current_seq_size) { + int seq_len = llm_params_.enable_dynamic_shape ? current_seq_size : llm_params_.seq_size_T; + auto result = Halide::Runtime::Buffer(llm_params_.model_dim_D, seq_len, + llm_params_.batch_size_B); + result.fill(0.0f); + return result; +} + +// TODO: Rewrite this whole operation in Halide. +absl::Status Llm::UpdateInput(const std::vector &input_ids) { +#if DUMP_INFO_TO_STDOUT + for (size_t i = 0; i < input_ids.size(); i++) { + std::cout << "UpdateInput Token " << (i + prev_ids_.size()) << ": " << input_ids[i] << "\n" + << std::flush; + } +#endif + // At present prev_ids_ is always empty at entry, but it seems the + // design is intended to support some sort of incremental operation. + RET_CHECK(input_ids.size() + prev_ids_.size() <= llm_params_.seq_size_T); + if (llm_weights_.token_embedding.weights.data()) { + PrintBuffer("token_embedding_", llm_weights_.token_embedding.weights); + } + PrintInFloatBuffer("softmax_linear_f32_", softmax_linear_f32_); + auto token_embedding = llm_weights_.token_embedding.weights.data() ? llm_weights_.token_embedding.weights : softmax_linear_f32_; + RET_CHECK(token_embedding.dim(1).extent() == llm_params_.voc_size_V); + RET_CHECK(token_embedding.dim(0).extent() == llm_params_.model_dim_D); + // TODO: Support conversion. + RET_CHECK(token_embedding.type().code == halide_type_float && + token_embedding.type().bits == 32); + Halide::Runtime::Buffer float_token_embedding = token_embedding; + Halide::Runtime::Buffer float_input = *transformer_input_; + size_t base_id = prev_ids_.size(); + for (size_t batch = 0; batch < llm_params_.batch_size_B; ++batch) { + for (size_t id = 0; id < input_ids.size(); id++) { + memcpy(&float_input(0, static_cast(base_id + id), batch), + &float_token_embedding(0, input_ids[id]), + llm_params_.model_dim_D * sizeof(float)); + } + } + PrintInFloatBuffer("float_token_embedding", float_token_embedding); + PrintInFloatBuffer("transformer_input_", *transformer_input_); + prev_ids_.insert(prev_ids_.end(), input_ids.begin(), input_ids.end()); + // prev_id.size - 1 is the output. + return absl::OkStatus(); +} + +absl::Status Llm::InitInputTokens(const std::vector &input_ids) { + RETURN_IF_ERROR(Reset()); + RETURN_IF_ERROR(InitAttentionMaskValues(input_ids.size())); + + if (!llm_params_.skip_absolute_positional_embeddings) { + std::cout << "Initing pos_embedding.\n"; + pos_embedding_ = Halide::Runtime::Buffer( + static_cast(llm_params_.model_dim_D), + static_cast(llm_params_.seq_size_T)); + int32_t input_length; + switch (llm_params_.model_type) { + case LlmParams::ModelType::PREFIX: + input_length = input_ids.size(); + break; + case LlmParams::ModelType::CAUSAL: + input_length = prev_ids_.size(); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported model type: ", llm_params_.model_type)); + } + RETURN_IF_ERROR(StatusFromHalide(position_embedding( + input_length, llm_params_.seq_size_T, llm_params_.model_dim_D, 1.0f, + 10000.0f, pos_embedding_))); + } + + RETURN_IF_ERROR(StatusFromHalide(rope_values(segment_pos_values_))); + PrintInFloatBuffer2D("segment_pos_values_", segment_pos_values_); + + // Prepare input from ids and token embedding. + // TODO: Do we need to resize input here? + if (!transformer_input_) { + transformer_input_ = + AllocateSeqBuffer(static_cast(llm_params_.seq_size_T)); + } + + RETURN_IF_ERROR(UpdateInput(input_ids)); + + if (llm_params_.enable_kv_cache) { + RETURN_IF_ERROR(GetNextToken(&saved_token_)); + // std::cout << "Saved token is: " << saved_token_[0] << "\n"; + } + + return absl::OkStatus(); +} + +absl::Status Llm::GetNextToken(std::vector *output_ids) { + if (!saved_token_.empty()) { + *output_ids = std::move(saved_token_); + saved_token_.clear(); + return absl::OkStatus(); + } + + if (prev_ids_.size() >= llm_params_.seq_size_T - 1) { + return absl::OutOfRangeError( + absl::StrCat("Hit max sequence length ", llm_params_.seq_size_T)); + } + + // PrefixDecodeLlm::GetNextToken. + RETURN_IF_ERROR(Run()); + + RET_CHECK(logits_output_); + CHECK_EQ(logits_output_->number_of_elements(), llm_params_.voc_size_V); + PrintBuffer("logits_output_", *logits_output_); + + auto o = sampler_->Sample(logits_output_->as()); + if (!o.ok()) { + return o.status(); + } + *output_ids = std::move(o.value()); + RET_CHECK(output_ids != nullptr && output_ids->size() == 1); + +#if DUMP_INFO_TO_STDOUT + std::cout << "Output ID size is " << output_ids->size() << " is " + << output_ids->at(0) << "\n"; +#endif + + return UpdateInput(*output_ids); +} + +absl::Status Llm::RunStack(Llm::TempBuffers &buffers) { + int decode_step = prev_ids_.size(); + int run_extent = decode_step - last_kv_cache_start_; + +#if DUMP_INFO_TO_STDOUT + std::cout << "Llm::RunStack: Decode step " << decode_step << " run_extent " + << run_extent << " llm_params.enable_dynamic_shape " + << llm_params_.enable_dynamic_shape << "\n"; + { + // std::cout << "Llm::RunStack transformer_input_" << *transformer_input_ + // << "\n"; + Halide::Runtime::Buffer temp_buf_out = *transformer_input_; + DumpFloatBuffer("transformer_input_", *transformer_input_, 16, + decode_step + 1); +#if 0 + PrintBuffer("transformer_input_", temp_buf_out); + int start = llm_params_.enable_kv_cache ? prev_ids.size() : 0; + int end = llm_params_.enable_kv_cache ? 1 : prev_ids.size(); + for (int tok = start; tok < end; tok++) { + std::cout << "Llm::GetNextToken start transformer_input_ tokens " << tok << "\n"; + dump_segpos(&temp_buf_out(0, tok, 0), 32); + std::cout << "Llm::GetNextToken end transformer_input_ tokens " << tok << "\n"; + } +#endif + } +#endif + + if (llm_params_.enable_kv_cache) { + buffers.FocusSeqDimCrop(last_kv_cache_start_, run_extent); + } else { + buffers.FocusSeqDimCrop(0, prev_ids_.size()); + } + + RETURN_IF_ERROR(StatusFromHalide( + preprocessor(*transformer_input_, buffers.StartInput()))); + + DumpFloatBuffer("start_input", buffers.StartInput(), 16, 2); + + if (llm_params_.enable_kv_cache) { + Halide::Runtime::Buffer<> attention_slice = + attention_mask_values_.cropped(1, last_kv_cache_start_, run_extent); + PrintBuffer("attention_slice", attention_slice); + for (int i = 0; i < llm_params_.num_transformer_M; i++) { + auto &sas = llm_weights_.sas[i]; + auto &ffs = llm_weights_.ffs[i]; + auto key_slice = + kv_cache_[i].k_cache.cropped(2, last_kv_cache_start_, run_extent); + auto value_slice = + kv_cache_[i].v_cache.cropped(2, last_kv_cache_start_, run_extent); + +#if DUMP_INFO_TO_STDOUT + std::cout << "Compute output step " << i << "\n"; +#endif + DumpFloatBuffer("Compute enable_kv_cache input", buffers.CurrentInput(), + 16); + // DumpFloatBuffer("Compute output k_cache", kv_cache_[i].k_cache, 16); + // DumpFloatBuffer("Compute output v_cache", kv_cache_[i].v_cache, 16); + DumpFloatBuffer("Compute output attention_slice", attention_slice, 16); + + RETURN_IF_ERROR(StatusFromHalide(transformer_kv_update_cache( + buffers.CurrentInput(), segment_pos_values_, attention_slice, + std::get<0>(*(sas.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, key_slice, + value_slice))); +#if DUMP_INFO_TO_STDOUT + std::cout << "Done with transformer_kv_update_cache " << i << "\n"; +#endif + RETURN_IF_ERROR(StatusFromHalide(transformer_kv_use_cache( + buffers.CurrentInput(), segment_pos_values_, attention_slice, + std::get<0>(*(sas.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, + kv_cache_[i].k_cache, kv_cache_[i].v_cache, + buffers.CurrentOutput()))); + + DumpFloatBuffer("Compute output output", buffers.CurrentOutput(), 16); + buffers.Swap(); + } + last_kv_cache_start_ += run_extent; + } else { + for (int i = 0; i < llm_params_.num_transformer_M; i++) { + auto &sas = llm_weights_.sas[i]; + auto &ffs = llm_weights_.ffs[i]; + +#if DUMP_INFO_TO_STDOUT + std::cout << "Compute output step " << i << "\n"; +#endif + DumpFloatBuffer("Compute !enable_kv_cache input", buffers.CurrentInput(), + 16); + DumpFloatBuffer("Compute output attention_slice", attention_mask_values_, + 16); + // PrintBuffer("current output", buffers.CurrentOutput()); + + RETURN_IF_ERROR(StatusFromHalide(transformer_no_kv_cache( + buffers.CurrentInput(), segment_pos_values_, attention_mask_values_, + std::get<0>(*(sas.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, + buffers.CurrentOutput()))); + + DumpFloatBuffer("Compute output output", buffers.CurrentOutput(), 16); + buffers.Swap(); + } + } + + PrintInFloatBuffer("current_output_ after transformer stack", + buffers.CurrentOutput()); +#if 1 +#if DUMP_INFO_TO_STDOUT + Halide::Runtime::Buffer temp_buf_out = buffers.CurrentInput(); + std::cout << "Start of dump for transformer stack output:\n"; + dump_segpos(temp_buf_out.data(), 2048 * 3); + std::cout << "End of dump for transformer stack output\n"; +#endif +#endif + + // TODO: can free current_output_ here as not currently reused. + logits_output_ = Halide::Runtime::Buffer(llm_params_.voc_size_V, 1, + llm_params_.batch_size_B); + // Only compute logits for the last token. + + PrintBuffer("logits current input", buffers.CurrentInput()); + PrintBuffer("logits current output", buffers.CurrentOutput()); + + logits_output_->set_min(0, buffers.CurrentInput().dim(1).max(), 0); + // Postprocess + RETURN_IF_ERROR(StatusFromHalide( + postprocessor(buffers.CurrentInput(), + std::get<0>(*llm_weights_.final_norm_weight).norm_weight.weights, + llm_weights_.softmax_linear.weights, + llm_weights_.softmax_linear.scale, *logits_output_))); + +#if DUMP_INFO_TO_STDOUT + Halide::Runtime::Buffer temp_buf = *logits_output_; + std::cout << "Start of dump for logits output:\n"; + dump_segpos(temp_buf.data(), 2048 * 3); + std::cout << "End of dump for logits output\n"; +#endif // DUMP_INFO_TO_STDOUT + + return absl::OkStatus(); +} + +absl::Status Llm::Run() { + // Weights cache operations? + // KV cache? + + const int extent = transformer_input_->dim(1).extent(); + + TempBuffers buffers; + buffers.initial_input_full = AllocateSeqBuffer(extent); + buffers.buffers_full[0] = AllocateSeqBuffer(extent); + buffers.buffers_full[1] = AllocateSeqBuffer(extent); + buffers.FocusSeqDimCrop(0, extent); + + return RunStack(buffers); +} + +} // namespace hallmark diff --git a/apps/hallmark/src/llm.h b/apps/hallmark/src/llm.h new file mode 100644 index 000000000000..cb05f42ec364 --- /dev/null +++ b/apps/hallmark/src/llm.h @@ -0,0 +1,151 @@ +// TODO: license. +#ifndef HALIDE_APPS_HALLMARK_LLM_H_ +#define HALIDE_APPS_HALLMARK_LLM_H_ + +#include + +#include "HalideBuffer.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "contrib/llm_params.h" +#include "contrib/llm_weights.h" +#include "contrib/sampler.h" +#include "contrib/weights_loader.h" + +namespace hallmark { + +class Llm { +public: + Llm() = default; + Llm(Llm &&) = default; + ~Llm() = default; + + static absl::StatusOr> CreateLlm( + const LlmWeights &llm_weights, const LlmParams &llm_params); + + // (Re)Initialize with input token ids. This will reset the cache, mask etc. + absl::Status InitInputTokens(const std::vector &input_ids); + + // Get the next token id. + absl::Status GetNextToken(std::vector *output_ids); + + // The size of all tokens, including prompt and generated tokens. + size_t TotalTokenSize() { + return prev_ids_.size(); + } + + const LlmParams &GetLlmParams() { + return llm_params_; + } + + // These are public only for test/benchmark purposes; don't use it elsewhere. + Halide::Runtime::Buffer<> AllocateSeqBuffer(int seq_size); + absl::Status Reset(); + absl::Status InitAttentionMaskValues(size_t process_seq_len); + +private: + struct TempBuffers { + Halide::Runtime::Buffer<> initial_input_full; + Halide::Runtime::Buffer<> buffers_full[2]; + Halide::Runtime::Buffer<> initial_input; + Halide::Runtime::Buffer<> buffers[2]; + bool first{true}; + int current_input{0}; + + void FocusSeqDimCrop(int min, int extent) { + initial_input = initial_input_full.cropped(1, min, extent); + buffers[0] = buffers_full[0].cropped(1, min, extent); + buffers[1] = buffers_full[1].cropped(1, min, extent); + } + + Halide::Runtime::Buffer<> &StartInput() { + return initial_input; + } + + Halide::Runtime::Buffer<> &CurrentInput() { + return first ? initial_input : buffers[current_input]; + } + + Halide::Runtime::Buffer<> &CurrentOutput() { + return first ? buffers[0] : buffers[current_input ^ 1]; + } + + void Swap() { + if (first) { + first = false; + } else { + current_input ^= 1; + } + } + + void ResetToStart() { + first = true; + current_input = 0; + } + }; + + absl::Status Run(); + absl::Status RunStack(TempBuffers &buffers); + absl::Status UpdateInput(const std::vector &ids); + void PrintParamsAndWeights() const; + + LlmWeights llm_weights_; + LlmParams llm_params_; + + std::unique_ptr sampler_; + +public: + const std::vector &ffs() const { + return llm_weights_.ffs; + } + const std::vector sas() const { + return llm_weights_.sas; + } + std::optional &final_norm_weight() { + return llm_weights_.final_norm_weight; + } + Halide::Runtime::Buffer<> &softmax_linear_weights() { + return llm_weights_.softmax_linear.weights; + } + Halide::Runtime::Buffer<> &softmax_linear_scale() { + return llm_weights_.softmax_linear.scale; + } + + Halide::Runtime::Buffer<> &segment_pos_values() { + return segment_pos_values_; + } + Halide::Runtime::Buffer<> &attention_mask_values() { + return attention_mask_values_; + } + +private: + Halide::Runtime::Buffer<> softmax_linear_f32_; + + // Enable if enable_kv_cache + struct KVCache { + Halide::Runtime::Buffer<> k_cache; + Halide::Runtime::Buffer<> v_cache; + }; + + Halide::Runtime::Buffer<> pos_embedding_; + Halide::Runtime::Buffer<> atten_masks_; + Halide::Runtime::Buffer<> segment_pos_; + + Halide::Runtime::Buffer<> position_embedding_values_; + Halide::Runtime::Buffer<> attention_mask_values_; + Halide::Runtime::Buffer<> segment_pos_values_; + + std::optional> transformer_input_; + std::optional> logits_output_; + + // Previous ids, including prompt. + std::vector prev_ids_; + int last_kv_cache_start_; + std::vector kv_cache_; + std::vector saved_token_; +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_LLM_H_ diff --git a/apps/hallmark/src/llm_generator.cpp b/apps/hallmark/src/llm_generator.cpp new file mode 100644 index 000000000000..f692f445becb --- /dev/null +++ b/apps/hallmark/src/llm_generator.cpp @@ -0,0 +1,704 @@ +#include "Halide.h" +#include "src/ml_ops/batch_matrix_multiply.h" +#include "src/ml_ops/fully_connected.h" +#include "src/ml_ops/ml_common.h" +#include "src/ml_ops/normalization.h" +#include "src/ml_ops/rope_weights.h" +#include "src/ml_ops/softmax.h" + +namespace hallmark { + +namespace { + +using namespace Halide; + +Var b{"b"}, t{"t"}, n{"n"}, h{"h"}, s{"s"}; +// TODO(zalman): Ugly global. +Type generating_type; + +enum class AttentionScaleType { + PerDimScale, + InverseSqrtHeadDim, +}; + +// TODO: Should be moved into Halide proper once sorted. +Expr fast_tanh(Expr x) { + // In theory, this should be a really good approximation for tanh; + // in practice, even very small (< 1e-7) differences in the result + // can have profound impact on correctness of output. TODO: consider + // adapting XNNPACK's approximation(s)? + // + // Expr r = (fast_exp(2*x)-1.f)/(fast_exp(2*x)+1.f); + + return tanh(x); +} + +const std::map attention_scale_names = { + {"per_dim_scale", AttentionScaleType::PerDimScale}, + {"inverse_sqrt_head_dim", AttentionScaleType::InverseSqrtHeadDim}}; + +Func soft_plus(Func weights, Expr dims_norm) { + Expr scale = 1.442695041f / dims_norm; + Func soft_plus("soft_plus"); + soft_plus(_) = (halide_log(1 + default_exp(cast(-abs(weights(_))))) + + max(weights(_), 0.0f)) * + scale; + return soft_plus; +} + +Func gelu(Func in) { + constexpr float sqrt_2_over_pi = 0.7978845608f; + Expr elem = in(_); + Func gelu_result("gelu_result"); + + // Based on approximation from e.g: https://arxiv.org/pdf/1606.08415.pdf + gelu_result(_) = + elem * + ((fast_tanh(((1 + (elem * elem * 0.044715f)) * elem) * sqrt_2_over_pi) + + 1) * + .5f); + return gelu_result; +} + +// TODO: Optimize, make match xnnpack. +Func silu(Func in) { + Func silu_result("silu_result"); + silu_result(_) = in(_) / (1 + default_exp(in(_))); + return silu_result; +} + +Func relu(Func in) { + Func relu_result("relu_result"); + relu_result(_) = max(0.0f, in(_)); + return relu_result; +} + +class LlmRoPEValues : public Halide::Generator { +public: + GeneratorParam head_dim_H_{"head_dim_H", 128}; + + GeneratorParam processing_type_{"processing_type", Float(32)}; + + Output> segment_pos_values_{"segment_pos_values"}; + + void configure() { + segment_pos_values_.set_type(processing_type_); + } + + void generate() { + segment_pos_weights_.apply(head_dim_H_, processing_type_); + segment_pos_values_ = segment_pos_weights_.result; + } + + void schedule() { + // TODO: apply static bounds. + segment_pos_weights_.default_schedule(LoopLevel::root(), get_target()); + } + + RoPEWeights segment_pos_weights_{"segment_pos_weights"}; // @@@ +}; + +class LlmPreprocessor : public Halide::Generator { +public: + GeneratorParam model_dim_D_{"model_dim_D", 2048}; + + GeneratorParam skip_absolute_positional_embeddings_{ + "skip_absolute_positional_embeddings", true}; + + GeneratorParam processing_type_{"processing_type", Float(32)}; + + Input> input_{"input"}; + // Optional input pos_embedding + Input> *pos_embedding_; + + Output> scaled_embedding_{"scaled_embedding"}; + + void configure() { + input_.set_type(processing_type_); + if (!skip_absolute_positional_embeddings_) { + pos_embedding_ = + add_input>("pos_embeddings", processing_type_, 3); + } + scaled_embedding_.set_type(processing_type_); + } + + void generate() { + scaled_embedding_(n, t, b) = + input_(n, t, b) * std::sqrt((float)model_dim_D_) + + (skip_absolute_positional_embeddings_ ? 0 : (*pos_embedding_)(n, t, b)); + } + + void schedule() { + // TODO: apply static bounds. + scaled_embedding_.compute_root().vectorize(n, natural_vector_size()); + } +}; + +// TODO: What should these kinds/modes really be called? +enum class TransformerKind { + PrefixOnlyUncached, + PrefixDecodeUpdateCache, + PrefixDecodeUseCache, +}; + +inline const std::map transformer_kind_names = { + {"prefix_only_uncached", TransformerKind::PrefixOnlyUncached}, + {"prefix_decode_update_cache", TransformerKind::PrefixDecodeUpdateCache}, + {"prefix_decode_use_cache", TransformerKind::PrefixDecodeUseCache}}; + +class LlmTransformer : public Halide::Generator { +public: + GeneratorParam batch_size_B_{"batch_size_B", 1}; + GeneratorParam seq_size_T_{"seq_size_T", 512}; + GeneratorParam model_dim_D_{"model_dim_D", 128}; + GeneratorParam hidden_dim_HD_{"hidden_dim_HD", 128}; + GeneratorParam head_dim_H_{"head_dim_H", 128}; + GeneratorParam n_heads_N_{"n_heads_N", 8}; + GeneratorParam voc_size_V_{"voc_size_V", 128}; + GeneratorParam num_kv_heads_{"num_kv_heads", 1}; + GeneratorParam transformer_kind_{ + "transformer_kind", TransformerKind::PrefixOnlyUncached, + transformer_kind_names}; + + GeneratorParam processing_type_{"processing_type", Float(32)}; + GeneratorParam sa_pre_norm_{ + "sa_pre_norm", NormalizationKind::RMS, normalization_kind_names}; + GeneratorParam sa_post_norm_{ + "sa_post_norm", NormalizationKind::RMS, normalization_kind_names}; + GeneratorParam feed_forward_pre_norm_{ + "feedforward_pre_norm", NormalizationKind::RMS, normalization_kind_names}; + GeneratorParam feed_forward_post_norm_{ + "feedforward_post_norm", NormalizationKind::RMS, + normalization_kind_names}; + + GeneratorParam attention_scale_type_{ + "attention_scale_type", AttentionScaleType::InverseSqrtHeadDim, + attention_scale_names}; + + GeneratorParam use_mqa_{"use_mqa", false}; + + // TODO: Does this need to be made an input? + GeneratorParam soft_cap_{"soft_cap", 0.0f}; + + GeneratorParam feed_forward_params_activation_{ + "feed_forward_params_activation", Activation::RELU, activation_names}; + + Input> layer_input_{"layer_input", 3}; + Input> segment_pos_values_{"segment_pos_values", 2}; + Input> attention_mask_{"attention_mask", 2}; + + // Only used for PrefixDecodeUpdateCache + Output> *key_cache_slice_output_; + Output> *value_cache_slice_output_; + + // Only used for PrefixDecodeUseCache + Input> *key_cache_input_; + Input> *value_cache_input_; + + // Optional per_dim_scale only present is attention_scale_type is PerDimScale. + Input> *per_dim_scale_; + + Output> *layer_output_; + + void configure() { + generating_type = processing_type_; + layer_input_.set_type(processing_type_); + segment_pos_values_.set_type(processing_type_); + // TODO: handle layer norm args? + pre_normed_.add_inputs(sa_pre_norm_, processing_type_, this); + post_normed_.add_inputs(sa_post_norm_, processing_type_, this); + + // TODO: Parameterize quantization kind. + // TODO: Find better convention for the "_3d" naming. + key_proj_3d_.add_inputs(QuantizationKind::QC8NoBias, Int(8), model_dim_D_, + head_dim_H_, this); + query_proj_3d_.add_inputs(QuantizationKind::QC8NoBias, Int(8), model_dim_D_, + model_dim_D_, this); + value_proj_3d_.add_inputs(QuantizationKind::QC8NoBias, Int(8), model_dim_D_, + head_dim_H_, this); + + attention_mask_.set_type(processing_type_); + post_attention_proj_.add_inputs(QuantizationKind::QC8NoBias, Int(8), + model_dim_D_, model_dim_D_, this); + if (transformer_kind_ != TransformerKind::PrefixDecodeUpdateCache) { + final_pre_normed_.add_inputs(feed_forward_pre_norm_, processing_type_, + this); + final_post_normed_.add_inputs(feed_forward_post_norm_, processing_type_, + this); + // TODO: is the order of model and hidden dims correct? + feed_forward_layer_1_.add_inputs(QuantizationKind::QC8NoBias, Int(8), + model_dim_D_, hidden_dim_HD_, this); + feed_forward_layer_1_gate_.add_inputs(QuantizationKind::QC8NoBias, Int(8), + model_dim_D_, hidden_dim_HD_, this); + feed_forward_layer_2_.add_inputs(QuantizationKind::QC8NoBias, Int(8), + hidden_dim_HD_, model_dim_D_, this); + layer_output_ = add_output>("layer_output", processing_type_, 3); + } + + if (attention_scale_type_ == AttentionScaleType::PerDimScale) { + per_dim_scale_ = + add_input>("per_dim_scale", processing_type_, 1); + } + + if (transformer_kind_ == TransformerKind::PrefixDecodeUpdateCache) { + // These are N, H, T, B or S, H, T, B (Halide ordering) + key_cache_slice_output_ = + add_output>("key_cache_slice_output", processing_type_, 4); + value_cache_slice_output_ = + add_output>("value_cache_slice_output", processing_type_, 4); + } + if (transformer_kind_ == TransformerKind::PrefixDecodeUseCache) { + key_cache_input_ = add_input>("key_cache", processing_type_, 4); + value_cache_input_ = + add_input>("value_cache", processing_type_, 4); + } + } + + void generate() { + Func input("input"); + // Name dimensions of input + input(n, t, b) = layer_input_(n, t, b); + pre_normed_.apply(input, model_dim_D_); + + key_proj_3d_.apply(pre_normed_.result, get_target()); + query_proj_3d_.apply(pre_normed_.result, get_target()); + value_proj_3d_.apply(pre_normed_.result, get_target()); + + // TODO: The splits here may be computed from generator params that just + // happen to work for this case based on the passed in key/query/value + // projection weight sizes. Should probably introduce checks to make sure + // the weights have these sizes or introduce new generator parameters. + // It is possible to make these dynamic from the extents of the inputs, but + // it may be expensive. + // + // Converts B,T,NH -> B,T,N,H or B,T,NH -> B,T,S,H + // + // Note, in original code, the split divisor here comes from + // kKeySelfAttentionReshapedWeight in metadata. + // The numerator of the split could be dim(0).extent from the weights, but + // the generator param is a constant. + + Expr query_scale; + if (attention_scale_type_ == AttentionScaleType::PerDimScale) { + // TODO: memoize this. + Func per_dim_scale_cached("per_dim_scale_cached"); + per_dim_scale_cached(h) = soft_plus(*per_dim_scale_, head_dim_H_)(h); + query_scale = per_dim_scale_cached(h); + } else if (attention_scale_type_ == + AttentionScaleType::InverseSqrtHeadDim) { + query_scale = fast_inverse_sqrt(cast(processing_type_, head_dim_H_)); + } + + int key_value_split = + static_cast(head_dim_H_) / static_cast(num_kv_heads_); + key_proj_4d_(s, n, t, b) = + key_proj_3d_.result(s + n * key_value_split, t, b); + value_proj_4d_(s, n, t, b) = + value_proj_3d_.result(s + n * key_value_split, t, b); + int query_split = + static_cast(model_dim_D_) / static_cast(n_heads_N_); + query_proj_4d_(h, n, t, b) = + query_proj_3d_.result(h + n * query_split, t, b); + + CHECK(key_value_split == query_split); + + roped_key_proj_4d_.apply(key_proj_4d_, segment_pos_values_, head_dim_H_); + roped_query_proj_4d_.apply(query_proj_4d_, segment_pos_values_, + head_dim_H_); + + if (transformer_kind_ == TransformerKind::PrefixDecodeUpdateCache) { + (*key_cache_slice_output_)(s, n, t, b) = + roped_key_proj_4d_.result(s, n, t, b); + (*value_cache_slice_output_)(s, n, t, b) = value_proj_4d_(s, n, t, b); + } else { + Func roped_key_proj_4d_switch("roped_key_proj_4d_switch"); + Func value_proj_4d_switch("value_proj_4d_switch"); + if (transformer_kind_ == TransformerKind::PrefixOnlyUncached) { + roped_key_proj_4d_switch = roped_key_proj_4d_.result; + value_proj_4d_switch = value_proj_4d_; + } else { + roped_key_proj_4d_switch = *key_cache_input_; + value_proj_4d_switch = *value_cache_input_; + } + + // Swap middle dimensions for key and query. BTN{H,S} -> BNT{S,H} + key_proj_permuted_(s, t, n, b) = roped_key_proj_4d_switch(s, n, t, b); + // BTNS -> BNST + value_proj_permuted_(t, s, n, b) = value_proj_4d_switch(s, n, t, b); + query_proj_permuted_(h, t, n, b) = + roped_query_proj_4d_.result(h, n, t, b) * query_scale; + + // "maybe" because I'm not 100% sure this is what this input means. Also + // not 100% sure it is the thing to use where it's being used, but I think + // so. + // TODO: These should probably be taken from the output. + Expr input_seq_len_maybe = layer_input_.dim(1).extent(); + Expr total_seq_len = + layer_input_.dim(1).min() + layer_input_.dim(1).extent(); + + Func logits("logits"); + if (use_mqa_) { + // reshape key_permuted {0, llm_params_.head_dim_H} + Func key_proj_permuted_reshaped("key_proj_permuted_reshaped"); + // TODO: Figure out the best way to do this. + key_proj_permuted_reshaped = key_proj_permuted_; + logits_fc_ = FullyConnected::float32_layer(query_proj_permuted_, + key_proj_permuted_reshaped, + head_dim_H_, model_dim_D_, get_target()); + logits = logits_fc_.result; + } else { + Func broadcast_key_proj_permuted("broadcast_key_proj_permuted"); + broadcast_key_proj_permuted(s, t, n, b) = + key_proj_permuted_(s, t, 0, b); + Func transposed_key_proj_permuted("transposed_key_proj_permuted"); + transposed_key_proj_permuted(t, s, n, b) = + broadcast_key_proj_permuted(s, t, n, b); + logits_bmm_.float32_layer(query_proj_permuted_, + transposed_key_proj_permuted, key_value_split, + input_seq_len_maybe, total_seq_len); + logits = logits_bmm_.result; + } + + // BNTS + if (soft_cap_ > 0.0f) { + logits(s, t, n, b) = + fast_tanh(logits(s, t, n, b) / soft_cap_) * soft_cap_; + } + Func padded_logits("padded_logits"); + padded_logits(s, t, n, b) = logits(s, t, n, b) + attention_mask_(s, t); + + // TODO: is size for this softmax correct? + probs_softmax_.apply(padded_logits, total_seq_len, generating_type); + + Func broadcast_value_proj_permuted("broadcast_value_proj_permuted"); + broadcast_value_proj_permuted(s, t, n, b) = + value_proj_permuted_(s, t, 0, b); + Func transposed_value_proj_permuted("transposed_value_proj_permuted"); + transposed_value_proj_permuted(t, s, n, b) = + broadcast_value_proj_permuted(s, t, n, b); + outcome_before_permute_bmm_.float32_layer( + probs_softmax_.result, transposed_value_proj_permuted, total_seq_len, + input_seq_len_maybe, head_dim_H_); + // Swap middle two dimensions back. + kqv_merged_(h, n, t, b) = outcome_before_permute_bmm_.result(h, t, n, b); + + // Merge h and n dimensions. + outcome_reshaped_(n, t, b) = + kqv_merged_(n % head_dim_H_, n / head_dim_H_, t, b); + post_attention_proj_.apply(outcome_reshaped_, get_target()); + + post_normed_.apply(post_attention_proj_.result, model_dim_D_); + + // Rename for now to match use in calling function. + Func output("output"); + output(n, t, b) = post_normed_.result(n, t, b) + input(n, t, b); + + final_pre_normed_.apply(output, model_dim_D_); + feed_forward_layer_1_.apply(final_pre_normed_.result, get_target()); + feed_forward_layer_1_gate_.apply(final_pre_normed_.result, get_target()); + + if (feed_forward_params_activation_ == Activation::None) { + feed_forward_gate_(n, t, b) = + feed_forward_layer_1_gate_.result(n, t, b); + } else if (feed_forward_params_activation_ == Activation::GELU) { + feed_forward_gate_(n, t, b) = + gelu(feed_forward_layer_1_gate_.result)(n, t, b); + } else if (feed_forward_params_activation_ == Activation::SILU) { + feed_forward_gate_(n, t, b) = + silu(feed_forward_layer_1_gate_.result)(n, t, b); + } else if (feed_forward_params_activation_ == Activation::RELU) { + feed_forward_gate_(n, t, b) = + relu(feed_forward_layer_1_gate_.result)(n, t, b); + } else { + std::abort(); + } + + feed_forward_layer_1_and_gate_(n, t, b) = + feed_forward_layer_1_.result(n, t, b) * feed_forward_gate_(n, t, b); + + feed_forward_layer_2_.apply(feed_forward_layer_1_and_gate_, get_target()); + + final_post_normed_.apply(feed_forward_layer_2_.result, model_dim_D_); + + if (transformer_kind_ != TransformerKind::PrefixDecodeUpdateCache) { + (*layer_output_)(n, t, b) = + final_post_normed_.result(n, t, b) + output(n, t, b); + } + } + } + + using LL = LoopLevel; + + void schedule() { +#if 0 + root_schedule(); +#else + pre_normed_.default_schedule(LL::root(), target); + + // t and b are unbounded but n is always exactly 2048 + layer_input_.dim(0).set_extent(2048); + + if (transformer_kind_ == TransformerKind::PrefixDecodeUpdateCache) { + key_proj_3d_.default_schedule(LL(roped_key_proj_4d_.inner, t), target); + key_proj_4d_.compute_inline(); + key_proj_permuted_.compute_inline(); + + query_proj_3d_.default_schedule(LL(roped_query_proj_4d_.inner, t), + target); + query_proj_4d_.compute_inline(); + query_proj_permuted_.compute_inline(); + + value_proj_3d_.default_schedule(LL::root(), target); + value_proj_4d_.compute_inline(); + value_proj_permuted_.compute_inline(); + + roped_query_proj_4d_.default_schedule(LL(*key_cache_slice_output_, t), + target); + roped_key_proj_4d_.default_schedule(LL(*key_cache_slice_output_, t), + target); + } else { + auto &layer_output = (*layer_output_); + key_proj_3d_.default_schedule(LL(roped_key_proj_4d_.inner, t), target); + key_proj_4d_.compute_inline(); + key_proj_permuted_.compute_inline(); + + query_proj_3d_.default_schedule(LL(probs_softmax_.result, b), target); + query_proj_4d_.compute_inline(); + query_proj_permuted_.compute_inline(); + + value_proj_3d_.default_schedule(LL::root(), target); + value_proj_4d_.compute_inline(); + value_proj_permuted_.compute_inline(); + + if (use_mqa_) { + logits_fc_.default_schedule(LL(probs_softmax_.result, b), target); + } else { + // Parallel here causes overruns and likely doesn't help. + const int parallel_split = 0; + logits_bmm_.default_schedule(LL(probs_softmax_.result, b), target, + parallel_split); + } + + const bool vectorize_softmax = + transformer_kind_ == TransformerKind::PrefixOnlyUncached; + probs_softmax_.default_schedule(LL::root(), target, vectorize_softmax); + + const int parallel_split = 4; + outcome_before_permute_bmm_.default_schedule( + LL(post_attention_proj_.result, b), target, parallel_split); + kqv_merged_.compute_inline(); + outcome_reshaped_.compute_inline(); + + post_attention_proj_.default_schedule(LL::root(), target); + post_normed_.default_schedule(LL::root(), target); + + roped_query_proj_4d_.default_schedule(LL(logits_bmm_.result, b), target); + roped_key_proj_4d_.default_schedule(LL(logits_bmm_.result, b), target); + final_pre_normed_.default_schedule(LL::root(), target); + feed_forward_layer_1_.default_schedule(LL(feed_forward_layer_2_.result, b), target); + feed_forward_layer_1_.result.hoist_storage_root(); + feed_forward_layer_1_gate_.default_schedule(LL(feed_forward_layer_2_.result, b), target); + feed_forward_layer_1_gate_.result.hoist_storage_root(); + feed_forward_layer_1_and_gate_.compute_at(feed_forward_layer_2_.result, b) + .hoist_storage_root(); + feed_forward_layer_2_.default_schedule(LL(layer_output, b), target); + final_post_normed_.default_schedule(LL(layer_output, b), target); + + layer_output.dim(0).set_extent(2048); + for (int d = 0; d < 3; d++) { + layer_input_.dim(d).set_extent(layer_output.dim(d).extent()); + } + } +#endif + } + + void root_schedule() { + pre_normed_.default_schedule(LL::root(), target); + + kqv_merged_.compute_root(); + + key_proj_3d_.default_schedule(LL::root(), target); + query_proj_3d_.default_schedule(LL::root(), target); + // query_proj_3d_.result.debug_to_file("/tmp/qp3d.npy"); + value_proj_3d_.default_schedule(LL::root(), target); + + roped_key_proj_4d_.default_schedule(LL::root(), target); + roped_query_proj_4d_.default_schedule(LL::root(), target); + + if (transformer_kind_ != TransformerKind::PrefixDecodeUpdateCache) { + logits_bmm_.default_schedule(LL::root(), target, /*parallel_split*/ 0); + probs_softmax_.default_schedule(LL::root(), target, /*vectorize*/ false); + + outcome_before_permute_bmm_.default_schedule(LL::root(), target, + /*parallel_split*/ 0); + + post_attention_proj_.default_schedule(LL::root(), target); + post_normed_.default_schedule(LL::root(), target); + + final_pre_normed_.default_schedule(LL::root(), target); + // Can these two be compute_with? + feed_forward_layer_1_.default_schedule(LL::root(), target); + feed_forward_layer_1_gate_.default_schedule(LL::root(), target); + feed_forward_layer_2_.default_schedule(LL::root(), target); + final_post_normed_.default_schedule(LL::root(), target); + } + } + + Input> *sa_pre_norm_weights_{}; + Input> *sa_post_norm_weights_{}; + Input> *feed_forward_pre_norm_weights_{}; + Input> *feed_forward_post_norm_weights_{}; + + Normalization pre_normed_{"pre_normed"}; + + FullyConnected key_proj_3d_{"key_proj_3d"}; + FullyConnected query_proj_3d_{"query_proj_3d"}; + FullyConnected value_proj_3d_{"value_proj_3d"}; + + Func key_proj_4d_{"key_proj_4d"}; + Func query_proj_4d_{"query_proj_4d"}; + Func value_proj_4d_{"value_proj_4d"}; + + RoPE roped_key_proj_4d_{"roped_key_proj_4d"}; + RoPE roped_query_proj_4d_{"roped_query_proj_4d"}; + + Func query_proj_permuted_{"query_proj_permuted"}; + Func key_proj_permuted_{"key_proj_permuted"}; + Func value_proj_permuted_{"value_proj_permuted"}; + + FullyConnected logits_fc_{"logits"}; + BatchMatrixMultiply logits_bmm_{"logits"}; + + Normalization post_normed_{"post_normed"}; + + Softmax probs_softmax_{"probs_softmax"}; + + BatchMatrixMultiply outcome_before_permute_bmm_{"outcome_before_permute"}; + + Func kqv_merged_{"kqv_merged"}; + Func outcome_reshaped_{"outcome_reshaped"}; + FullyConnected post_attention_proj_{"post_attention_proj"}; + + FullyConnected feed_forward_layer_1_{"feed_forward_layer_1"}; + Func feed_forward_gate_{"feed_forward_gate"}; + FullyConnected feed_forward_layer_1_gate_{"feed_forward_layer_1_gate"}; + Func feed_forward_layer_1_and_gate_{"feed_forward_layer_1_and_gate"}; + FullyConnected feed_forward_layer_2_{"feed_forward_layer_2"}; + + Normalization final_pre_normed_{"final_pre_normed"}; + Normalization final_post_normed_{"final_post_normed"}; +}; + +class LlmPostprocessor : public Halide::Generator { +public: + GeneratorParam batch_size_B_{"batch_size_B", 1}; + GeneratorParam seq_size_T_{"seq_size_T", 512}; + GeneratorParam model_dim_D_{"model_dim_D", 128}; + GeneratorParam head_dim_H_{"head_dim_H", 128}; + GeneratorParam voc_size_V_{"voc_size_V", 128}; + + GeneratorParam final_norm_{ + "final_norm", NormalizationKind::RMS, normalization_kind_names}; + + GeneratorParam processing_type_{"processing_type", Float(32)}; + + // Inputs are last transformer layer output, final_norm, + // final_post_process_weights + Input> layer_input_{"layer_input", 3}; + + Output> result_{"result", 3}; + + void configure() { + generating_type = processing_type_; + layer_input_.set_type(processing_type_); + post_process_normed_.add_inputs(final_norm_, processing_type_, this); + + feed_forward_.add_inputs(QuantizationKind::QC8NoBias, Int(8), model_dim_D_, + voc_size_V_, this); + + result_.set_type(processing_type_); + } + + void generate() { + // Gives var names to arguments, which are used in operators. + Func postprocess_input("postprocess_input"); + postprocess_input(n, t, b) = layer_input_(n, t, b); + // TODO: is size right for normalization here? + post_process_normed_.apply(postprocess_input, head_dim_H_); + // TODO: Anything to do to ensure softmax linear? + feed_forward_.apply(post_process_normed_.result, get_target()); + + result_ = feed_forward_.result; + } + + void schedule() { + post_process_normed_.default_schedule(LoopLevel(feed_forward_.result, b), + get_target()); + feed_forward_.default_schedule(LoopLevel(feed_forward_.result, t), + get_target()); + } + + Normalization post_process_normed_{"post_process_normed"}; + FullyConnected feed_forward_{"feed_forward"}; +}; + +class LlmPositionEmbedding : public Halide::Generator { +public: + Input input_length_{"input_length"}; + Input seq_length_{"seq_length"}; + Input embedding_dim_{"embedding_dim"}; + Input min_timescale_{"min_timescale"}; + Input max_timescale_{"max_timescale"}; + + Output> result_{"result", 2}; + + void generate() { + input_range_ = RDom(0, embedding_dim_ / 2, 0, input_length_); + seq_range_ = + RDom(0, embedding_dim_ / 2, input_length_, seq_length_ - input_length_); + Expr log_timescale_inc = default_log(max_timescale_ / min_timescale_) / + max(embedding_dim_ / 2.0f, 1.0f); + Expr inv_timescale = + min_timescale_ * default_exp(input_range_.x * log_timescale_inc); + + result_(n, h) = undef(); + result_(input_range_.x, input_range_.y) = + select(input_range_.x > embedding_dim_ / 2, + fast_cos(input_range_.y * inv_timescale), + fast_sin(input_range_.y * inv_timescale)); + result_(seq_range_.x, seq_range_.y) = + select(seq_range_.x > embedding_dim_ / 2, 0.f, 1.f); + } + + void schedule() { + // Turning this on causes a Halide compilation error complaining about + // a redundant update definition. + // Schedule + RVar ro("ro"), ri("ri"); + result_.compute_root(); + result_.update(0) + .split(input_range_.x, ro, ri, embedding_dim_ / 2) + .unroll(ro) + .vectorize(ri, natural_vector_size()); + result_.update(1) + .split(seq_range_.x, ro, ri, embedding_dim_ / 2) + .unroll(ro) + .vectorize(ri, natural_vector_size()); + } + + RDom input_range_; + RDom seq_range_; +}; + +} // namespace + +} // namespace hallmark + +HALIDE_REGISTER_GENERATOR(hallmark::LlmRoPEValues, LlmRoPEValues); +HALIDE_REGISTER_GENERATOR(hallmark::LlmPreprocessor, + LlmPreprocessor); +HALIDE_REGISTER_GENERATOR(hallmark::LlmTransformer, LlmTransformer); +HALIDE_REGISTER_GENERATOR(hallmark::LlmPostprocessor, + LlmPostprocessor); +HALIDE_REGISTER_GENERATOR(hallmark::LlmPositionEmbedding, + LlmPositionEmbedding); diff --git a/apps/hallmark/src/llm_runner.cpp b/apps/hallmark/src/llm_runner.cpp new file mode 100644 index 000000000000..5f5ac83af582 --- /dev/null +++ b/apps/hallmark/src/llm_runner.cpp @@ -0,0 +1,183 @@ +#include "llm.h" + +#include + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "src/sentencepiece_processor.h" + +ABSL_FLAG(std::string, model_path, "model.tflite", + "Path to the tflite model file."); + +ABSL_FLAG(std::string, tokenizer_path, "tokenizer.spm", + "Path to the sentence piece model."); + +ABSL_FLAG(std::string, prompt, "Write a memo to myself titled \"Do the dumb things I gotta do.\"", + "Initial prompt for llm."); + +ABSL_FLAG(int, max_tokens, 512, + "Maximum number of input and output tokens. This value needs to be " + "at least larger than the number of input tokens."); + +ABSL_FLAG(bool, show_timing, false, + "Show timing for operations."); + +namespace { + +// Prefer high_resolution_clock, but only if it's steady... +template +struct SteadyClock { + using type = std::chrono::high_resolution_clock; +}; + +// ...otherwise use steady_clock. +template<> +struct SteadyClock { + using type = std::chrono::steady_clock; +}; + + +struct TimingScope { + TimingScope(const char *name, int iterations = 1) : name(name), iterations(iterations) { + start = SteadyClock<>::type::now(); + } + + ~TimingScope() { + if (absl::GetFlag(FLAGS_show_timing)) { + SteadyClock<>::type::time_point end = SteadyClock<>::type::now(); + double secs = std::chrono::duration_cast>(end - start).count(); + std::cerr << name << ": took " << secs << "s"; + if (iterations != 1) { + std::cerr << " " << secs / iterations << "s per iteration.\n"; + } else { + std::cerr << "\n"; + } + } + } + + std::string name; + int iterations; + SteadyClock<>::type::time_point start; +}; + +} // anonymous namespace + +int main(int argc, char *argv[]) { + absl::ParseCommandLine(argc, argv); + + auto model_path = absl::GetFlag(FLAGS_model_path); + auto tokenizer_path = absl::GetFlag(FLAGS_tokenizer_path); + auto prompt = absl::GetFlag(FLAGS_prompt); + auto max_tokens = absl::GetFlag(FLAGS_max_tokens); + + sentencepiece::SentencePieceProcessor tokenizer; + { + TimingScope load_tokenizer("Loading tokenizer"); + auto result = tokenizer.Load(tokenizer_path); + if (!result.ok()) { + std::cerr << result.message(); + return 1; + } + } + + std::vector prompt_tokens; + { + // TODO: Apparently this is required by the Gemma IT + // model. Find some documentation on the mechanism and see if + // there is a better way to handle this or to make it + // conditional on some info from the model file. + std::string bracketed_prompt = "user\n" + prompt + + "\nmodel\n"; + + auto result = tokenizer.Encode(bracketed_prompt, &prompt_tokens); + } + + hallmark::LlmParams llm_params; + { + TimingScope load_tokenizer("Loading LLM params"); + auto p = hallmark::LoadLlmParams(model_path); + if (!p.ok()) { + std::cerr << p.status() << "\n"; + return 1; + } + llm_params = std::move(p.value()); + } + llm_params.seq_size_T = max_tokens; + + hallmark::LlmWeights llm_weights; + { + TimingScope load_tokenizer("Loading LLM params"); + auto w = hallmark::LoadLlmWeights(model_path, llm_params); + if (!w.ok()) { + std::cerr << w.status() << "\n"; + return 1; + } + llm_weights = std::move(w.value()); + } + + std::unique_ptr llm; + { + TimingScope load_tokenizer("Creating LLM"); + auto l = hallmark::Llm::CreateLlm(llm_weights, llm_params); + if (!l.ok()) { + std::cerr << l.status() << "\n"; + return 2; + } + llm = std::move(l.value()); + } + + if (!llm->Reset().ok()) { + std::cerr << "Reset fails\n"; + return 3; + } + { + TimingScope load_tokenizer("Init attention mask"); + if (!llm->InitAttentionMaskValues(llm_params.seq_size_T).ok()) { + std::cerr << "InitAttentionMaskValues fails\n"; + return 4; + } + } + + { + TimingScope load_tokenizer("Init input tokens", prompt_tokens.size()); + if (!llm->InitInputTokens(prompt_tokens).ok()) { + std::cerr << "InitInputTokens fails\n"; + return 1; + } + } + + std::cout << prompt << "\n"; + + { + TimingScope generate("\nGenerate tokens", max_tokens); + std::vector output_tokens; + for (int token = prompt_tokens.size(); token < max_tokens - 2; token += output_tokens.size()) { + output_tokens.clear(); + if (!llm->GetNextToken(&output_tokens).ok()) { + std::cerr << "GetNextToken fails\n"; + return 6; + } + if (output_tokens.empty()) { + std::cerr << "Empty result from GetNextToken.\n"; + } else if (output_tokens.size() > 1) { + std::cerr << "More than one token returned from GetNextToken token " << token << ".\n"; + } + std::string decoded_tokens; + if (!tokenizer.Decode(output_tokens, &decoded_tokens).ok()) { + std::cerr << "Decode fails\n"; + return 7; + } + if (decoded_tokens.empty()) { + std::cout << "_"; + } + std::cout << decoded_tokens; + std::cout.flush(); + } + } + + return 0; +} diff --git a/apps/hallmark/src/ml_ops/CMakeLists.txt b/apps/hallmark/src/ml_ops/CMakeLists.txt new file mode 100644 index 000000000000..5cd5b164d2a5 --- /dev/null +++ b/apps/hallmark/src/ml_ops/CMakeLists.txt @@ -0,0 +1,4 @@ +# Sigh, header-only libraries shouldn't be special +add_library(hallmark_ml_ops INTERFACE) +target_include_directories(hallmark_ml_ops INTERFACE + $) diff --git a/apps/hallmark/src/ml_ops/batch_matrix_multiply.h b/apps/hallmark/src/ml_ops/batch_matrix_multiply.h new file mode 100644 index 000000000000..4a7f0b796763 --- /dev/null +++ b/apps/hallmark/src/ml_ops/batch_matrix_multiply.h @@ -0,0 +1,78 @@ +#ifndef HALIDE_APPS_HALLMARK_BATCH_MATRIX_MULTIPLY_H_ +#define HALIDE_APPS_HALLMARK_BATCH_MATRIX_MULTIPLY_H_ + +#include + +#include "Halide.h" +#include "absl/log/check.h" + +namespace hallmark { + +// Multiply all the 2D matrices defined by the initial dimensions of two +// Funcs, iterating across the higher dimensions in correspondence fashion. +// (Should implement the standard ML op, though transposition/adjoint is pushed +// outside this interface.) + +struct BatchMatrixMultiply : public Halide::NamesInterface { + BatchMatrixMultiply(const std::string &base_name) + : base_name(base_name), result(base_name + "_batch_matrix_multiply") { + } + + std::string base_name; + Func result; + + RDom r; + Var in1_0, in1_1; + + // TODO: better API needed + // TODO: Likely can infer the processing type here and make this not just float32. + void float32_layer(Func in1, Func in2, Expr shared_dim_size, + Expr in1_dim1_size, Expr in2_dim0_size) { + std::vector in1_args = in1.args(); + std::vector in2_args = in2.args(); + + CHECK(in1_args.size() == in2_args.size()); + CHECK(in1_args.size() > 2); + + r = RDom(0, shared_dim_size, base_name + "_rdom"); + + std::vector result_reduction_args(in1_args.begin(), in1_args.end()); + std::vector in1_reduction_args = result_reduction_args; + std::vector in2_reduction_args = result_reduction_args; + + in1_reduction_args[0] = r; + in1_reduction_args[1] = in1_args[1]; + in2_reduction_args[0] = in1_args[0]; + in2_reduction_args[1] = r; + + result(in1_args) = 0.0f; + result(in1_args) += in1(in1_reduction_args) * in2(in2_reduction_args); + + in1_0 = in1_args[0]; + in1_1 = in1_args[1]; + } + + void default_schedule(LoopLevel result_loop_level, const Target &t, + int parallel_split) { + result.compute_at(result_loop_level); + // Don't vectorize here the pure-init case: it will expand the boundaries + // (which will cause OOB for some use cases), and more importantly, LLVM is + // apparently smart enough to just use memset(0) to clear this anyway. + + RVar ro("ro"), ri("ri"); + Var fo("fo"), fi("fi"); + + result.update() + .split(r, ro, ri, t.natural_vector_size() * 4) + .atomic() + .vectorize(ri); + + if (parallel_split != 0) { + result.update().split(in1_1, fo, fi, parallel_split).parallel(fo); + } + } +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_BATCH_MATRIX_MULTIPLY_H_ diff --git a/apps/hallmark/src/ml_ops/fully_connected.h b/apps/hallmark/src/ml_ops/fully_connected.h new file mode 100644 index 000000000000..ec0c31347ea7 --- /dev/null +++ b/apps/hallmark/src/ml_ops/fully_connected.h @@ -0,0 +1,114 @@ +#ifndef HALIDE_APPS_HALLMARK_FULLY_CONNECTED_H_ +#define HALIDE_APPS_HALLMARK_FULLY_CONNECTED_H_ + +#include + +#include "Halide.h" +#include "absl/log/check.h" + +namespace hallmark { + +// TODO: Move to its common header for GeneratorParams that are used across manyops. +enum QuantizationKind { + None, + QC8NoBias, // Is no bias inplied by qc8? +}; + +inline const std::map quantization_names = { + {"none", QuantizationKind::None}, + {"qc8_no_bias", QuantizationKind::QC8NoBias}, +}; + +struct FullyConnected : public Halide::NamesInterface { + FullyConnected(const std::string &base_name) + : base_name(base_name), + result(base_name + "_fc"), + weights(base_name + "_fc_weights"), + scale(base_name + "_fc_scale") { + } + std::string base_name; + Func result; + + QuantizationKind quantization_kind; + Halide::Type processing_type; + int input_features_size; + int output_features_size; + Var i{"i"}; + + Halide::GeneratorInput> *weights_input; + Halide::GeneratorInput> *scale_input; + Func weights; + Func scale; + + RDom r, r_tail; + + void add_inputs(QuantizationKind kind, + const Halide::Type &processing_type_arg, + int input_features_size_arg, int output_features_size_arg, + Halide::Internal::GeneratorBase *generator) { + quantization_kind = kind; + processing_type = processing_type_arg; + input_features_size = input_features_size_arg; + output_features_size = output_features_size_arg; + weights_input = generator->add_input>(base_name + "_weights", + processing_type, 2); + scale_input = + generator->add_input>(base_name + "_scale", Float(32), 1); + } + + // TODO: Should this return result? + void apply(Func input, const Target &target) { + if (!weights.defined()) { + weights_input->dim(0).set_min(0).dim(1).set_min(0); + weights = *weights_input; + CHECK(weights.args().size() == 2); + } + + // Arguments to inner func + std::vector args = input.args(); + CHECK(args.size() == 3); + Var t = args[1]; + Var b = args[2]; + + Expr scale_expr = 1.0f; + if (quantization_kind == QuantizationKind::QC8NoBias) { + scale = *scale_input; + CHECK(scale.args().size() == 1); + scale_expr = scale(i); + } + + r = RDom(0, input_features_size, base_name + "_r"); + + result(i, t, b) += input(r, t, b) * weights(r, i) * scale_expr; + } + + // TODO: better API needed + static FullyConnected float32_layer(Func inputs, Func weights, int input_size, + int output_size, const Target &target) { + FullyConnected result("float32_layer"); + result.quantization_kind = QuantizationKind::None; + result.processing_type = Float(32); + result.input_features_size = input_size; + result.output_features_size = output_size; + result.weights = weights; + result.apply(inputs, target); + return result; + } + + void default_schedule(LoopLevel result_loop_level, const Target &target) { + const int vec_size = target.natural_vector_size(); + result.compute_at(result_loop_level).vectorize(i, vec_size); + + RVar ro("ro"), ri("ri"); + Var fo("fo"), fi("fi"); + result.update() + .split(r, ro, ri, vec_size * 32) + .split(i, fo, fi, 256) + .atomic() + .vectorize(ri) + .parallel(fo); + } +}; + +} // namespace hallmark +#endif // HALIDE_APPS_HALLMARK_FULLY_CONNECTED_H_ diff --git a/apps/hallmark/src/ml_ops/ml_common.h b/apps/hallmark/src/ml_ops/ml_common.h new file mode 100644 index 000000000000..b270ca7c96ca --- /dev/null +++ b/apps/hallmark/src/ml_ops/ml_common.h @@ -0,0 +1,33 @@ +#ifndef HALIDE_APPS_HALLMARK_ML_COMMON_H_ +#define HALIDE_APPS_HALLMARK_ML_COMMON_H_ + +#include "Halide.h" + +namespace hallmark { + +// Allow easy choice between halide_exp and fast_exp. +Halide::Expr default_exp(Halide::Expr x) { + return Halide::exp(x); +} + +// Allow easy choice between halide_log and fast_log. +Halide::Expr default_log(Halide::Expr x) { + return Halide::log(x); +} + +enum class Activation { + None, + GELU, + SILU, + RELU, +}; + +inline const std::map activation_names = { + {"none", Activation::None}, + {"gelu", Activation::GELU}, + {"silu", Activation::SILU}, + {"relu", Activation::RELU}, +}; + +} // namespace hallmark +#endif // HALIDE_APPS_HALLMARK_ML_COMMON_H_ diff --git a/apps/hallmark/src/ml_ops/normalization.h b/apps/hallmark/src/ml_ops/normalization.h new file mode 100644 index 000000000000..4f28ec05c7c4 --- /dev/null +++ b/apps/hallmark/src/ml_ops/normalization.h @@ -0,0 +1,147 @@ +#ifndef HALIDE_APPS_HALLMARK_NORMALIZATION_H_ +#define HALIDE_APPS_HALLMARK_NORMALIZATION_H_ + +#include + +#include "Halide.h" + +namespace hallmark { + +// TODO: Rename to NormalizationMethod or NormalizationKind? +enum class NormalizationKind { + None, + RMS, + Layer, +}; + +inline const std::map normalization_kind_names = { + {"none", NormalizationKind::None}, + {"rms", NormalizationKind::RMS}, + {"layer", NormalizationKind::Layer}, +}; + +struct Normalization : public Halide::NamesInterface { + Normalization(const std::string &base_name) + : base_name(base_name), + result(base_name + "_apply_norm"), + norm_sum(base_name + "_apply_norm_sum"), + clamped_rms(base_name + "_apply_norm_clamped_rms"), + diff(base_name + "_apply_norm_diff"), + var(base_name + "_apply_norm_var"), + stddev(base_name + "_apply_norm_stddev") { + } + + std::string base_name; + Func result; + Func norm_sum; + Func clamped_rms; + Func diff; + Func var; + Func stddev; + RVar norm_sum_range; + + NormalizationKind norm_kind; + Halide::GeneratorInput> *rms_weight_input; + Func weights; + Halide::GeneratorInput> *gamma_input; + Halide::GeneratorInput> *beta_input; + Func gamma{"gamma"}; + Func beta{"beta"}; + Expr epsilon; + Type processing_type; + std::vector args_norm_sum; + + // TODO: Make into a constructor and use pointers in generator. + void add_inputs(NormalizationKind norm_kind_arg, + const Halide::Type &processing_type_arg, + Halide::Internal::GeneratorBase *generator, + int arg_count = 1) { + processing_type = processing_type_arg; + norm_kind = norm_kind_arg; + if (norm_kind == NormalizationKind::RMS) { + rms_weight_input = generator->add_input>(base_name + "_rms_weights", processing_type, arg_count); + } else if (norm_kind == NormalizationKind::Layer) { + // TODO: fill in. + } + } + + void apply(Func input, Expr size) { + std::vector args = input.args(); + if (norm_kind == NormalizationKind::None) { + // It's important that we always create a distinct function for scheduling purposes + result = input; + } else { + Expr zero = cast(processing_type, 0); + + // Probably should make the splitting up of the operation part of the API + // as it affects the result. There are two goals: avoiding overflow and + // allowing efficient parallel computation. + RDom r(0, size, "apply_norm_sum_range"); + + args_norm_sum = std::vector{args.begin() + 1, args.end()}; + std::vector args_reduction; + args_reduction.emplace_back(r); + args_reduction.insert(args_reduction.end(), + args_norm_sum.begin(), args_norm_sum.end()); + norm_sum(args_norm_sum) = zero; + norm_sum(args_norm_sum) += input(args_reduction) * input(args_reduction); + norm_sum_range = r.x; + if (norm_kind == NormalizationKind::RMS) { + // Can't set this up in the configure call so has to happen here. + if (!weights.defined()) { + weights = *rms_weight_input; + } + clamped_rms(args_norm_sum) = max(cast(processing_type, 1e-6f), + sqrt(norm_sum(args_norm_sum) / size)); + std::vector weights_args(args.begin(), + args.begin() + weights.args().size()); + result(args) = (input(args) / clamped_rms(args_norm_sum)) * + (1 + weights(weights_args)); + } else if (norm_kind == NormalizationKind::Layer) { + diff(args) = input(args) - sqrt(norm_sum(args_norm_sum)); + var(args_norm_sum) = 0; + var(args_norm_sum) += diff(args_reduction) * diff(args_reduction); + stddev(args_norm_sum) = + sqrt(var(args_norm_sum) / size + + (epsilon.defined() ? cast(processing_type, epsilon) : zero)); + Expr body = diff(args) / stddev(args_norm_sum); + std::vector gamma_args(args.begin(), + args.begin() + gamma.args().size()); + if (gamma.defined()) { + body = body * gamma(gamma_args); + } + if (beta.defined()) { + std::vector beta_args(args.begin(), + args.begin() + beta.args().size()); + body += beta(beta_args); + } + result(args) = body; + } + } + } + + void default_schedule(LoopLevel result_loop_level, const Target &t) { + if (norm_kind != NormalizationKind::None) { + norm_sum.compute_at(result, Var::outermost()) + .vectorize(args_norm_sum[0], t.natural_vector_size(), TailStrategy::RoundUp) + .update(0) + .atomic() + .vectorize(norm_sum_range, t.natural_vector_size()); + result.compute_at(result_loop_level) + .vectorize(result.args()[0], t.natural_vector_size(result.type())); + } + if (norm_kind == NormalizationKind::Layer) { + // TODO untested + var.compute_at(result, Var::outermost()) + .vectorize(args_norm_sum[0], t.natural_vector_size(), TailStrategy::RoundUp) + .update(0) + .atomic() + .vectorize(norm_sum_range, t.natural_vector_size()); + result.compute_at(result_loop_level) + .vectorize(result.args()[0], t.natural_vector_size(result.type())); + } + } +}; + +} // namespace hallmark +#endif // HALIDE_APPS_HALLMARK_NORMALIZATION_H_ diff --git a/apps/hallmark/src/ml_ops/rope_weights.h b/apps/hallmark/src/ml_ops/rope_weights.h new file mode 100644 index 000000000000..e131203dffab --- /dev/null +++ b/apps/hallmark/src/ml_ops/rope_weights.h @@ -0,0 +1,127 @@ +#ifndef HALIDE_APPS_HALLMARK_ROPE_WEIGHTS_H_ +#define HALIDE_APPS_HALLMARK_ROPE_WEIGHTS_H_ + +#include + +#include "Halide.h" + +namespace hallmark { + +// Produce cos/sin phased sinusoids at different frequencies to provide a +// positioning signal on input. +// TODO: Is rank 1 correct for this? +struct RoPEWeights : public Halide::NamesInterface { + RoPEWeights(const std::string &base_name) + : base_name(base_name), result(base_name + "_rope_weights") { + } + std::string base_name; + Func result; + RDom r; + int num_channels; + + void apply(int32_t num_channels, const Type &generating_type) { + r = RDom(0, num_channels, base_name + "rope_weights_r"); + + Expr e = 2.f / num_channels; + Expr time_scale = pow(1e-4f, e * (r.x % (num_channels / 2))); + + Var h("h"), t("t"); + result(h, t) = undef(generating_type); + result(r.x, t) = select(r.x >= num_channels / 2, sin(t * time_scale), + cos(t * time_scale)); + this->num_channels = num_channels; + } + + void default_schedule(LoopLevel result_loop_level, const Target &t) { + RVar ro("ro"), ri("ri"); + result.compute_at(result_loop_level) + .update() + .split(r.x, ro, ri, num_channels / 2) + .unroll(ro) + .unroll(ri, 4) + .vectorize(ri, t.natural_vector_size()); + } +}; + +// Implemented per https://arxiv.org/pdf/2104.09864v5.pdf, bottom of page 5. +// Effectively treat each pair of features in the input and weights as a +// complex number and multiply them +// +// Complex representation places real values contiguous and imaginary values +// contifuous immediately after the real ones. +// +// TODO: Rewrite to take arguments from embedding rather than +// hardcodeing vars. +struct RoPE : public Halide::NamesInterface { + RoPE(const std::string &base_name) + : base_name(base_name), + result(base_name + "_rotated"), + inner(base_name + "_rotated_inner") { + } + std::string base_name; + Func result; + Func inner; + Var inner_var{"inner_var"}; + Var is_imaginary{"is_imaginary"}; + int d; + + void apply(Func embedding, Func rope_weights, int d) { + std::vector a = embedding.args(); + CHECK(a.size() == 4); + + std::vector args(a.begin(), a.end()); + + Expr real_h_index = inner_var; + Expr imaginary_h_index = d / 2 + inner_var; + + auto e_args_r = args, e_args_i = args; + e_args_r[0] = real_h_index; + e_args_i[0] = imaginary_h_index; + + Expr e_r = embedding(e_args_r); + Expr e_i = embedding(e_args_i); + + std::vector rw_args_r = {real_h_index, args[2]}; + std::vector rw_args_i = {imaginary_h_index, args[2]}; + + Expr rw_r = rope_weights(rw_args_r); + Expr rw_i = rope_weights(rw_args_i); + + auto ri_args = args; + ri_args[0] = inner_var; + ri_args.insert(ri_args.begin(), is_imaginary); + + // ri(is_imaginary, inner_var, ...) = select(...); + inner(ri_args) = select(is_imaginary == 0, e_r * rw_r - e_i * rw_i, + e_r * rw_i + e_i * rw_r); + + auto r_args_lhs = args; + r_args_lhs[0] = inner_var; + + auto r_args_rhs = args; + r_args_rhs[0] = inner_var % (d / 2); + r_args_rhs.insert(r_args_rhs.begin(), inner_var >= (d / 2)); + + result(r_args_lhs) = inner(r_args_rhs); + + this->d = d; + } + + void default_schedule(LoopLevel result_loop_level, const Target &t) { + Var io("io"), ii("ii"); + inner.compute_at(result, io) + .bound(is_imaginary, 0, 2) + .unroll(is_imaginary) + .unroll(inner_var, 4) + .vectorize(inner_var, t.natural_vector_size()); + result.compute_at(result_loop_level) + .split(inner_var, io, ii, d / 2) + .unroll(io, 2) + .unroll(ii, 4) + .vectorize(ii, t.natural_vector_size()); + } +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_ROPE_WEIGHTS_H_ diff --git a/apps/hallmark/src/ml_ops/softmax.h b/apps/hallmark/src/ml_ops/softmax.h new file mode 100644 index 000000000000..249cedfd6fd2 --- /dev/null +++ b/apps/hallmark/src/ml_ops/softmax.h @@ -0,0 +1,285 @@ +@// TODO: license. +#ifndef HALIDE_APPS_HALLMARK_SOFTMAX_H_ +#define HALIDE_APPS_HALLMARK_SOFTMAX_H_ + +#include +#include + +#include "Halide.h" + +namespace hallmark { +namespace { + +using Halide::Expr; +using Halide::Float; +using Halide::Type; + +Expr evaluate_polynomial(Expr x, float *coeff, int n) { + Expr x2 = x * x; + + Expr even_terms = coeff[0]; + Expr odd_terms = coeff[1]; + + for (int i = 2; i < n; i++) { + if ((i & 1) == 0) { + if (coeff[i] == 0.0f) { + even_terms *= x2; + } else { + even_terms = even_terms * x2 + coeff[i]; + } + } else { + if (coeff[i] == 0.0f) { + odd_terms *= x2; + } else { + odd_terms = odd_terms * x2 + coeff[i]; + } + } + } + + if ((n & 1) == 0) { + return even_terms * std::move(x) + odd_terms; + } else { + return odd_terms * std::move(x) + even_terms; + } +} + +/* Extended exponential which produces two output values, + * each of the same precision as the input, as described in + * "The Two-Pass Softmax Algorithm" by Marat Dukhan and + * Artsiom Ablavatski [https://arxiv.org/abs/2001.04438]. + * + * The first element of the returned Tuple is a psuedo-mantissa while + * the second is an exponent which is an integer. The product of the + * pseudo-mantissa and 2 raised to the returned exponent is the + * desired result e^a. For arguments up to slightly greater than + * 11629079, the pseudo-mantissa is guaranteed to be within the + * interval (-e, e). For larger arguments, the exponent result of the + * tuple may not be able to represent the exact integer necessary to + * keep the pseudo-mantissa within bounds. Thus it can become + * progressively larger in magnitude as the argument increases. + * + * Ideally this routine will maintain a degree of accuracy through the + * entire range and be able to produce results out to the end of the + * numeric range. At present neither of these properties are true due to + * the following issues: + * - Range reduction may overflow when scaling the argument. + * - Range reduction is increasingly inaccurate in reducing the value + * due to the implementation. This results in overflow in the polynomial + * evaluation. + * - Even if the above to issues were resolved, the approximation polynomial + * would have to run on values outside its intended approximation range. + */ +Halide::Tuple extended_exp(const Expr &x_full) { + float ln2_part1 = 0.6931457519f; + float ln2_part2 = 1.4286067653e-6f; + float one_over_ln2 = 1.0f / logf(2.0f); + + Expr scaled = x_full * one_over_ln2; + Expr k_real = floor(scaled); + + Expr x = x_full - k_real * ln2_part1; + x = x - k_real * ln2_part2; + + float coeff[] = { + 0.00031965933071842413f, + 0.00119156835564003744f, + 0.00848988645943932717f, + 0.04160188091348320655f, + 0.16667983794100929562f, + 0.49999899033463041098f, + 1.0f, + 1.0f}; + Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0])); + + // Ensure that the mantissa part is not a NaN or itself an infinity. + result = strict_float(select(!is_finite(k_real), 1, result)); + result = common_subexpression_elimination(result); + + return {result, k_real}; +} + +} // anonymous namespace + +struct Softmax : public Halide::NamesInterface { + enum class Algorithm { + Naive, + TwoPass, + ThreePass, + }; + + Softmax(const std::string &base_name, + Algorithm algorithm = Algorithm::TwoPass) + : base_name(base_name), + algorithm(algorithm), + result(base_name + "_softmax"), + ext_exp(base_name + "_softmax_ext_exp"), + exponentials(base_name + "_softmax_exponentials"), + softmax_sum(base_name + "_softmax_sum") { + } + std::string base_name; + Algorithm algorithm; + Func result; + + // Naive algorithm + Func exponentials; + + // Two pass algorithm + Func ext_exp; + + // Three pass algorithm + Func max_bias; + Func biased_exp; + + // Common to different algorithms + Func softmax_sum; + Var result_inner; + RVar softmax_sum_inner; // TODO: Remove this. + Var softmax_sum_inner_var; + LoopLevel softmax_sum_compute_at; + + void apply(Func input, Expr size, const Type &generating_type) { + switch (algorithm) { + case Algorithm::Naive: + naive_algorithm(input, size, generating_type); + break; + case Algorithm::TwoPass: + two_pass_algorithm(input, size, generating_type); + break; + case Algorithm::ThreePass: + three_pass_algorithm(input, size, generating_type); + break; + }; + } + + void naive_algorithm(Func input, Expr size, const Type &generating_type) { + auto args = input.args(); + RDom r(0, size); + + exponentials(args) = + default_exp(cast(clamp(input(args), -1e12f, 1e12f))); + + std::vector args_sum(args.begin() + 1, args.end()); + std::vector args_reduction; + args_reduction.emplace_back(r.x); + args_reduction.insert(args_reduction.end(), args_sum.begin(), + args_sum.end()); + + softmax_sum(args_sum) = Expr(0.0); + softmax_sum(args_sum) += exponentials(args_reduction); + softmax_sum_inner = r.x; + softmax_sum_inner_var = args_sum[0]; + + result(args) = cast(generating_type, + input(args) / select(softmax_sum(args_sum) < Expr(1e-5), + 1, softmax_sum(args_sum))); + result_inner = args[0]; + softmax_sum_compute_at = LoopLevel(result, args[1]); + } + + // Implementation based on the algorithm in + // https://arxiv.org/pdf/2001.04438.pdf + void two_pass_algorithm(Func input, Expr size, const Type &generating_type) { + auto args = input.args(); + RDom r(0, size); + + // TODO: It should not be necessary to use double for computation here. +#define USE_DOUBLE 1 +#if USE_DOUBLE + ext_exp(args) = extended_exp(cast(input(args))); +#else + ext_exp(args) = extended_exp(input(args)); +#endif + + std::vector args_inner(args.begin() + 1, args.end()); + std::vector args_reduction; + args_reduction.emplace_back(r.x); + args_reduction.insert(args_reduction.end(), args_inner.begin(), + args_inner.end()); + + // This reduction maintains a Tuple of with the sum and the maximum exponent + // so far, both as floating point numbers. + softmax_sum(args_inner) = +#if USE_DOUBLE + Halide::Tuple(Expr(0.0), Expr(std::numeric_limits::lowest())); +#else + Halide::Tuple(0.0f, Expr(std::numeric_limits::lowest())); +#endif + Expr running_max_exp = + max(softmax_sum(args_inner)[1], ext_exp(args_reduction)[1]); + Expr m_sub_i_term = ext_exp(args_reduction)[0] * + pow(2.0f, ext_exp(args_reduction)[1] - running_max_exp); + Expr m_sum_term = softmax_sum(args_inner)[0] * + pow(2.0f, softmax_sum(args_inner)[1] - running_max_exp); + Expr running_sum = m_sub_i_term + m_sum_term; + softmax_sum(args_inner) = Tuple(running_sum, running_max_exp); + Expr lambda = 1 / softmax_sum(args_inner)[0]; + Expr t = + cast(generating_type, + ext_exp(args)[0] * lambda * + pow(2.0f, ext_exp(args)[1] - softmax_sum(args_inner)[1])); + result(args) = t; + result_inner = args[0]; + softmax_sum_inner = r; + softmax_sum_inner_var = args_inner[0]; + softmax_sum_compute_at = LoopLevel(result, args[1]); + } + + void three_pass_algorithm(Func input, Expr size, const Type &generating_type) { + auto args = input.args(); + RDom r(0, size); + + std::vector args_inner(args.begin() + 1, args.end()); + std::vector args_reduction; + args_reduction.emplace_back(r.x); + args_reduction.insert(args_reduction.end(), args_inner.begin(), + args_inner.end()); + + max_bias(args_inner) = std::numeric_limits::lowest(); + max_bias(args_inner) = max(max_bias(args_inner), input(args_reduction)); + + biased_exp(args) = halide_exp(input(args) - max_bias(args_inner)); + softmax_sum(args_inner) = 0.0f; + softmax_sum(args_inner) += biased_exp(args_reduction); + + Expr lambda = 1 / softmax_sum(args_inner); + result(args) = halide_exp(input(args) - max_bias(args_inner)) * lambda; + result_inner = args[0]; + softmax_sum_inner = r; + softmax_sum_inner_var = args_inner[0]; + softmax_sum_compute_at = LoopLevel(result, args[1]); + } + + // TODO: add support for resuse vs. recompute scheduling on exp operations. + + void default_schedule(LoopLevel result_loop_level, const Target &t, + bool vectorize) { + if (algorithm == Algorithm::Naive) { + exponentials.compute_at(softmax_sum_compute_at); + } else if (algorithm == Algorithm::TwoPass) { + ext_exp.compute_inline(); + } else if (algorithm == Algorithm::ThreePass) { + max_bias.compute_at(softmax_sum_compute_at); + // TODO: vectorize max loop, maybe parallelize + biased_exp.compute_at(softmax_sum_compute_at); + } + softmax_sum.compute_at(softmax_sum_compute_at) + .store_in(MemoryType::Register) + .vectorize(softmax_sum_inner_var, t.natural_vector_size()) + .update(0) + .unscheduled(); + result.compute_at(result_loop_level); + if (vectorize) { + // In some modes, this dimension is narrow and we don't want to vectorize + // it +#if USE_DOUBLE + result.vectorize(result_inner, t.natural_vector_size()); +#else + result.vectorize(result_inner, t.natural_vector_size()); +#endif + } + } +}; + +} // namespace hallmark + +#endif // HALIDE_APPS_HALLMARK_SOFTMAX_H_ diff --git a/apps/hallmark/test/CMakeLists.txt b/apps/hallmark/test/CMakeLists.txt new file mode 100644 index 000000000000..4cf5a0605102 --- /dev/null +++ b/apps/hallmark/test/CMakeLists.txt @@ -0,0 +1,34 @@ +add_executable(llm_generator_test llm_generator_test.cc) +target_link_libraries(llm_generator_test + PRIVATE + absl::flags + absl::flags_parse + GTest::gtest + hallmark_contrib + hallmark_llm + hallmark_position_embedding + hallmark_postprocessor + hallmark_preprocessor + hallmark_rope_values + hallmark_transformer_kv_update_cache + hallmark_transformer_kv_use_cache + hallmark_transformer_no_kv_cache) + +include(GoogleTest) +gtest_discover_tests(llm_generator_test) + +add_executable(llm_generator_bench llm_generator_bench.cc) +target_link_libraries(llm_generator_bench + PRIVATE + absl::flags + absl::flags_parse + benchmark::benchmark + hallmark_contrib + hallmark_llm + hallmark_position_embedding + hallmark_postprocessor + hallmark_preprocessor + hallmark_rope_values + hallmark_transformer_kv_update_cache + hallmark_transformer_kv_use_cache + hallmark_transformer_no_kv_cache) diff --git a/apps/hallmark/test/llm_generator_bench.cc b/apps/hallmark/test/llm_generator_bench.cc new file mode 100644 index 000000000000..23d09079927f --- /dev/null +++ b/apps/hallmark/test/llm_generator_bench.cc @@ -0,0 +1,256 @@ +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "contrib/llm_weights.h" +#include "contrib/status_helpers.h" +#include "hallmark_position_embedding.h" +#include "hallmark_postprocessor.h" +#include "hallmark_preprocessor.h" +#include "hallmark_rope_values.h" +#include "hallmark_transformer_kv_update_cache.h" +#include "hallmark_transformer_kv_use_cache.h" +#include "hallmark_transformer_no_kv_cache.h" +#include "src/llm.h" + +ABSL_FLAG(std::optional, model_path, std::nullopt, + "Path to the tflite model file."); + +// TODO just for this model? +ABSL_FLAG(int, max_tokens, 512, + "Maximum number of input and output tokens. This value needs to be " + "at least larger than the number of input tokens."); + +namespace hallmark { + +namespace { + +absl::StatusOr> LoadLlm() { + CHECK(absl::GetFlag(FLAGS_model_path).has_value()); + + const std::string model_path = absl::GetFlag(FLAGS_model_path).value(); + // LOG(INFO) << "Using model from path: " << model_path; + + auto p = LoadLlmParams(model_path); + if (!p.ok()) { + return p.status(); + } + auto llm_params = std::move(p.value()); + llm_params.seq_size_T = absl::GetFlag(FLAGS_max_tokens); // TODO: not sure about this + + auto w = LoadLlmWeights(model_path, llm_params); + if (!w.ok()) { + return w.status(); + } + auto llm_weights = std::move(w.value()); + + auto l = Llm::CreateLlm(llm_weights, llm_params); + if (!l.ok()) { + return l.status(); + } + auto llm = std::move(l.value()); + + RETURN_IF_ERROR(llm->Reset()); + RETURN_IF_ERROR( + llm->InitAttentionMaskValues(llm_params.seq_size_T)); + + return llm; +} + +} // namespace + +void BM_RoPEValues(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto &segment_pos_values = llm.value()->segment_pos_values(); + + for (auto _ : state) { + CHECK_EQ(0, rope_values(segment_pos_values)); + } +} + +void BM_Preprocessor(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto input = llm.value()->AllocateSeqBuffer( + llm.value()->GetLlmParams().seq_size_T); // TODO just for this model + auto output = llm.value()->AllocateSeqBuffer(input.dim(1).extent()); + + for (auto _ : state) { + CHECK_EQ(0, preprocessor(input, output)); + } +} + +void BM_transformer_no_kv_cache(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto input = llm.value()->AllocateSeqBuffer( + llm.value()->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm.value()->segment_pos_values(); + auto &attention_mask_values = llm.value()->attention_mask_values(); + auto output = llm.value()->AllocateSeqBuffer(input.dim(1).extent()); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm.value()->sas()[0]; + auto ffs = llm.value()->ffs()[0]; + + for (auto _ : state) { + CHECK_EQ( + 0, transformer_no_kv_cache( + input, segment_pos_values, attention_mask_values, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, output)); + } +} + +void BM_transformer_kv_use_cache(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto input = llm.value()->AllocateSeqBuffer( + llm.value()->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm.value()->segment_pos_values(); + auto &attention_mask_values = llm.value()->attention_mask_values(); + auto output = llm.value()->AllocateSeqBuffer(input.dim(1).extent()); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm.value()->sas()[0]; + auto ffs = llm.value()->ffs()[0]; + + auto k_cache = Halide::Runtime::Buffer( + llm.value()->GetLlmParams().head_dim_H, + 1, // llm.value()->GetLlmParams().model_dim_D / + // llm.value()->GetLlmParams().head_dim_H, + llm.value()->GetLlmParams().seq_size_T, + llm.value()->GetLlmParams().batch_size_B); + + auto v_cache = Halide::Runtime::Buffer( + llm.value()->GetLlmParams().head_dim_H, + 1, // llm.value()->GetLlmParams().model_dim_D / + // llm.value()->GetLlmParams().head_dim_H, + llm.value()->GetLlmParams().seq_size_T, + llm.value()->GetLlmParams().batch_size_B); + + constexpr int last_kv_cache_start = 1; + auto input_slice = input.cropped(1, last_kv_cache_start, 1); + auto output_slice = output.cropped(1, last_kv_cache_start, 1); + + for (auto _ : state) { + CHECK_EQ( + 0, transformer_kv_use_cache( + input_slice, segment_pos_values, attention_mask_values, + std::get<0>(*(sas.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, k_cache, + v_cache, output_slice)); + } +} + +void BM_transformer_kv_update_cache(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto input = llm.value()->AllocateSeqBuffer( + llm.value()->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm.value()->segment_pos_values(); + auto &attention_mask_values = llm.value()->attention_mask_values(); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm.value()->sas()[0]; + auto ffs = llm.value()->ffs()[0]; + + auto k_cache = Halide::Runtime::Buffer( + llm.value()->GetLlmParams().head_dim_H, + 1, // llm.value()->GetLlmParams().model_dim_D / + // llm.value()->GetLlmParams().head_dim_H, + llm.value()->GetLlmParams().seq_size_T, + llm.value()->GetLlmParams().batch_size_B); + + auto v_cache = Halide::Runtime::Buffer( + llm.value()->GetLlmParams().head_dim_H, + 1, // llm.value()->GetLlmParams().model_dim_D / + // llm.value()->GetLlmParams().head_dim_H, + llm.value()->GetLlmParams().seq_size_T, + llm.value()->GetLlmParams().batch_size_B); + + constexpr int last_kv_cache_start = 1; + auto input_slice = input.cropped(1, last_kv_cache_start, 1); + + const int run_extent = input_slice.dim(1).max() - last_kv_cache_start + 1; + auto key_slice = k_cache.cropped(2, last_kv_cache_start, run_extent); + auto value_slice = v_cache.cropped(2, last_kv_cache_start, run_extent); + + for (auto _ : state) { + CHECK_EQ( + 0, transformer_kv_update_cache( + input_slice, segment_pos_values, attention_mask_values, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + sas.k_weight.weights, sas.k_weight.scale, sas.q_weight.weights, + sas.q_weight.scale, sas.v_weight.weights, sas.v_weight.scale, + sas.post_proj_weight.weights, sas.post_proj_weight.scale, + key_slice, value_slice)); + } +} + +void BM_Postprocessor(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + auto input = llm.value()->AllocateSeqBuffer( + llm.value()->GetLlmParams().seq_size_T); // TODO just for this model + auto logits_output = + Halide::Runtime::Buffer(llm.value()->GetLlmParams().voc_size_V, 1, + llm.value()->GetLlmParams().batch_size_B); + + for (auto _ : state) { + CHECK_EQ(0, + postprocessor( + input, + std::get<0>(*llm.value()->final_norm_weight()).norm_weight.weights, + llm.value()->softmax_linear_weights(), + llm.value()->softmax_linear_scale(), logits_output)); + } +} + +void BM_PositionEmbedding(benchmark::State &state) { + auto llm = LoadLlm(); + CHECK_OK(llm); + const auto ¶ms = llm.value()->GetLlmParams(); + auto pos_embedding = + Halide::Runtime::Buffer(static_cast(params.model_dim_D), + static_cast(params.seq_size_T)); + int32_t input_length = + llm.value()->GetLlmParams().seq_size_T; // TODO just for this model + + for (auto _ : state) { + CHECK_EQ(0, position_embedding(input_length, params.seq_size_T, + params.model_dim_D, 1.0f, 10000.0f, + pos_embedding)); + } +} + +BENCHMARK(BM_Preprocessor); +BENCHMARK(BM_transformer_no_kv_cache); +BENCHMARK(BM_transformer_kv_use_cache); +BENCHMARK(BM_transformer_kv_update_cache); +BENCHMARK(BM_Postprocessor); +BENCHMARK(BM_PositionEmbedding); + +} // namespace hallmark + +// gtest's main() won't initialize Abseil flags, so we must define our own +int main(int argc, char **argv) { + benchmark::Initialize(&argc, argv); + absl::ParseCommandLine(argc, argv); + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); + + return 0; +} diff --git a/apps/hallmark/test/llm_generator_test.cc b/apps/hallmark/test/llm_generator_test.cc new file mode 100644 index 000000000000..ff9d1e7c7b66 --- /dev/null +++ b/apps/hallmark/test/llm_generator_test.cc @@ -0,0 +1,210 @@ +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "contrib/llm_weights.h" +#include "contrib/status_helpers.h" +#include "hallmark_position_embedding.h" +#include "hallmark_postprocessor.h" +#include "hallmark_preprocessor.h" +#include "hallmark_rope_values.h" +#include "hallmark_transformer_kv_update_cache.h" +#include "hallmark_transformer_kv_use_cache.h" +#include "hallmark_transformer_no_kv_cache.h" +#include "src/llm.h" + +ABSL_FLAG(std::optional, model_path, std::nullopt, + "Path to the tflite model file."); + +// TODO just for this model? +ABSL_FLAG(int, max_tokens, 512, + "Maximum number of input and output tokens. This value needs to be " + "at least larger than the number of input tokens."); + +namespace hallmark { + +namespace { + +absl::StatusOr> LoadLlm() { + CHECK(absl::GetFlag(FLAGS_model_path).has_value()); + + const std::string model_path = absl::GetFlag(FLAGS_model_path).value(); + // LOG(INFO) << "Using model from path: " << model_path; + + auto p = LoadLlmParams(model_path); + if (!p.ok()) { + return p.status(); + } + auto llm_params = std::move(p.value()); + llm_params.seq_size_T = absl::GetFlag(FLAGS_max_tokens); // TODO: not sure about this + + auto w = LoadLlmWeights(model_path, llm_params); + if (!w.ok()) { + return w.status(); + } + auto llm_weights = std::move(w.value()); + + auto l = Llm::CreateLlm(llm_weights, llm_params); + if (!l.ok()) { + return l.status(); + } + auto llm = std::move(l.value()); + + RETURN_IF_ERROR(llm->Reset()); + RETURN_IF_ERROR( + llm->InitAttentionMaskValues(llm_params.seq_size_T)); + + return llm; +} +} // namespace + +class LlmHalideTest : public testing::Test { +protected: + void SetUp() override { + auto llm = LoadLlm(); + CHECK_OK(llm); + llm_ = std::move(*llm); + } + + void TearDown() override { + // nothing + } + + std::unique_ptr llm_; +}; + +TEST_F(LlmHalideTest, RoPEValues) { + auto &segment_pos_values = llm_->segment_pos_values(); + CHECK_OK(StatusFromHalide(rope_values(segment_pos_values))); +} + +TEST_F(LlmHalideTest, Preprocessor) { + auto input = llm_->AllocateSeqBuffer( + llm_->GetLlmParams().seq_size_T); // TODO just for this model + auto output = llm_->AllocateSeqBuffer(input.dim(1).extent()); + CHECK_OK(StatusFromHalide(preprocessor(input, output))); +} + +TEST_F(LlmHalideTest, transformer_no_kv_cache) { + auto input = llm_->AllocateSeqBuffer( + llm_->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm_->segment_pos_values(); + auto &attention_mask_values = llm_->attention_mask_values(); + auto output = llm_->AllocateSeqBuffer(input.dim(1).extent()); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm_->sas()[0]; + auto ffs = llm_->ffs()[0]; + + CHECK_OK(StatusFromHalide(transformer_no_kv_cache( + input, segment_pos_values, attention_mask_values, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, sas.k_weight.weights, + sas.k_weight.scale, sas.q_weight.weights, sas.q_weight.scale, + sas.v_weight.weights, sas.v_weight.scale, sas.post_proj_weight.weights, + sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, output))); +} + +TEST_F(LlmHalideTest, transformer_kv_use_cache) { + auto input = llm_->AllocateSeqBuffer( + llm_->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm_->segment_pos_values(); + auto &attention_mask_values = llm_->attention_mask_values(); + auto output = llm_->AllocateSeqBuffer(input.dim(1).extent()); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm_->sas()[0]; + auto ffs = llm_->ffs()[0]; + + auto k_cache = Halide::Runtime::Buffer( + llm_->GetLlmParams().head_dim_H, 1, llm_->GetLlmParams().seq_size_T, + llm_->GetLlmParams().batch_size_B); + + auto v_cache = Halide::Runtime::Buffer( + llm_->GetLlmParams().head_dim_H, 1, llm_->GetLlmParams().seq_size_T, + llm_->GetLlmParams().batch_size_B); + + constexpr int last_kv_cache_start = 1; + auto input_slice = input.cropped(1, last_kv_cache_start, 1); + auto output_slice = output.cropped(1, last_kv_cache_start, 1); + + CHECK_OK(StatusFromHalide(transformer_kv_use_cache( + input_slice, segment_pos_values, attention_mask_values, + std::get<0>(*(sas.pre_norm_weight)).norm_weight.weights, sas.k_weight.weights, + sas.k_weight.scale, sas.q_weight.weights, sas.q_weight.scale, + sas.v_weight.weights, sas.v_weight.scale, sas.post_proj_weight.weights, + sas.post_proj_weight.scale, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, + ffs.layer_1_weight.weights, ffs.layer_1_weight.scale, + ffs.layer_1_gate_weight.weights, ffs.layer_1_gate_weight.scale, + ffs.layer_2_weight.weights, ffs.layer_2_weight.scale, k_cache, v_cache, + output_slice))); +} + +TEST_F(LlmHalideTest, transformer_kv_update_cache) { + auto input = llm_->AllocateSeqBuffer( + llm_->GetLlmParams().seq_size_T); // TODO just for this model + auto &segment_pos_values = llm_->segment_pos_values(); + auto &attention_mask_values = llm_->attention_mask_values(); + // TODO: we only do the first entry here for now. Should we do all of them? + auto sas = llm_->sas()[0]; + auto ffs = llm_->ffs()[0]; + + auto k_cache = Halide::Runtime::Buffer( + llm_->GetLlmParams().head_dim_H, 1, llm_->GetLlmParams().seq_size_T, + llm_->GetLlmParams().batch_size_B); + + auto v_cache = Halide::Runtime::Buffer( + llm_->GetLlmParams().head_dim_H, 1, llm_->GetLlmParams().seq_size_T, + llm_->GetLlmParams().batch_size_B); + + constexpr int last_kv_cache_start = 1; + auto input_slice = input.cropped(1, last_kv_cache_start, 1); + + const int run_extent = input_slice.dim(1).max() - last_kv_cache_start + 1; + auto key_slice = k_cache.cropped(2, last_kv_cache_start, run_extent); + auto value_slice = v_cache.cropped(2, last_kv_cache_start, run_extent); + + CHECK_OK(StatusFromHalide(transformer_kv_update_cache( + input_slice, segment_pos_values, attention_mask_values, + std::get<0>(*(ffs.pre_norm_weight)).norm_weight.weights, sas.k_weight.weights, + sas.k_weight.scale, sas.q_weight.weights, sas.q_weight.scale, + sas.v_weight.weights, sas.v_weight.scale, sas.post_proj_weight.weights, + sas.post_proj_weight.scale, key_slice, value_slice))); +} + +TEST_F(LlmHalideTest, Postprocessor) { + auto input = llm_->AllocateSeqBuffer( + llm_->GetLlmParams().seq_size_T); // TODO just for this model + auto logits_output = Halide::Runtime::Buffer( + llm_->GetLlmParams().voc_size_V, 1, llm_->GetLlmParams().batch_size_B); + + // Postprocess + CHECK_OK(StatusFromHalide(postprocessor( + input, std::get<0>(*llm_->final_norm_weight()).norm_weight.weights, + llm_->softmax_linear_weights(), llm_->softmax_linear_scale(), + logits_output))); +} + +TEST_F(LlmHalideTest, PositionEmbedding) { + const auto ¶ms = llm_->GetLlmParams(); + auto pos_embedding = + Halide::Runtime::Buffer(static_cast(params.model_dim_D), + static_cast(params.seq_size_T)); + int32_t input_length = params.seq_size_T; + CHECK_OK(StatusFromHalide(position_embedding(input_length, params.seq_size_T, + params.model_dim_D, 1.0f, + 10000.0f, pos_embedding))); +} + +} // namespace hallmark + +// gtest's main() won't initialize Abseil flags, so we must define our own +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + absl::ParseCommandLine(argc, argv); + return RUN_ALL_TESTS(); +}