Skip to content

Commit

Permalink
Support more tropical types (#24)
Browse files Browse the repository at this point in the history
* more tropical element types

* upgrade TropicalNumbers
  • Loading branch information
GiggleLiu authored Sep 24, 2023
1 parent 5ed2f0a commit 88b9e2b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 132 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TropicalGEMM"
uuid = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
authors = ["GiggleLiu <[email protected]> and contributors"]
version = "0.1.10"
version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -15,7 +15,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
LoopVectorization = "0.12.4"
Octavian = "0.2.18, 0.3"
PrecompileTools = "1"
TropicalNumbers = "0.2.3, 0.3, 0.4, 0.5"
TropicalNumbers = "0.6"
VectorizationBase = "0.21"
julia = "1"

Expand Down
26 changes: 13 additions & 13 deletions src/TropicalGEMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, In
using VectorizationBase: contiguous_batch_size, contiguous_axis, val_stride_rank, bytestrides, offsets, memory_reference,
vmaximum, fmap, FloatingTypes, IntegerIndex, LazyMulAdd

export Tropical, TropicalF64, TropicalF32
export Tropical, TropicalF64, TropicalF32, TropicalMinPlus, TropicalMinPlusF64, TropicalMinPlusF32, TropicalMaxMul, TropicalMaxMulF64, TropicalMaxMulF32, TropicalMaxPlus, TropicalMaxPlusF64, TropicalMaxPlusF32, BlasSemiringTypes

include("fallbacks.jl")
include("gemm.jl")

import PrecompileTools
PrecompileTools.@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# precompile file and potentially make loading faster.
PrecompileTools.@compile_workload begin
for T in (Float32, Float64, Int64)
A = Tropical.(rand(T, 10, 10))
O = Tropical.(rand(T, 10, 10))
LinearAlgebra.mul!(O, A, A)
end
end
end
# import PrecompileTools
# PrecompileTools.@setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# PrecompileTools.@compile_workload begin
# for T in (Float32, Float64, Int64)
# A = Tropical.(rand(T, 10, 10))
# O = Tropical.(rand(T, 10, 10))
# LinearAlgebra.mul!(O, A, A)
# end
# end
# end

end
10 changes: 6 additions & 4 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ end

# For types not nativelly supported, go to fallback.
# Overwrite the `mul!` in LinearAlgebra (also changes the behavior of `*` in Base)!
function LinearAlgebra.mul!(o::MaybeAdjOrTransMat{TO}, a::MaybeAdjOrTransMat{<:Tropical}, b::MaybeAdjOrTransMat{<:Tropical}, α::Number, β::Number) where TO<:Tropical
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
for TT in [:Tropical, :TropicalMinPlus, TropicalMaxMul]
@eval function LinearAlgebra.mul!(o::MaybeAdjOrTransMat{TO}, a::MaybeAdjOrTransMat{<:$TT}, b::MaybeAdjOrTransMat{<:$TT}, α::Number, β::Number) where TO<:$TT
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
end
end

Base.:*(a::T, b::StaticInt{0}) where T<:TropicalTypes = zero(T)
Expand Down
Loading

0 comments on commit 88b9e2b

Please sign in to comment.