Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Remove unnecessary helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles committed Jun 19, 2024
1 parent ff981fa commit 7375cb5
Showing 1 changed file with 26 additions and 42 deletions.
68 changes: 26 additions & 42 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord
n = length(arrays)
symbols = [nextindex() for _ in 1:2n]

function get_index(directions, i)
map(directions) do dir
_tensors = map(enumerate(arrays)) do (i, array)
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :r
Expand All @@ -53,10 +53,6 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; ord
throw(ArgumentError("Invalid direction: $dir"))
end
end
end

_tensors = map(enumerate(arrays)) do (i, array)
inds = get_index(order, i)
Tensor(array, inds)
end

Expand All @@ -69,20 +65,22 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order =
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ndims, arrays[2:end-1]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
issetequal(order, defaultorder(State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))"))
issetequal(order, defaultorder(Chain, State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:2n]

function get_index(directions, i, is_first, is_last)
if is_first
directions = filter(x -> x != :l, directions)
elseif is_last
directions = filter(x -> x != :r, directions)
_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
_order = filter(x -> x != :l, order)
elseif i == n
_order = filter(x -> x != :r, order)
else
_order = order
end

map(directions) do dir
inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
Expand All @@ -93,12 +91,6 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order =
throw(ArgumentError("Invalid direction: $dir"))
end
end
end

_tensors = map(enumerate(arrays)) do (i, array)
is_first = (i == 1)
is_last = (i == n)
inds = get_index(order, i, is_first, is_last)
Tensor(array, inds)
end

Expand All @@ -109,14 +101,14 @@ end

function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator()))
@assert all(==(4) ndims, arrays) "All arrays must have 4 dimensions"
issetequal(order, defaultorder(Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))"))
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n]

function get_index(directions, i)
map(directions) do dir
_tensors = map(enumerate(arrays)) do (i, array)
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :i
Expand All @@ -129,10 +121,6 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray};
throw(ArgumentError("Invalid direction: $dir"))
end
end
end

_tensors = map(enumerate(arrays)) do (i, array)
inds = get_index(order, i)
Tensor(array, inds)
end

Expand All @@ -146,20 +134,22 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde
@assert ndims(arrays[1]) == 3 "First array must have 3 dimensions"
@assert all(==(4) ndims, arrays[2:end-1]) "All arrays must have 4 dimensions"
@assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions"
issetequal(order, defaultorder(Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Operator())))"))
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n-1]

function get_index(directions, i, is_first, is_last)
if is_first
directions = filter(x -> x != :l, directions)
elseif is_last
directions = filter(x -> x != :r, directions)
_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
_order = filter(x -> x != :l, order)
elseif i == n
_order = filter(x -> x != :r, order)
else
_order = order
end

map(directions) do dir
inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :i
Expand All @@ -172,12 +162,6 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; orde
throw(ArgumentError("Invalid direction: $dir"))
end
end
end

_tensors = map(enumerate(arrays)) do (i, array)
is_first = (i == 1)
is_last = (i == n)
inds = get_index(order, i, is_first, is_last)
Tensor(array, inds)
end

Expand Down

0 comments on commit 7375cb5

Please sign in to comment.