Skip to content

Commit 7f7d678

Browse files
committed
nfloat: Proof of concept wrapping of new nfloat module in Flint
1 parent 5844ecc commit 7f7d678

File tree

10 files changed

+505
-14
lines changed

10 files changed

+505
-14
lines changed

src/ArbCall/ArbArgTypes.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ Struct for conversion between C argument types in the Arb
99
documentation and Julia types.
1010
"""
1111
struct ArbArgTypes
12-
supported::Dict{String,DataType}
12+
supported::Dict{String,Union{DataType,UnionAll}}
1313
unsupported::Set{String}
14-
supported_reversed::Dict{DataType,String}
14+
supported_reversed::Dict{Union{DataType,UnionAll},String}
1515
end
1616

1717
function Base.getindex(arbargtypes::ArbArgTypes, key::AbstractString)
@@ -22,7 +22,7 @@ end
2222

2323
# Define the conversions we use for the rest of the code
2424
const arbargtypes = ArbArgTypes(
25-
Dict{String,DataType}(
25+
Dict{String,Union{DataType,UnionAll}}(
2626
# Primitive
2727
"void" => Cvoid,
2828
"int" => Cint,
@@ -43,6 +43,10 @@ const arbargtypes = ArbArgTypes(
4343
"mpfr_rnd_t" => Base.MPFR.MPFRRoundingMode,
4444
# mag.h
4545
"mag_t" => Mag,
46+
# nfloat.h
47+
"nfloat_ptr" => NFloat,
48+
"nfloat_srcptr" => NFloat,
49+
"gr_ctx_t" => nfloat_ctx_struct, # Actually in gr_types.h
4650
# arf.h
4751
"arf_t" => Arf,
4852
"arf_rnd_t" => arb_rnd,
@@ -66,7 +70,7 @@ const arbargtypes = ArbArgTypes(
6670
"acb_mat_t" => AcbMatrix,
6771
),
6872
Set(["FILE *", "flint_rand_t"]),
69-
Dict{DataType,String}(
73+
Dict{Union{DataType,UnionAll},String}(
7074
# Primitive
7175
Cvoid => "void",
7276
Cint => "int",
@@ -87,6 +91,9 @@ const arbargtypes = ArbArgTypes(
8791
Base.MPFR.MPFRRoundingMode => "mpfr_rnd_t",
8892
# mag.h
8993
Mag => "mag_t",
94+
# nfloat.h
95+
NFloat => "nfloat_ptr",
96+
nfloat_ctx_struct => "gr_ctx_t", # Actually in gr_types.h
9097
# arf.h
9198
Arf => "arf_t",
9299
arb_rnd => "arf_rnd_t",

src/ArbCall/ArbCall.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ import ..Arblib:
1212
cstructtype,
1313
arb_rnd,
1414
mag_struct,
15+
nfloat_struct,
16+
nfloat_ctx_struct,
17+
_get_nfloat_ctx_struct,
1518
arf_struct,
1619
acf_struct,
1720
arb_struct,
@@ -23,6 +26,7 @@ import ..Arblib:
2326
arb_mat_struct,
2427
acb_mat_struct,
2528
MagLike,
29+
NFloatLike,
2630
ArfLike,
2731
AcfLike,
2832
ArbLike,

src/ArbCall/ArbFunction.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ const jlfname_prefixes = (
6767
"double",
6868
"cdouble",
6969
"mag",
70+
"nfloat",
71+
"ctx",
7072
"arf",
7173
"acf",
7274
"arb",
@@ -78,7 +80,7 @@ const jlfname_prefixes = (
7880
"scalar",
7981
)
8082
const jlfname_suffixes =
81-
("si", "ui", "d", "str", "mpz", "mpfr", "mag", "arf", "acf", "arb", "acb")
83+
("si", "ui", "d", "str", "mpz", "mpfr", "mag", "nfloat", "arf", "acf", "arb", "acb")
8284

8385
function jlfname(
8486
arbfname::AbstractString;
@@ -146,6 +148,8 @@ function jlargs(af::ArbFunction; argument_detection::Bool = true)
146148
push!(kwargs, extract_rounding_argument(carg))
147149
elseif i > 1 && is_length_argument(carg, cargs[i-1])
148150
push!(kwargs, extract_length_argument(carg, cargs[i-1]))
151+
elseif is_ctx_argument(carg)
152+
push!(kwargs, extract_ctx_argument(carg, first(cargs)))
149153
else
150154
push!(args, jlarg(carg))
151155
end

src/ArbCall/Carg.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ jltype(::Carg{Vector{ComplexF64}}) = Vector{ComplexF64}
7171
jltype(::Carg{Base.MPFR.MPFRRoundingMode}) = Union{Base.MPFR.MPFRRoundingMode,RoundingMode}
7272
# mag.h
7373
jltype(::Carg{Mag}) = MagLike
74+
# nfloat.h
75+
jltype(::Carg{NFloat}) = NFloatLike
7476
# arf.h
7577
jltype(::Carg{Arf}) = ArfLike
7678
jltype(::Carg{arb_rnd}) = Union{arb_rnd,RoundingMode}
@@ -103,6 +105,8 @@ ctype(::Carg{T}) where {T<:Union{Mag,Arf,Acf,Arb,Acb,ArbPoly,AcbPoly,ArbMatrix,A
103105
Ref{cstructtype(T)}
104106
ctype(::Carg{T}) where {T<:Union{ArbVector,arb_vec_struct}} = Ptr{arb_struct}
105107
ctype(::Carg{T}) where {T<:Union{AcbVector,acb_vec_struct}} = Ptr{acb_struct}
108+
ctype(::Carg{T}) where {T<:NFloat} = Ref{nfloat_struct}
109+
ctype(::Carg{T}) where {T<:nfloat_ctx_struct} = Ref{nfloat_ctx_struct}
106110

107111
"""
108112
jlarg(ca::Carg{T}) where {T}
@@ -131,6 +135,8 @@ is_length_argument(ca::Carg, prev_ca::Carg) =
131135
rawtype(ca) == Int &&
132136
rawtype(prev_ca) (ArbVector, AcbVector)
133137

