@@ -2313,4 +2313,101 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
2313
2313
end
2314
2314
end
2315
2315
2316
+ """
2317
+ reduce(
2318
+ x::TracedRArray{T},
2319
+ init_values::TracedRNumber{T},
2320
+ dimensions::Vector{Int},
2321
+ fn::Function,
2322
+ location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
2323
+ )
2324
+
2325
+ Applies a reduction function `fn` along the specified `dimensions` of input `x`, starting from `init_values`.
2326
+
2327
+ # Arguments
2328
+
2329
+ - `x`: The input array.
2330
+ - `init_values`: The initial value.
2331
+ - `dimensions`: The dimensions to reduce along.
2332
+ - `fn`: A binary operator.
2333
+
2334
+ !!! warning
2335
+ This reduction operation follows StableHLO semantics. The key difference between this operation and Julia's built-in `reduce` is explained below:
2336
+
2337
+ - The function `fn` and the initial value `init_values` must form a **monoid**, meaning:
2338
+ - `fn` must be an **associative** binary operation.
2339
+ - `init_values` must be the **identity element** associated with `fn`.
2340
+ - This constraint ensures consistent results across all implementations.
2341
+
2342
+ If `init_values` is not the identity element of `fn`, the results may vary between CPU and GPU executions. For example:
2343
+
2344
+ ```julia
2345
+ A = [1 3; 2 4;;; 5 7; 6 8;;; 9 11; 10 12]
2346
+ init_values = 2
2347
+ dimensions = [1, 3]
2348
+ ```
2349
+
2350
+ - **CPU version & Julia's `reduce`**:
2351
+ - Reduce along dimension 1 → `[(15) (21); (18) (24)]`
2352
+ - Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
2353
+
2354
+ - **GPU version**:
2355
+ - Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
2356
+ - Reduce along dimension 3 → `[37 49]`
2357
+ """
2358
+ @noinline function reduce (
2359
+ x:: TracedRArray{T} ,
2360
+ init_values:: TracedRNumber{T} ,
2361
+ dimensions:: Vector{Int} ,
2362
+ fn:: Function ,
2363
+ location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2364
+ ) where {T}
2365
+ reduced_shape = Tuple (deleteat! (collect (size (x)), dimensions))
2366
+
2367
+ result_type = mlir_type (TracedRArray{T,length (reduced_shape)}, reduced_shape)
2368
+
2369
+ sample_inputs = [
2370
+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ),
2371
+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 )
2372
+ ]
2373
+
2374
+ func =
2375
+ Reactant. TracedUtils. make_mlir_fn (
2376
+ fn,
2377
+ (sample_inputs),
2378
+ (),
2379
+ " reduce_fn" ,
2380
+ false ;
2381
+ args_in_result= :none ,
2382
+ return_dialect= :stablehlo ,
2383
+ ). f
2384
+ @assert MLIR. IR. nregions (func) == 1
2385
+ fn_name = String (
2386
+ MLIR. IR. attr (func, String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ()))
2387
+ )
2388
+ ftype_attr = MLIR. IR. attr (func, " function_type" )
2389
+ ftype = MLIR. IR. Type (ftype_attr)
2390
+ @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (T)) error (
2391
+ " $fn return type is not tensor<i1>"
2392
+ )
2393
+ fn = MLIR. IR. Region ()
2394
+ MLIR. API. mlirRegionTakeBody (fn, MLIR. IR. region (func, 1 ))
2395
+ MLIR. IR. rmfromparent! (func)
2396
+
2397
+ dimensions = MLIR. IR. Attribute (dimensions .- 1 )
2398
+
2399
+ res = MLIR. IR. result (
2400
+ stablehlo. reduce (
2401
+ [x. mlir_data],
2402
+ [init_values. mlir_data];
2403
+ result_0= [result_type],
2404
+ dimensions= dimensions,
2405
+ body= fn,
2406
+ location= location,
2407
+ ),
2408
+ )
2409
+
2410
+ return TracedRArray {T,length(reduced_shape)} ((), res, reduced_shape)
2411
+ end
2412
+
2316
2413
end # module Ops
0 commit comments