From 07fd6041f90b0b84f21a11541922784d6825ad9b Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 13 Jun 2024 15:28:09 +0800 Subject: [PATCH] support shape_propogate for gather,reduce,scatter,transpose --- .../disc/transforms/disc_shape_propagate.cc | 234 +++++++++++++++++- .../tests/disc-shape-propagate.mlir | 152 +++++++++++- 2 files changed, 383 insertions(+), 3 deletions(-) mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir diff --git a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc index ca9185a13c2..64156322509 100644 --- a/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc +++ b/tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc @@ -67,6 +67,9 @@ struct DiscShapePropagatePass DiscShapePropagatePassBase::getDependentDialects( registry); registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); } void runOnOperation() override; }; @@ -75,7 +78,9 @@ bool isBinaryOp(Operation* op) { isa(*op) || isa(*op); } -bool isUnaryOp(Operation* op) { return isa(op); } +bool isUnaryOp(Operation* op) { + return isa(op); +} bool isConcreteShape(ShapeContext& ctx) { for (auto dim : ctx.shape) { if (dim == ShapedType::kDynamic) return false; @@ -136,10 +141,24 @@ std::optional propagateHelper(OpBuilder& b, Operation* op, ShapeContext& inputCtx) { return std::nullopt; } + +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto dim_op = dyn_cast_or_null(op); + if (!dim_op) return std::nullopt; + + SmallVector new_shape( + op->getResult(0).getType().cast().getShape()); + return ShapeContext(op->getResult(0), new_shape); +} + template <> std::optional propagateHelper( OpBuilder& b, Operation* op, ShapeContext& inputCtx) { - auto dot_op = cast(op); + auto dot_op = dyn_cast_or_null(op); + if (!dot_op) return std::nullopt; + auto lhs_shape = dot_op.getOperand(0).getType().cast().getShape(); auto rhs_shape = @@ -152,6 +171,210 @@ std::optional propagateHelper( return ShapeContext(op->getResult(0), new_shape); } +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto concat_op = dyn_cast_or_null(op); + if (!concat_op) return std::nullopt; + + auto operands = op->getOperands(); + SmallVector new_shape( + op->getResult(0).getType().cast().getRank(), + ShapedType::kDynamic); + new_shape[concat_op.getDimension()] = + op->getResult(0) + .getType() + .cast() + .getShape()[concat_op.getDimension()]; + + for (auto operand : operands) { + auto shape = operand.getType().cast().getShape(); + if (inputCtx.value == operand) { + shape = inputCtx.shape; + } + + for (int dim_idx = 0; dim_idx < new_shape.size(); dim_idx++) { + if (dim_idx == concat_op.getDimension() && + shape[dim_idx] == ShapedType::kDynamic) { + new_shape[dim_idx] = ShapedType::kDynamic; + } else if (dim_idx != concat_op.getDimension() && + shape[dim_idx] != ShapedType::kDynamic) { + new_shape[dim_idx] = shape[dim_idx]; + } + } + } + + return ShapeContext(op->getResult(0), new_shape); +} + +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto transpose_op = dyn_cast_or_null(op); + if (!transpose_op) return std::nullopt; + + SmallVector new_shape; + + for (auto it = transpose_op.getPermutation().begin(); + it != transpose_op.getPermutation().end(); it++) { + int64_t src_dim = (*it).getSExtValue(); + new_shape.push_back(inputCtx.shape[src_dim]); + } + + return ShapeContext(op->getResult(0), new_shape); +} + +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto reduce_op = dyn_cast_or_null(op); + if (!reduce_op) return std::nullopt; + + SmallVector new_shape; + + for (int dim = 0; dim < inputCtx.shape.size(); dim++) { + bool add_dim = true; + for (auto it = reduce_op.getDimensions().begin(); + it != reduce_op.getDimensions().end(); it++) { + int64_t src_dim = (*it).getSExtValue(); + add_dim = add_dim && !(dim == src_dim); + } + if (add_dim) { + new_shape.push_back(inputCtx.shape[dim]); + } + } + + return ShapeContext(op->getResult(0), new_shape); +} + +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto dynamic_gather_op = dyn_cast_or_null(op); + if (!dynamic_gather_op) return std::nullopt; + + SmallVector new_shape(dynamic_gather_op.getResult() + .getType() + .cast() + .getShape()); + + auto attr = dynamic_gather_op.getDimensionNumbers(); + auto slice_sizes = + op->getOperand(2).getType().cast().getShape(); + + auto offset_dims = attr.getOffsetDims(); + auto index_vector_dim = attr.getIndexVectorDim(); + auto collapsed_slice_dims = attr.getCollapsedSliceDims(); + + if (inputCtx.value == op->getOperand(1)) { + // start_indices + int shape_dim_idx = 0; + for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) { + if (dim_idx != index_vector_dim) { + new_shape[shape_dim_idx++] = inputCtx.shape[dim_idx]; + } + } + } else if (inputCtx.value == op->getOperand(2)) { + int shape_dim_idx = + op->getOperand(0).getType().cast().getRank() - 1; + for (int dim_idx = 0; dim_idx < inputCtx.shape.size(); dim_idx++) { + bool include_this_dim = true; + for (auto collapsed_slice_dim : collapsed_slice_dims) { + if (dim_idx == collapsed_slice_dim) { + include_this_dim = false; + } + } + if (include_this_dim) { + // need to decide whether it is a constant value or value from operand + new_shape[shape_dim_idx++] = inputCtx.shape[dim_idx]; + } + } + } + + return ShapeContext(op->getResult(0), new_shape); +} + +template <> +std::optional propagateHelper( + OpBuilder& b, Operation* op, ShapeContext& inputCtx) { + auto gather_op = dyn_cast_or_null(op); + if (!gather_op) return std::nullopt; + + // batch_dims = [d for d in axes(result) and d not in offset_dims]. + auto attr = gather_op.getDimensionNumbers(); + auto offset_dims = attr.getOffsetDims(); + auto index_vector_dim = attr.getIndexVectorDim(); + auto slice_sizes = gather_op.getSliceSizes(); + auto collapsed_slice_dims = attr.getCollapsedSliceDims(); + auto src_shape = + op->getOperand(0).getType().cast().getShape(); + SmallVector slice_sizes_vec; + SmallVector new_shape; + auto start_indices_shape = + op->getOperand(1).getType().cast().getShape(); + + b.setInsertionPoint(op); + // process offset_dim_sizes, offset dims + for (int dim_idx = 0; dim_idx < start_indices_shape.size(); dim_idx++) { + if (dim_idx != index_vector_dim) { + new_shape.push_back(start_indices_shape[dim_idx]); + } + } + + int dim_idx = 0; + for (auto dim_size : slice_sizes) { + bool include_this_dim = true; + for (auto collapsed_slice_dim : collapsed_slice_dims) { + if (dim_idx == collapsed_slice_dim) { + include_this_dim = false; + } + } + // need to decide whether it is a constant value or value from operand + if (src_shape[dim_idx] == dim_size.getSExtValue()) { + auto dim_value = b.create(op->getLoc(), op->getOperand(0), + b.create( + op->getLoc(), dim_idx) + .getResult()) + .getResult(); + slice_sizes_vec.push_back( + b.create(op->getLoc(), b.getI64Type(), dim_value) + .getResult()); + } else { + slice_sizes_vec.push_back(b.create( + op->getLoc(), dim_size.getSExtValue(), b.getI64Type())); + } + + if (include_this_dim && src_shape[dim_idx] == dim_size.getSExtValue()) { + new_shape.push_back(ShapedType::kDynamic); + } else if (include_this_dim && + src_shape[dim_idx] != dim_size.getSExtValue()) { + new_shape.push_back(dim_size.getSExtValue()); + } + + dim_idx += 1; + } + + // create a dynamic gather op + auto dynamic_gather_op = b.create( + op->getLoc(), + RankedTensorType::get(new_shape, gather_op.getResult() + .getType() + .cast() + .getElementType()), + op->getOperand(0), op->getOperand(1), + b.create(op->getLoc(), slice_sizes_vec) + .getResult(), + mhlo::GatherDimensionNumbersAttr::get( + attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), + attr.getStartIndexMap(), attr.getIndexVectorDim()), + gather_op.getIndicesAreSorted()); + gather_op.getResult().replaceAllUsesWith(dynamic_gather_op.getResult()); + + // Update DynamicGatherOp result shape information + return propagateHelper( + b, dynamic_gather_op.getOperation(), inputCtx); +} + LogicalResult parseInputDynamicDims( func::FuncOp main, std::vector>>& input_dynamic_dims) { @@ -203,7 +426,13 @@ std::optional propagateOpShape(OpBuilder& rewriter, Operation* op, using PropagationFunc = std::optional (*)(OpBuilder&, Operation*, ShapeContext&); const std::vector propagationFunctions = { + propagateHelper, propagateHelper, + propagateHelper, + propagateHelper, + propagateHelper, + propagateHelper, + propagateHelper, }; // Iterate over the propagation functions and apply each one for (const auto& propagate : propagationFunctions) { @@ -226,6 +455,7 @@ void visitOperator(ModuleOp& m, OpBuilder& rewriter, Operation* op, op->getName().stripDialect().str()); return; } + for (auto user : op->getResult(0).getUsers()) { visitOperator(m, rewriter, user, resultShapeCtx.value()); } diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir old mode 100644 new mode 100755 index ae112a3a903..57b52fffb02 --- a/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir @@ -1,4 +1,4 @@ -// RUN: disc-opt -split-input-file --disc-shape-propagate %s | FileCheck %s +// RUN: disc-opt -split-input-file -disc-shape-propagate -canonicalize %s | FileCheck %s // CHECK-LABEL: main func.func @main(%arg0: tensor<4x101xi64>, %arg1: tensor<4x101xi64>) -> tensor<4x101xi1> attributes{tf.entry_function = {input_dynamic_dims = "0:1|1:1"}}{ @@ -16,4 +16,154 @@ func.func @main(%arg0: tensor<4x101xi64>) -> tensor<4x101xi1> attributes{tf.entr %0 = mhlo.constant dense<0> : tensor<4x101xi64> %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo} : (tensor<4x101xi64>, tensor<4x101xi64>) -> tensor<4x101xi1> return %1 : tensor<4x101xi1> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101x32x128xbf16>) -> tensor<4x32x101x128xbf16> attributes{tf.entry_function = {input_dynamic_dims = "0:0,1"}}{ + // CHECK: %0 = "mhlo.transpose"(%arg0) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>, result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "bf16[4,32,101,128]{3,1,2,0}"} : (tensor) -> tensor + %1 = "mhlo.transpose"(%arg0) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>, result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "bf16[4,32,101,128]{3,1,2,0}"} : (tensor<4x101x32x128xbf16>) -> tensor<4x32x101x128xbf16> + return %1 : tensor<4x32x101x128xbf16> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101x4096xf32>) -> tensor<4x101xf32> attributes{tf.entry_function = {input_dynamic_dims = "0:0,1"}}{ + %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [2] : (tensor, tensor) -> tensor + // CHECK: reducer(%arg1: tensor, %arg2: tensor) { + // CHECK: %2 = mhlo.add %arg1, %arg2 : tensor + // CHECK: mhlo.return %2 : tensor + // CHECK: } + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [2] : (tensor<4x101x4096xf32>, tensor) -> tensor<4x101xf32> + reducer(%arg216: tensor, %arg217: tensor) { + %2869 = mhlo.add %arg216, %arg217 : tensor + mhlo.return %2869 : tensor + } + return %2 : tensor<4x101xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101x4096xf32>) -> tensor<4x101xf32> attributes{tf.entry_function = {input_dynamic_dims = "0:1"}}{ + %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [2] : (tensor<4x?x4096xf32>, tensor) -> tensor<4x?xf32> + // CHECK: reducer(%arg1: tensor, %arg2: tensor) { + // CHECK: %2 = mhlo.add %arg1, %arg2 : tensor + // CHECK: mhlo.return %2 : tensor + // CHECK: } + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [2] : (tensor<4x101x4096xf32>, tensor) -> tensor<4x101xf32> + reducer(%arg216: tensor, %arg217: tensor) { + %2869 = mhlo.add %arg216, %arg217 : tensor + mhlo.return %2869 : tensor + } + return %2 : tensor<4x101xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x101x4096xf32>) -> tensor<4x101xf32> attributes{tf.entry_function = {input_dynamic_dims = "0:2"}}{ + %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [2] : (tensor<4x101x?xf32>, tensor) -> tensor<4x101xf32> + // CHECK: reducer(%arg1: tensor, %arg2: tensor) { + // CHECK: %2 = mhlo.add %arg1, %arg2 : tensor + // CHECK: mhlo.return %2 : tensor + // CHECK: } + %2 = mhlo.reduce(%arg0 init: %1) across dimensions = [2] : (tensor<4x101x4096xf32>, tensor) -> tensor<4x101xf32> + reducer(%arg216: tensor, %arg217: tensor) { + %2869 = mhlo.add %arg216, %arg217 : tensor + mhlo.return %2869 : tensor + } + return %2 : tensor<4x101xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x32x101x64xbf16>, %arg1: tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> attributes{tf.entry_function = {input_dynamic_dims = "0:3"}}{ + // CHECK: %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x32x101x?xbf16>, tensor<4x32x101x64xbf16>) -> tensor<4x32x101x?xbf16> + %1 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x32x101x64xbf16>, tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> + return %1 : tensor<4x32x101x128xbf16> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x32x101x64xbf16>, %arg1: tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> attributes{tf.entry_function = {input_dynamic_dims = "1:1"}}{ + // CHECK: %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x32x101x64xbf16>, tensor<4x?x101x64xbf16>) -> tensor<4x32x101x128xbf16> + %1 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x32x101x64xbf16>, tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> + return %1 : tensor<4x32x101x128xbf16> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<4x32x101x64xbf16>, %arg1: tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> attributes{tf.entry_function = {input_dynamic_dims = "1:1|0:1"}}{ + // CHECK: %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x?x101x64xbf16>, tensor<4x?x101x64xbf16>) -> tensor<4x?x101x128xbf16> + %1 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 3 : i64} : (tensor<4x32x101x64xbf16>, tensor<4x32x101x64xbf16>) -> tensor<4x32x101x128xbf16> + return %1 : tensor<4x32x101x128xbf16> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<404x1xi64>, %arg2: tensor<404x4096xf32>) -> tensor<32001x4096xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0|2:0"}}{ + + // CHECK: %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor + // CHECK: mhlo.return %1 : tensor + // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<32001x4096xf32>, tensor, tensor) -> tensor + %1 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg216: tensor, %arg217: tensor): + %2869 = mhlo.add %arg216, %arg217 : tensor + mhlo.return %2869 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<32001x4096xf32>, tensor<404x1xi64>, tensor<404x4096xf32>) -> tensor<32001x4096xf32> + return %1 : tensor<32001x4096xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0|0:1"}}{ + // CHECK: %c1_i64 = arith.constant 1 : i64 + // CHECK: %c1 = arith.constant 1 : index + // CHECK: %dim = tensor.dim %arg0, %c1 : tensor<32001x?xf32> + // CHECK: %0 = arith.index_cast %dim : index to i64 + // CHECK: %from_elements = tensor.from_elements %c1_i64, %0 : tensor<2xi64> + // CHECK: %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %from_elements) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<32001x?xf32>, tensor, tensor<2xi64>) -> tensor + %1 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 4096]> : tensor<2xi64>} : (tensor<32001x4096xf32>, tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> + return %1 : tensor<4x101x4096xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0,1"}}{ + // CHECK: %cst = arith.constant dense<[1, 4096]> : tensor<2xi64> + // CHECK: %0 = "mhlo.dynamic_gather"(%arg0, %arg1, %cst) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<32001x4096xf32>, tensor, tensor<2xi64>) -> tensor + %1 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 4096]> : tensor<2xi64>} : (tensor<32001x4096xf32>, tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> + return %1 : tensor<4x101x4096xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0"}}{ + // CHECK: %cst = arith.constant dense<[1, 4096]> : tensor<2xi64> + // CHECK: %0 = "mhlo.dynamic_gather"(%arg0, %arg1, %cst) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<32001x4096xf32>, tensor, tensor<2xi64>) -> tensor + %1 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 4096]> : tensor<2xi64>} : (tensor<32001x4096xf32>, tensor<4x101x1xi64>) -> tensor<4x101x4096xf32> + return %1 : tensor<4x101x4096xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<4x101x1xi64>) -> tensor<4x101x2048xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0"}}{ + // CHECK: %cst = arith.constant dense<[1, 2048]> : tensor<2xi64> + // CHECK: %0 = "mhlo.dynamic_gather"(%arg0, %arg1, %cst) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<32001x4096xf32>, tensor, tensor<2xi64>) -> tensor + %1 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2048]> : tensor<2xi64>} : (tensor<32001x4096xf32>, tensor<4x101x1xi64>) -> tensor<4x101x2048xf32> + return %1 : tensor<4x101x2048xf32> +} + +// ----- +// CHECK-LABEL: main +func.func @main(%arg0: tensor<32001x4096xf32>, %arg1: tensor<4x101x1xi64>) -> tensor<4x101x2048xf32> attributes{tf.entry_function = {input_dynamic_dims = "1:0|0:1"}}{ + // CHECK: %cst = arith.constant dense<[1, 2048]> : tensor<2xi64> + // CHECK: %0 = "mhlo.dynamic_gather"(%arg0, %arg1, %cst) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<32001x?xf32>, tensor, tensor<2xi64>) -> tensor + %1 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2048]> : tensor<2xi64>} : (tensor<32001x4096xf32>, tensor<4x101x1xi64>) -> tensor<4x101x2048xf32> + return %1 : tensor<4x101x2048xf32> } \ No newline at end of file