138+
is_ctx_argument(ca::Carg{T}) where {T} = T <: nfloat_ctx_struct
139+
134140
function extract_precision_argument(ca::Carg, first_ca::Carg)
135141
is_precision_argument(ca) ||
136142
throw(ArgumentError("argument is not a valid precision argument, $ca"))
@@ -161,6 +167,14 @@ function extract_length_argument(ca::Carg, prev_ca::Carg)
161167
return Expr(:kw, jlarg(ca), :(length($(name(prev_ca)))))
162168
end
163169

170+
# TODO: This needs to handle the case when it is not the first
171+
# argument we should get the context from. This happens for e.g.
172+
# nfloat_get_arf.
173+
function extract_ctx_argument(ca::Carg, first_ca::Carg)
174+
is_ctx_argument(ca) || throw(ArgumentError("argument is not a valid ctx argument, $ca"))
175+
return Expr(:kw, jlarg(ca), :(_get_nfloat_ctx_struct($(name(first_ca)))))
176+
end
177+
164178
"""
165179
is_fpwrap_res_argument(ca::Carg, T::Union{Float64,ComplexF64})
166180

src/ArbCall/parse.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ function parse_and_generate_arbdoc(
192192
out_dir = "src/arbcalls/";
193193
filenames = [
194194
"mag",
195+
"nfloat",
195196
"arf",
196197
"acf",
197198
"arb",

src/Arblib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ include("types.jl")
5454
include("hash.jl")
5555
include("serialize.jl")
5656

57+
include("nfloat.jl")
58+
5759
include("ArbCall/ArbCall.jl")
5860
import .ArbCall: @arbcall_str, @arbfpwrapcall_str
5961
include("manual_overrides.jl")
@@ -84,6 +86,7 @@ include("calc_integrate.jl")
8486
include("special-functions.jl")
8587

8688
include("arbcalls/mag.jl")
89+
include("arbcalls/nfloat.jl")
8790
include("arbcalls/arf.jl")
8891
include("arbcalls/acf.jl")
8992
include("arbcalls/arb.jl")

0 commit comments

Comments
 (0)