Skip to content

Commit

Permalink
Stop using IdDict on Reactant extension
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 28, 2024
1 parent 2bf32ca commit 6209829
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ const stablehlo = MLIR.Dialects.stablehlo
const Enzyme = Reactant.Enzyme

function Reactant.make_tracer(
seen::IdDict, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {RT<:Tensor}
tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...)
return Tensor(tracedata, inds(prev))
end

function Reactant.make_tracer(seen::IdDict, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetensors = Vector{Tensor}(undef, Tenet.ntensors(prev))
for (i, tensor) in enumerate(tensors(prev))
tracetensors[i] = Reactant.make_tracer(seen, tensor, Reactant.append_path(path, i), mode; kwargs...)
Expand All @@ -26,18 +26,18 @@ end

Reactant.traced_getfield(x::TensorNetwork, i::Int) = tensors(x)[i]

function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, TensorNetwork(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Quantum(tracetn, copy(prev.sites))
end

function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Product(tracequantum)
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.make_tracer(seen::IdDict, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...)
function Reactant.make_tracer(seen, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Chain(tracequantum, boundary(prev))
end
Expand Down

0 comments on commit 6209829

Please sign in to comment.