Skip to content

Commit 1e33b7f

Browse files
committed
NFloat: Improve handling of type parameters in ArbCall wrapper
1 parent 7f7d678 commit 1e33b7f

File tree

5 files changed

+83
-11
lines changed

5 files changed

+83
-11
lines changed

src/ArbCall/ArbFunction.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,22 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
201201

202202
returnT = returntype(af)
203203
cargs = arguments(af)
204+
where_type_parameters = unique(reduce(vcat, type_parameters.(cargs)))
204205

205-
func_full_args = :(
206-
function $jl_fname($(jl_full_args...))
206+
func_full_args_call = :($jl_fname($(jl_full_args...)))
207+
208+
func_full_args_header = if isempty(where_type_parameters)
209+
func_full_args_call
210+
else
211+
Expr(:where, func_full_args_call, where_type_parameters...)
212+
end
213+
214+
func_full_args_body = :(
215+
begin
207216
__ret = ccall(
208217
Arblib.@libflint($(arbfname(af))),
209218
$returnT,
210-
$(Expr(:tuple, ctype.(cargs)...)),
219+
$(Expr(:tuple, carg_expr.(cargs)...)),
211220
$(name.(cargs)...),
212221
)
213222
$(
@@ -222,7 +231,11 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
222231
end
223232
)
224233

234+
func_full_args = Expr(:function, func_full_args_header, func_full_args_body)
235+
225236
if is_series_method(af)
237+
@assert isempty(where_type_parameters) # Currently not supported for series methods
238+
226239
# Note that this currently doesn't respect any custom function
227240
# name given as an argument.
228241
jl_fname_series = jlfname_series(af)
@@ -244,9 +257,16 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
244257
if isempty(jl_kwargs)
245258
return code
246259
else
260+
func_kwarg_args_call = :($jl_fname($(jl_args...); $(jl_kwargs...)))
261+
func_kwarg_args_header = if isempty(where_type_parameters)
262+
func_kwarg_args_call
263+
else
264+
Expr(:where, func_kwarg_args_call, where_type_parameters...)
265+
end
266+
247267
return quote
248268
$code
249-
$jl_fname($(jl_args...); $(jl_kwargs...)) = $jl_fname($(name.(cargs)...))
269+
$func_kwarg_args_header = $jl_fname($(name.(cargs)...))
250270
end
251271
end
252272
end

src/ArbCall/Carg.jl

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ jltype(::Carg{ArbMatrix}) = ArbMatrixLike
9393
# acb_mat.h
9494
jltype(::Carg{AcbMatrix}) = AcbMatrixLike
9595

96+
type_parameters_from_type(type::UnionAll) =
97+
[type.var.name; type_parameters_from_type(type.body)]
98+
type_parameters_from_type(_) = Symbol[]
99+
100+
type_parameters(ca::Carg) = type_parameters_from_type(jltype(ca))
101+
96102
"""
97103
ctype(ca::Carg)
98104
@@ -105,8 +111,8 @@ ctype(::Carg{T}) where {T<:Union{Mag,Arf,Acf,Arb,Acb,ArbPoly,AcbPoly,ArbMatrix,A
105111
Ref{cstructtype(T)}
106112
ctype(::Carg{T}) where {T<:Union{ArbVector,arb_vec_struct}} = Ptr{arb_struct}
107113
ctype(::Carg{T}) where {T<:Union{AcbVector,acb_vec_struct}} = Ptr{acb_struct}
108-
ctype(::Carg{T}) where {T<:NFloat} = Ref{nfloat_struct}
109-
ctype(::Carg{T}) where {T<:nfloat_ctx_struct} = Ref{nfloat_ctx_struct}
114+
ctype(::Carg{T}) where {T<:NFloat} = Ref{nfloat_struct{P,F}} where {P,F}
115+
ctype(::Carg{T}) where {T<:nfloat_ctx_struct} = Ref{nfloat_ctx_struct{P,F}} where {P,F}
110116

111117
"""
112118
jlarg(ca::Carg{T}) where {T}
@@ -119,9 +125,39 @@ julia> Arblib.ArbCall.jlarg(Arblib.ArbCall.Carg("const arb_t x"))
119125
:(x::ArbLike)
120126
julia> Arblib.ArbCall.jlarg(Arblib.ArbCall.Carg("slong prec"))
121127
:(prec::Integer)
128+
julia> Arblib.ArbCall.jlarg(Arblib.ArbCall.Carg("nfloat_ptr res"))
129+
:(res::(NFloatLike){P,F})
122130
```
123131
"""
124-
jlarg(ca::Carg) = :($(name(ca))::$(jltype(ca)))
132+
function jlarg(ca::Carg)
133+
if jltype(ca) isa UnionAll
134+
:($(name(ca))::$(jltype(ca)){$(type_parameters(ca)...)})
135+
else
136+
:($(name(ca))::$(jltype(ca)))
137+
end
138+
end
139+
140+
"""
141+
carg_expr(ca::Carg{T}) where {T}
142+
143+
Return a value for representing the argument in a `ccall`.
144+
145+
```jldoctest
146+
julia> Arblib.ArbCall.carg_expr(Arblib.ArbCall.Carg("const arb_t x"))
147+
Ref{Arblib.arb_struct}
148+
julia> Arblib.ArbCall.carg_expr(Arblib.ArbCall.Carg("slong prec"))
149+
Int64
150+
julia> Arblib.ArbCall.carg_expr(Arblib.ArbCall.Carg("nfloat_ptr res"))
151+
:((Ref{Arblib.nfloat_struct{P, F}} where {P, F}){P, F})
152+
```
153+
"""
154+
function carg_expr(ca::Carg)
155+
if ctype(ca) isa UnionAll
156+
:($(ctype(ca)){$(type_parameters(ca)...)})
157+
else
158+
:($(ctype(ca)))
159+
end
160+
end
125161

126162
is_precision_argument(ca::Carg) = ca == Carg{Int}(:prec, false)
127163

src/Arblib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,6 @@ include("arbcalls/arb_fpwrap.jl")
112112
include("arbcalls/fmpz_extras.jl")
113113
include("arbcalls/eigen.jl")
114114

115+
include("nfloat_init_contexts.jl")
116+
115117
end # module

src/nfloat.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mutable struct nfloat_struct{P,F}
2828
function nfloat_struct{P,F}() where {P,F}
2929
@assert P isa Int && F isa Int
3030
res = new{P,F}()
31-
init!(res, nfloat_ctx_struct{P,F}())
31+
init!(res)
3232
return res
3333
end
3434
end
@@ -49,9 +49,11 @@ const NFloatLike{P,F} = Union{NFloat{P,F},NFloatRef{P,F},nfloat_struct{P,F}}
4949
nfloat_ctx_struct(::NFloatLike{P,F}) where {P,F} = nfloat_ctx_struct{P,F}()
5050
nfloat_ctx_struct(::Type{NFloatLike{P,F}}) where {P,F} = nfloat_ctx_struct{P,F}()
5151

52-
# TODO: Precompute these for all(?) possible values
53-
function _get_nfloat_ctx_struct(::Union{Type{NFloatLike{P,F}},NFloatLike{P,F}}) where {P,F}
54-
return nfloat_ctx_struct{P,F}()
52+
# The contexts are precomputed in nfloat_late.jl
53+
@generated function _get_nfloat_ctx_struct(
54+
::Union{Type{<:NFloatLike{P,F}},NFloatLike{P,F}},
55+
) where {P,F}
56+
Symbol(:_nfloat_ctx_struct_, P, :_, F)
5557
end
5658

5759
# Helper function for constructing a flag argument for NFloat
@@ -179,6 +181,13 @@ end
179181

180182
Base.show(io::IO, ::Type{NFloatLike}) = print(io, :NFloatLike)
181183

184+
# As in promotion.jl
185+
186+
Base.promote_rule(
187+
::Type{<:NFloatOrRef{P1,F1}},
188+
::Type{<:Union{NFloatOrRef{P2,F2}}},
189+
) where {P1,P2,F1,F2} = NFloat{max(P1, P2),F1 | F2}
190+
182191
# As in arithmetic.jl
183192

184193
for (jf, af) in [(:+, :add!), (:-, :sub!), (:*, :mul!), (:/, :div!)]

src/nfloat_init_contexts.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
for P = 1:66
2+
for F = 0:7
3+
@eval const $(Symbol(:_nfloat_ctx_struct_, P, :_, F)) = nfloat_ctx_struct{$P,$F}()
4+
end
5+
end

0 commit comments

Comments
 (0)