Skip to content

Commit

Permalink
rename then remove from parsed module
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Dec 10, 2024
1 parent 68138f6 commit 8eb71cc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,9 @@ function hlo_call(
# Change function name
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
end
end

for op in operations
MLIR.IR.rmfromparent!(op)
push!(top_level_block, op)
end
Expand Down
28 changes: 16 additions & 12 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ end
]
x = ConcreteRArray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0])
@test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0] ==
@jit Ops.clamp(_min, x, _max)
@jit Ops.clamp(_min, x, _max)
end
end

Expand Down Expand Up @@ -166,7 +166,7 @@ end
0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im
])
@test [1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im]
@jit Ops.cosine(x)
@jit Ops.cosine(x)
end
end

Expand Down Expand Up @@ -216,7 +216,7 @@ end
# NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation
@test sum(a .* b) @jit f1(a, b)
@test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b)))
@jit fouter(a, b)
@jit fouter(a, b)
@test a .* b @jit fouter_batch1(a, b)
end

Expand Down Expand Up @@ -415,7 +415,7 @@ end
# on unsigned integers: (1) bitcast, (2) change sign and (3) bitcast
x = ConcreteRArray(UInt[0, 1, 10])
@test reinterpret(UInt, Base.checked_neg.(reinterpret.(Int, Array(x)))) ==
@jit Ops.negate(x)
@jit Ops.negate(x)

x = ConcreteRArray([-1.0, 0.0, 1.0, 10.0])
@test [1.0, 0.0, -1.0, -10.0] @jit Ops.negate(x)
Expand Down Expand Up @@ -639,7 +639,7 @@ end
0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im
])
@test [0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im]
@jit Ops.sine(x)
@jit Ops.sine(x)
end
end

Expand Down Expand Up @@ -847,7 +847,7 @@ end
x = ConcreteRArray([-1.0, 0.0, 1.0, 1.0, 2.5])
m = ConcreteRArray([3.0, 3.0, 2.0, 3.0, 4.0])
@test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x))
@jit Ops.polygamma(m, x)
@jit Ops.polygamma(m, x)
end
end

Expand Down Expand Up @@ -926,14 +926,14 @@ end
Ops.hlo_call(
"""
module {
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
Reactant.to_rarray(Float32[1, 2, 3]),
Expand All @@ -951,7 +951,9 @@ function f_multiple_hlo_calls(x, y)
return %0 : tensor<3xf32>
}
}
""", x, y,
""",
x,
y,
)
return Ops.hlo_call(
"""
Expand All @@ -961,7 +963,9 @@ function f_multiple_hlo_calls(x, y)
return %0 : tensor<3xf32>
}
}
""", x, y
""",
x,
y,
)
end

Expand Down

0 comments on commit 8eb71cc

Please sign in to comment.