Skip to content

Wrapping nfloat #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
93 changes: 62 additions & 31 deletions src/ArbCall/ArbArgTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ Struct for conversion between C argument types in the Arb
documentation and Julia types.
"""
struct ArbArgTypes
supported::Dict{String,DataType}
supported::Dict{String,Union{DataType,UnionAll}}
unsupported::Set{String}
supported_reversed::Dict{DataType,String}
supported_reversed::Dict{Union{DataType,UnionAll},String}
end

function Base.getindex(arbargtypes::ArbArgTypes, key::AbstractString)
Expand All @@ -22,65 +22,96 @@ end

# Define the conversions we use for the rest of the code
const arbargtypes = ArbArgTypes(
Dict{String,DataType}(
Dict{String,Union{DataType,UnionAll}}(
# Primitive
"void" => Cvoid,
"void *" => Ptr{Cvoid},
"int" => Cint,
"slong" => Int,
"ulong" => UInt,
"double" => Cdouble,
"double *" => Vector{Float64},
"double" => Float64,
"complex_double" => ComplexF64,
"void *" => Ptr{Cvoid},
"char *" => Cstring,
"slong *" => Vector{Int},
"ulong *" => Vector{UInt},
"double *" => Vector{Float64},
"complex_double *" => Vector{ComplexF64},
# gmp.h
"mpz_t" => BigInt,
# mpfr.h
"mpfr_t" => BigFloat,
"mpfr_rnd_t" => Base.MPFR.MPFRRoundingMode,
# mag.h
"mag_t" => Mag,
# nfloat.h
"nfloat_ptr" => NFloat,
"nfloat_srcptr" => NFloat,
"gr_ctx_t" => nfloat_ctx_struct, # Actually in gr_types.h
# arf.h
"arf_t" => Arf,
"arf_rnd_t" => arb_rnd,
# acf.h
"acf_t" => Acf,
# arb.h
"arb_t" => Arb,
"acb_t" => Acb,
"mag_t" => Mag,
"arb_srcptr" => ArbVector,
"arb_ptr" => ArbVector,
"acb_srcptr" => AcbVector,
"arb_srcptr" => ArbVector,
# acb.h
"acb_t" => Acb,
"acb_ptr" => AcbVector,
"acb_srcptr" => AcbVector,
# arb_poly.h
"arb_poly_t" => ArbPoly,
# acb_poly.h
"acb_poly_t" => AcbPoly,
# arb_mat.h
"arb_mat_t" => ArbMatrix,
# acb_mat.h
"acb_mat_t" => AcbMatrix,
"arf_rnd_t" => arb_rnd,
"mpfr_t" => BigFloat,
"mpfr_rnd_t" => Base.MPFR.MPFRRoundingMode,
"mpz_t" => BigInt,
"char *" => Cstring,
"slong *" => Vector{Int},
"ulong *" => Vector{UInt},
),
Set(["FILE *", "fmpr_t", "fmpr_rnd_t", "flint_rand_t", "bool_mat_t"]),
Dict{DataType,String}(
Set(["FILE *", "flint_rand_t"]),
Dict{Union{DataType,UnionAll},String}(
# Primitive
Cvoid => "void",
Ptr{Cvoid} => "void *",
Cint => "int",
Int => "slong",
UInt => "ulong",
Cdouble => "double",
Vector{Float64} => "double *",
Float64 => "double",
ComplexF64 => "complex_double",
Ptr{Cvoid} => "void *",
Cstring => "char *",
Vector{Int} => "slong *",
Vector{UInt} => "ulong *",
Vector{Float64} => "double *",
Vector{ComplexF64} => "complex_double *",
# gmp.h
BigInt => "mpz_t",
# mpfr.h
BigFloat => "mpfr_t",
Base.MPFR.MPFRRoundingMode => "mpfr_rnd_t",
# mag.h
Mag => "mag_t",
# nfloat.h
NFloat => "nfloat_ptr",
nfloat_ctx_struct => "gr_ctx_t", # Actually in gr_types.h
# arf.h
Arf => "arf_t",
arb_rnd => "arf_rnd_t",
# acf.h
Acf => "acf_t",
# arb.h
Arb => "arb_t",
Acb => "acb_t",
Mag => "mag_t",
ArbVector => "arb_ptr",
# acb.h
Acb => "acb_t",
AcbVector => "acb_ptr",
# arb_poly.h
ArbPoly => "arb_poly_t",
# acb_poly.h
AcbPoly => "acb_poly_t",
# arb_mat.h
ArbMatrix => "arb_mat_t",
# acb_mat.h
AcbMatrix => "acb_mat_t",
arb_rnd => "arf_rnd_t",
BigFloat => "mpfr_t",
Base.MPFR.MPFRRoundingMode => "mpfr_rnd_t",
BigInt => "mpz_t",
Cstring => "char *",
Vector{Int} => "slong *",
Vector{UInt} => "ulong *",
),
)
4 changes: 4 additions & 0 deletions src/ArbCall/ArbCall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import ..Arblib:
cstructtype,
arb_rnd,
mag_struct,
nfloat_struct,
nfloat_ctx_struct,
_get_nfloat_ctx_struct,
arf_struct,
acf_struct,
arb_struct,
Expand All @@ -23,6 +26,7 @@ import ..Arblib:
arb_mat_struct,
acb_mat_struct,
MagLike,
NFloatLike,
ArfLike,
AcfLike,
ArbLike,
Expand Down
2 changes: 1 addition & 1 deletion src/ArbCall/ArbFPWrapFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function jlargs(af::ArbFPWrapFunction)
cargs[end] == Carg{Cint}(:flags, false) ||
throw(ArgumentError("expected last argument to be flags::Cint, got $(cargs[end])"))

args = [:($(name(carg))::$(jltype(carg))) for carg in cargs[n+1:end-1]]
args = [jlarg(carg) for carg in cargs[n+1:end-1]]

if basetype(af) == Float64
kwargs = [
Expand Down
52 changes: 37 additions & 15 deletions src/ArbCall/ArbFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,23 @@ is_series_method(af::ArbFunction) =
(jltype(first(arguments(af))) <: Union{Arblib.ArbPolyLike,Arblib.AcbPolyLike})

const jlfname_prefixes = (
"double",
"cdouble",
"mag",
"nfloat",
"ctx",
"arf",
"acf",
"arb",
"acb",
"mag",
"mat",
"vec",
"poly",
"scalar",
"mat",
"fpwrap",
"double",
"cdouble",
"scalar",
)
const jlfname_suffixes =
("si", "ui", "d", "mag", "arf", "acf", "arb", "acb", "mpz", "mpfr", "str")
("si", "ui", "d", "str", "mpz", "mpfr", "mag", "nfloat", "arf", "acf", "arb", "acb")

function jlfname(
arbfname::AbstractString;
Expand Down Expand Up @@ -117,15 +119,15 @@ jlfname_series(af::ArbFunction) = jlfname_series(arbfname(af))
function jlargs(af::ArbFunction; argument_detection::Bool = true)
cargs = arguments(af)

jl_arg_names_types = Tuple{Symbol,Any}[]
args = Expr[]
kwargs = Expr[]

prec_kwarg = false
rnd_kwarg = false
flags_kwarg = false
for (i, carg) in enumerate(cargs)
if !argument_detection
push!(jl_arg_names_types, (name(carg), jltype(carg)))
push!(args, jlarg(carg))
continue
end

Expand All @@ -146,13 +148,13 @@ function jlargs(af::ArbFunction; argument_detection::Bool = true)
push!(kwargs, extract_rounding_argument(carg))
elseif i > 1 && is_length_argument(carg, cargs[i-1])
push!(kwargs, extract_length_argument(carg, cargs[i-1]))
elseif is_ctx_argument(carg)
push!(kwargs, extract_ctx_argument(carg, first(cargs)))
else
push!(jl_arg_names_types, (name(carg), jltype(carg)))
push!(args, jlarg(carg))
end
end

args = [:($a::$T) for (a, T) in jl_arg_names_types]

return args, kwargs
end

Expand Down Expand Up @@ -199,13 +201,22 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))

returnT = returntype(af)
cargs = arguments(af)
where_type_parameters = unique(reduce(vcat, type_parameters.(cargs)))

func_full_args_call = :($jl_fname($(jl_full_args...)))

func_full_args_header = if isempty(where_type_parameters)
func_full_args_call
else
Expr(:where, func_full_args_call, where_type_parameters...)
end

func_full_args = :(
function $jl_fname($(jl_full_args...))
func_full_args_body = :(
begin
__ret = ccall(
Arblib.@libflint($(arbfname(af))),
$returnT,
$(Expr(:tuple, ctype.(cargs)...)),
$(Expr(:tuple, carg_expr.(cargs)...)),
$(name.(cargs)...),
)
$(
Expand All @@ -220,7 +231,11 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
end
)

func_full_args = Expr(:function, func_full_args_header, func_full_args_body)

if is_series_method(af)
@assert isempty(where_type_parameters) # Currently not supported for series methods

# Note that this currently doesn't respect any custom function
# name given as an argument.
jl_fname_series = jlfname_series(af)
Expand All @@ -242,9 +257,16 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
if isempty(jl_kwargs)
return code
else
func_kwarg_args_call = :($jl_fname($(jl_args...); $(jl_kwargs...)))
func_kwarg_args_header = if isempty(where_type_parameters)
func_kwarg_args_call
else
Expr(:where, func_kwarg_args_call, where_type_parameters...)
end

return quote
$code
$jl_fname($(jl_args...); $(jl_kwargs...)) = $jl_fname($(name.(cargs)...))
$func_kwarg_args_header = $jl_fname($(name.(cargs)...))
end
end
end
Expand Down
Loading
Loading