Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 2acd20b

Browse files
Revert "Split translate function"
This reverts commit 3f64767.
1 parent 3f64767 commit 2acd20b

File tree

2 files changed

+5
-23
lines changed

2 files changed

+5
-23
lines changed

src/device/matmul_kernels/kernel.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function matmul_impl(a, b, c, d,
4545

4646
@unroll for i = 1 : NUM_FRAGMENTS_M
4747
@unroll for j = 1 : NUM_FRAGMENTS_N
48-
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
48+
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
4949
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050
end
5151
end
@@ -84,15 +84,15 @@ function matmul_impl(a, b, c, d,
8484
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)
8585

8686
@unroll for i = 1 : NUM_FRAGMENTS_M
87-
a_tile = translate_const(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
87+
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
8888
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989
end
9090

9191
# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)
9393

9494
@unroll for j = 1 : NUM_FRAGMENTS_N
95-
b_tile = translate_const(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
95+
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
9696
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797
end
9898

@@ -114,7 +114,7 @@ function matmul_impl(a, b, c, d,
114114

115115
@unroll for i = 1 : NUM_FRAGMENTS_M
116116
@unroll for j = 1 : NUM_FRAGMENTS_N
117-
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
117+
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
118118
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
119119
end
120120
end

src/device/tiling.jl

+1-19
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ export translate
119119
"""
120120
translate(tile::Tile{names, T}, offset::NamedTuple{names, T})
121121
122-
Translate (i.e. move) a [`Tile`](@ref) by an `offset`.
122+
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
123123
124124
# Arguments
125125
- `tile`: The [`Tile`](@ref) to translate.
@@ -132,24 +132,6 @@ end
132132

133133
@inline translate(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate(tile, NamedTuple{names}(offset))
134134

135-
export translate_const
136-
137-
"""
138-
translate_const(tile::Tile{names, T}, offset::NamedTuple{names, T})
139-
140-
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
141-
142-
# Arguments
143-
- `tile`: The [`Tile`](@ref) to translate.
144-
- `offset`: The `offset` in each dimension.
145-
"""
146-
@inline function translate_const(tile::Tile{size, names, T}, offset::NamedTuple{names, T}) where {names, T, size}
147-
offset = map(+, tile.offset, offset)
148-
return Tile{size, names, T}(tile.base, offset)
149-
end
150-
151-
@inline translate_const(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate_const(tile, NamedTuple{names}(offset))
152-
153135
# -------------
154136
# TileIterators
155137
# -------------

0 commit comments

Comments
 (0)