Skip to content

Commit

Permalink
C interface for hsl_subset (#221)
Browse files Browse the repository at this point in the history
* Remove hsl_subset/lapack.jl

* Update the C interface of libhsl

* Update the generator of wrappers for hsl_subset

* Update Project.toml

* Add the C interface for hsl_subset
  • Loading branch information
amontoison authored Dec 26, 2024
1 parent 486f1a9 commit 8df68fa
Show file tree
Hide file tree
Showing 30 changed files with 4,551 additions and 85 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ Quadmath = "be4d8f0f-7fa4-5f49-b795-2f01399ab2dd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
HSL_jll = "4, 2024"
Libdl = "1.9"
LinearAlgebra = "1.9"
OpenBLAS32_jll = "0.3.9"
Quadmath = "0.5.10"
julia = "^1.6.0"
SparseArrays = "1.9"
julia = "1.9"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 1 addition & 1 deletion gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
julia = "1.6"
HSL_jll = "=2024.11.28"
HSL_jll = "=2024.12.10"
158 changes: 129 additions & 29 deletions gen/rewriter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,45 @@ structure_modifications = Dict("_control_s}" => "_control{Float32}}",
"_sinfo_s}" => "_sinfo{Float32}}",
"_sinfo_d}" => "_sinfo{Float64}}")

function rewrite!(path::String, name::String, optimized::Bool)
function rewrite!(library::String, path::String, name::String, optimized::Bool)
if library == "libhsl"
libhsl_rewrite!(path, name, optimized)
elseif library == "hsl_subset"
hsl_subset_rewrite!(path, name, optimized)
else
error("The library $library is not supported.")
end
end

function libhsl_rewrite!(path::String, name::String, optimized::Bool)
text = read(path, String)
if name == "libhsl"
updated_text = replace(text, "# no prototype is found for this function at libhsl.h:44:6, please use with caution\n" => "")
updated_text = replace(updated_text, "major, minor, patch)\n" => ")\n major = Ref{Cint}(0)\n minor = Ref{Cint}(0)\n patch = Ref{Cint}(0)\n")
updated_text = replace(updated_text, "Ptr{Cint}" => "Ref{Cint}")
updated_text = replace(updated_text, " @ccall" => " @ccall")
updated_text = replace(updated_text, "Cvoid\n" => "Cvoid\n return VersionNumber(major[], minor[], patch[])\n")
text = replace(text, "# no prototype is found for this function at libhsl.h:44:6, please use with caution\n" => "")
text = replace(text, "major, minor, patch)\n" => ")\n major = Ref{Cint}(0)\n minor = Ref{Cint}(0)\n patch = Ref{Cint}(0)\n")
text = replace(text, "Ptr{Cint}" => "Ref{Cint}")
text = replace(text, " @ccall" => " @ccall")
text = replace(text, "Cvoid\n" => "Cvoid\n return VersionNumber(major[], minor[], patch[])\n")
else
solver = split(name, "_")[2]
updated_text = replace(text, "struct $solver" => "mutable struct $solver")
text = replace(text, "struct $solver" => "mutable struct $solver")
if optimized
for (keys, vals) in type_modifications
updated_text = replace(updated_text, solver * keys => vals)
text = replace(text, solver * keys => vals)
end
for (keys, vals) in structure_modifications
updated_text = replace(updated_text, solver * keys => solver * vals)
text = replace(text, solver * keys => solver * vals)
end
for structure in ("control", "info", "solve_control", "ainfo", "sinfo", "finfo")
updated_text = replace(updated_text, "mutable struct $(solver)_$(structure)_s" => "mutable struct $(solver)_$(structure){T}")
updated_text = replace(updated_text, "mutable struct $(solver)_$(structure)_i" => "mutable struct $(solver)_$(structure){T}")
updated_text = replace(updated_text, "Ptr{$(solver)_$(structure)" => "Ref{$(solver)_$(structure)")
text = replace(text, "mutable struct $(solver)_$(structure)_s" => "mutable struct $(solver)_$(structure){T}")
text = replace(text, "mutable struct $(solver)_$(structure)_i" => "mutable struct $(solver)_$(structure){T}")
text = replace(text, "Ptr{$(solver)_$(structure)" => "Ref{$(solver)_$(structure)")
end
updated_text = replace(updated_text, "::Float32\n" => "::T\n")
updated_text = replace(updated_text, "Float32}\n" => "T}\n") # NTuple{N, Float32} → NTuple{N, T}
text = replace(text, "::Float32\n" => "::T\n")
text = replace(text, "Float32}\n" => "T}\n") # NTuple{N, Float32} → NTuple{N, T}

# Add two constructors for each structure
blocks = split(updated_text, "end\n", keepempty=false)
updated_text = ""
blocks = split(text, "end\n", keepempty=false)
text = ""
for code in blocks
if contains(code, "mutable struct")
structure = code * "end\n"
Expand All @@ -84,39 +94,129 @@ function rewrite!(path::String, name::String, optimized::Bool)
end
end
structure = replace(structure, "end\n" => "\n " * structure_name * "($arguments) where T = new($arguments)\nend\n")
updated_text = updated_text * structure
text = text * structure
else
updated_text = updated_text * code * "end\n"
text = text * code * "end\n"
end
end

# Special cases where the structures are not parameterized.
if name == "hsl_ma48"
for type in ("T", "Float32", "Float64")
updated_text = replace(updated_text, "$(solver)_sinfo{$type}" => "$(solver)_sinfo")
updated_text = replace(updated_text, Regex("$(solver)_sinfo(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_sinfo\\1"))
text = replace(text, "$(solver)_sinfo{$type}" => "$(solver)_sinfo")
text = replace(text, Regex("$(solver)_sinfo(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_sinfo\\1"))
end
end

if name == "hsl_mc64"
for type in ("T", "Float32", "Float64")
updated_text = replace(updated_text, "$(solver)_control{$type}" => "$(solver)_control")
updated_text = replace(updated_text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
updated_text = replace(updated_text, "$(solver)_info{$type}" => "$(solver)_info")
updated_text = replace(updated_text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
text = replace(text, "$(solver)_control{$type}" => "$(solver)_control")
text = replace(text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
text = replace(text, "$(solver)_info{$type}" => "$(solver)_info")
text = replace(text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
end
end

if name == "hsl_mc68" || name == "hsl_mc78" || name == "hsl_mc79"
for type in ("T", "Cint", "Clong")
updated_text = replace(updated_text, "$(solver)_control{$type}" => "$(solver)_control")
updated_text = replace(updated_text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
updated_text = replace(updated_text, "$(solver)_info{$type}" => "$(solver)_info")
updated_text = replace(updated_text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
text = replace(text, "$(solver)_control{$type}" => "$(solver)_control")
text = replace(text, Regex("$(solver)_control(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_control\\1"))
text = replace(text, "$(solver)_info{$type}" => "$(solver)_info")
text = replace(text, Regex("$(solver)_info(\\([^)]*\\)) where T") => SubstitutionString("$(solver)_info\\1"))
end
end
end
end
write(path, updated_text)
write(path, text)
(name != "libhsl") && format_file(path, YASStyle())
end

function hsl_subset_rewrite!(path::String, name::String, optimized::Bool)
text = read(path, String)
structures = ""
info_structures = Tuple{String, String, Bool}[]
if optimized
text = replace(text, "struct " => "mutable struct ")
text = replace(text, "hsl_longc_" => Int64)

blocks = split(text, "end\n")
text = ""
for (index, code) in enumerate(blocks)
if contains(code, "function")
for (ipc_, rpc_, suffix, lib) in (("Int32", "Float32" , "_s" , "libhsl_subset" ),
("Int32", "Float64" , "_d" , "libhsl_subset" ),
("Int32", "Float128", "_q" , "libhsl_subset" ),
("Int64", "Float32" , "_s_64", "libhsl_subset_64"),
("Int64", "Float64" , "_d_64", "libhsl_subset_64"),
("Int64", "Float128", "_q_64", "libhsl_subset_64"))
# We only want to generate two methods (Int32 / Int64) for hsl_mc68
(name == "hsl_mc68") && (rpc_ != "Float64") && continue

fname = split(split(code, "function ")[2], "(")[1]
fname_generic = fname[1:end-2]
pp_fname = fname[1:end-2] * suffix
routine = code * "end\n"
if name == "hsl_mc68"
endswith(fname, "_i") || error("The symbol $fname should have the suffix _i")
routine = replace(routine, "function $fname(" => "function $(fname_generic)(::Type{$ipc_}, ")
else
endswith(fname, "_d") || error("The symbol $fname should have the suffix _d")
routine = replace(routine, "function $fname(" => "function $(fname_generic)(::Type{$rpc_}, ::Type{$ipc_}, ")
end
routine = replace(routine, "libhsl.$fname(" => "$lib.$(pp_fname)(")
routine = replace(routine, "ipc_" => ipc_)
routine = replace(routine, "rpc_" => rpc_)

# Update the type of the structures
routine = replace(routine, "_d}" => "_d{$rpc_,$ipc_}}")
routine = replace(routine, "_i}" => "_i{$rpc_,$ipc_}}")

# Float128 should be passed by value as a Cfloat128
routine = replace(routine, "::Float128" => "::Cfloat128")

text = text * routine * "\n"
end
elseif contains(code, "struct")
structure = code * "end\n"
structure_name = split(split(code, "struct ")[2], "\n")[1] |> String
generic_structure_name = structure_name[1:end-2] |> String
generic_structure_name = 'M' * generic_structure_name[2:end]
generic_structure_name = replace(generic_structure_name, "_solve_control" => "SolveControl")
generic_structure_name = replace(generic_structure_name, "_control" => "Control")
generic_structure_name = replace(generic_structure_name, "_ainfo" => "Ainfo")
generic_structure_name = replace(generic_structure_name, "_finfo" => "Finfo")
generic_structure_name = replace(generic_structure_name, "_sinfo" => "Sinfo")
generic_structure_name = replace(generic_structure_name, "_info" => "Info")
structure = replace(structure, "rpc_" => "T")
structure = replace(structure, "ipc_" => "INT")
if !contains(code, "rpc_")
structure = replace(structure, structure_name => generic_structure_name * "{INT}")
push!(info_structures, (structure_name, generic_structure_name, false))
else
structure = replace(structure, structure_name => generic_structure_name * "{T,INT}")
push!(info_structures, (structure_name, generic_structure_name, true))
end
structures = structures * structure * "\n"
else
text = text * code
end
end
end
text = structures * "\n" * text
startswith(text, '\n') && (text = text[2:end])

# Rename the structures in the wrappers
for (old_struct, new_struct, bool) in info_structures
if bool
text = replace(text, "Ptr{$old_struct" => "Ref{$new_struct")
else
for precision in ("Float32", "Float64", "Float128")
text = replace(text, "Ptr{$old_struct{$precision,Int32}}" => "Ref{$new_struct{Int32}}")
text = replace(text, "Ptr{$old_struct{$precision,Int64}}" => "Ref{$new_struct{Int64}}")
end
end
end

write(path, text)
format_file(path, YASStyle())
end
Loading

0 comments on commit 8df68fa

Please sign in to comment.