|
| 1 | +using MacroTools |
| 2 | + |
| 3 | +_permutedims(x, p) = |
| 4 | + p == collect(1:length(p)) ? x : # TODO 0.7 |
| 5 | + p == [2, 1] ? :(transpose($x)) : |
| 6 | + :(permutedims($x, ($(p...),))) |
| 7 | + |
| 8 | +_size(x, n) = Any[:(size($x, $i)) for i = 1:n] |
| 9 | + |
| 10 | +_reshape(x, n, s) = # TODO use transpose |
| 11 | + s == _size(x, n) ? x : |
| 12 | + :(reshape($x, ($(s...),))) |
| 13 | + |
| 14 | +function _expanddims(x, n, ds) # TODO use transpose |
| 15 | + is = _size(x, n) |
| 16 | + foreach(d -> insert!(is, d, 1), ds) |
| 17 | + :(reshape($x, ($(is...),))) |
| 18 | +end |
| 19 | + |
| 20 | +function _squeezedims(x, n, ds) |
| 21 | + is = [d for (i, d) in enumerate(_size(x, n)) if i ∉ ds] |
| 22 | + :(reshape($x, ($(is...),))) |
| 23 | +end |
| 24 | + |
| 25 | +function _einsum_pair(a, b, dims) |
| 26 | + (a, adims), (b, bdims) = a, b # TODO 0.7 |
| 27 | + preserved = setdiff(intersect(adims, bdims), dims) |
| 28 | + broadcast = map(adims -> reduce(setdiff, (adims, preserved, dims)), (adims, bdims)) # TODO 0.7 |
| 29 | + # TODO move preserved dims last |
| 30 | + aperm = sortperm(adims, by = i -> i in preserved ? -1 : i in broadcast[1] ? 0 : 1) |
| 31 | + bperm = sortperm(bdims, by = i -> i in preserved ? -1 : i in dims ? 0 : 1) |
| 32 | + a, b = _permutedims(a, aperm), _permutedims(b, bperm) |
| 33 | + adims, bdims = adims[aperm], bdims[bperm] |
| 34 | + if isempty(dims) |
| 35 | + b = _expanddims(b, length(bdims), length(preserved)+(1:length(broadcast[1]))) |
| 36 | + :($a .* $b), vcat(adims[aperm], bdims[bperm][length(preserved)+1:end]) |
| 37 | + else |
| 38 | + prod(xs) = isempty(xs) ? 0 : length(xs) == 1 ? xs[1] : :(prod(($(xs...),))) |
| 39 | + |
| 40 | + ashape = _size(a, length(adims)) |
| 41 | + npreserve = prod(ashape[1:length(preserved)]) |
| 42 | + aaxes = 1+length(preserved):length(adims)-length(dims) |
| 43 | + abroadcast = prod(ashape[aaxes]) |
| 44 | + asum = prod(ashape[end-length(dims)+1:end]) |
| 45 | + a = _reshape(a, length(adims), [abroadcast, asum]) # TODO preserve |
| 46 | + |
| 47 | + bshape = _size(b, length(bdims)) |
| 48 | + bsum = prod(bshape[length(preserved)+1:end-length(broadcast[2])]) |
| 49 | + baxes = length(bdims)-length(broadcast[2])+1:length(bdims) |
| 50 | + bbroadcast = prod(bshape[baxes]) |
| 51 | + b = _reshape(b, length(bdims), [bsum, bbroadcast]) # TODO preserve |
| 52 | + |
| 53 | + ab = :($a*$b) |
| 54 | + shape = vcat(ashape[[1:length(preserved)..., aaxes...]], bshape[baxes]) |
| 55 | + shape == [abroadcast, bbroadcast] || (ab = _reshape(ab, 2, shape)) |
| 56 | + axes = vcat(adims[[1:length(preserved)..., aaxes...]], bdims[baxes]) |
| 57 | + |
| 58 | + return ab, axes |
| 59 | + end |
| 60 | +end |
| 61 | + |
| 62 | +# _einsum_pair([:a, [:i, :j]], [:b, [:j, :k]], [:j]) |
| 63 | +# _einsum_pair([:a, [:i, :j, :N]], [:b, [:j, :k, :N]], [:j]) |
| 64 | + |
| 65 | +macro einsum(ex) |
| 66 | + @capture(ex, [out__] -> *(in__) | in_) || error("`@einsum [...] -> a[...] * b[...] * ...`") |
| 67 | + in isa Vector || (in = [in]) |
| 68 | + # TODO rebinding, check dims |
| 69 | + in = map(in) do x |
| 70 | + @capture(x, a_[i__]) || error("Einsum input should be `a[i...]`, got `$x`") |
| 71 | + esc(a), i |
| 72 | + end |
| 73 | + all(length(unique(is)) == length(is) for (_, is) in in) || error("Diagonals not supported") |
| 74 | + labels = unique(vcat(map(x -> x[2], in)...)) |
| 75 | + for i in labels |
| 76 | + count(in -> i ∈ in[2], in) > 2 && error("Not supported: index $i appears more than twice") |
| 77 | + end |
| 78 | + y = in[1] |
| 79 | + for i = 1:length(in)-1 |
| 80 | + dims = setdiff(union(y[2], in[i+1][2]), out) |
| 81 | + y = _einsum_pair(y, in[i+1], dims) |
| 82 | + end |
| 83 | + reduce = setdiff(y[2], out) |
| 84 | + if !isempty(reduce) |
| 85 | + r = indexin(reduce, y[2]) |
| 86 | + y = _squeezedims(:(sum($(y[1]), ($(r...),))), length(y[2]), r), |
| 87 | + setdiff(y[2], reduce) |
| 88 | + end |
| 89 | + @assert sort(y[2]) == sort(out) |
| 90 | + return _permutedims(y[1], indexin(out, y[2])) |
| 91 | +end |
| 92 | + |
| 93 | +# @expand @einsum [i] -> a[i,j] |
| 94 | +# @expand @einsum [i,k] -> a[i,j] * b[j,k] |
| 95 | +# @expand @einsum [i,k] -> a[j,k] * b[i,j] |
| 96 | +# @expand @einsum [i,k,N] -> a[i,j,N] * b[j,k,N] |
| 97 | +# @expand @einsum [i,j] -> a[i] * b[j] |
0 commit comments