Skip to content

Commit

Permalink
make shifts handle longer displacements
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Jan 10, 2025
1 parent 5408a96 commit ec31644
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 201 deletions.
17 changes: 17 additions & 0 deletions src/base/threading.nim
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ macro tFor*(index: untyped; slice: Slice; body: untyped): untyped =
i1 = slice[2]
result = tForX(index, i0, i1, body)

proc adjust[T](i: int, x: openarray[T], b: int): int =
result = i
if i > 0:
let n = x.len
while result<n:
if x[result-1] div b != x[result] div b:
break
inc result

proc splitThreads*[T](x: openarray[T], b: int, nt,myt: int): tuple[a:int,b:int] =
let n = x.len
var i0 = (n*myt) div nt;
var i1 = (n*(myt+1)) div nt
i0 = adjust(i0, x, b)
i1 = adjust(i1, x, b)
result = (i0,i1)

discard """
iterator `.|`*[S, T](a: S, b: T): T {.inline.} =
mixin threadNum
Expand Down
17 changes: 0 additions & 17 deletions src/comms/gather.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,6 @@ template newSeqOfCap[T](x: var seq[T], n: int) =
template newSeqUninitialized[T](x: var seq[T], n: int) =
x = newSeqUninitialized[T](n)

proc adjust[T](i: int, x: openarray[T], b: int): int =
result = i
if i > 0:
let n = x.len
while result<n:
if x[result-1] div b != x[result] div b:
break
inc result

proc splitThreads[T](x: openarray[T], b: int, nt,myt: int): tuple[a:int,b:int] =
let n = x.len
var i0 = (n*myt) div nt;
var i1 = (n*(myt+1)) div nt
i0 = adjust(i0, x, b)
i1 = adjust(i1, x, b)
result = (i0,i1)

type
RecvList* = object
didx*: int32 # destination index on this rank
Expand Down
27 changes: 21 additions & 6 deletions src/layout/layoutTypes.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ type
beginOuter*: int
endOuter*: int

type
SendSite* = object
maskidx* {.bitsize:2.}: uint32
site* {.bitsize:30.}: uint32
RecvIdx* = object
maskidx* {.bitsize:1.}: uint32
idx* {.bitsize:31.}: uint32

# FIXME: check int sizes
type ShiftIndicesQ* = object
gi*: ptr GatherIndices
Expand All @@ -59,28 +67,35 @@ type ShiftIndicesQ* = object
recvRemoteSrcs*: ptr cArray[cint]
nSendRanks*: cint
sendRanks*: ptr cArray[cint]
sendRankSizes*: ptr cArray[cint]
sendRankSizes*: ptr cArray[int]
sendRankSizes1*: ptr cArray[cint]
sendRankOffsets*: ptr cArray[cint]
sendRankOffsets1*: ptr cArray[cint]
nSendSites*: cint
nSendSites*: int
nSendSites1*: cint
#sendSites*: ptr cArray[cint]
sendSites*: seq[int32]
#sendSites*: seq[int32]
sendSites*: seq[SendSite]
recvIndex*: seq[RecvIdx]
vv*: cint
perm*: cint
pack*: cint
blend*: cint
packmasks*: array[2,int]
packbits*: array[2,int]
packmasks*: array[4,int] # bit mask for indices in innerGeom to send
packbits*: array[4,int] # number of bits in packmasks
sbufcount*: seq[int32] # index offset in send buffer for sendSite
lbufcount*: seq[int32] # index offset in local buffer for sendSite
recvmasks*: array[2,int] # bit mask for indices in innerGeom that are received
recvbits*: array[2,int] # number of bits in recvmasks

type ShiftIndices* = ref object
sq*: ShiftIndicesQ
nRecvRanks*: int
nRecvDests*: int
nSendRanks*: int
nSendSites*: int
sendSites*: seq[int32]
#sendSites*: seq[int32]
sendSites*: seq[SendSite]
perm*: int
pack*: int
blend*: int
Expand Down
Loading

0 comments on commit ec31644

Please sign in to comment.