Skip to content

Commit

Permalink
hotfix: Assembly syntax and coverage for DirectToLds has changed (#1598
Browse files Browse the repository at this point in the history
…) (#1610)
  • Loading branch information
AlexBrownAMD authored Oct 29, 2022
1 parent b33ca97 commit 006a5d6
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 118 deletions.
4 changes: 2 additions & 2 deletions .jenkins/common.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def runCompileCommand(platform, project, jobName, boolean debug=false)
project.paths.construct_build_prefix()

String compiler = 'hipcc'
String pythonVersion = 'py36'
String pythonVersion = 'py3'
String cov = "V3"
String buildType = debug ? 'Debug' : 'RelWithDebInfo'
String parallelJobs = "export HIPCC_COMPILE_FLAGS_APPEND=-parallel-jobs=2"
Expand Down Expand Up @@ -110,7 +110,7 @@ def runTestCommand (platform, project, jobName, test_marks, boolean skipHostTest
def test_dir = "Tensile/Tests"

String compiler = 'hipcc'
String pythonVersion = 'py36'
String pythonVersion = 'py3'
String markSkipHostTest = skipHostTest ? "#" : ""
String markSkipExtendedTest = !test_marks.contains("extended") ? "--gtest_filter=-\"*Extended*\"" : ""

Expand Down
177 changes: 91 additions & 86 deletions Tensile/Common.py

Large diffs are not rendered by default.

81 changes: 52 additions & 29 deletions Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,10 @@ def initKernel(self, kernel, tPA, tPB ):
self.AsmBugs["ExplicitCO"] = globalParameters["AsmCaps"][self.version]["HasExplicitCO"]
self.AsmBugs["ExplicitNC"] = globalParameters["AsmCaps"][self.version]["HasExplicitNC"]

if not globalParameters["AsmCaps"][self.version]["HasDirectToLds"]:
hasDtl = globalParameters["AsmCaps"][self.version]["HasDirectToLdsDest"] or globalParameters["AsmCaps"][self.version]["HasDirectToLdsNoDest"]
if not hasDtl:
if kernel["DirectToLds"]:
printExit("DirectToLds requested, but not available on this architecture ( {} )".format(self.version))
kernel["DirectToLdsA"] = False
kernel["DirectToLdsB"] = False
kernel["LocalWriteUseSgprA"] = False # Requires DirectToLdsA
Expand Down Expand Up @@ -2223,6 +2226,16 @@ def defineBufferMemoryMacros(self):
replace = f'{type_list[t]}' if (self.version[0] < 11) else f'{t}'
kStr += self.generalMacro('buffer_load_', origin, replace, 'dst', 'voffset', 'base', 'soffset', 'offen', 'ioffset', 'md0', 'md1', 'md2') + self.endLine

# Extra macro for DirectToLds loads with no destination register
type_list = {
'b32' : 'dword',
'u16' : 'ushort'
}
for t in type_list:
origin = f'{t}'
replace = f'{type_list[t]}' if (self.version[0] < 11) else f'{t}'
kStr += self.generalMacro('buffer_load_', origin + '_dtl', replace, 'voffset', 'base', 'soffset', 'offen', 'ioffset', 'md0', 'md1', 'md2') + self.endLine

type_list = {
'b32' : 'dword',
'b64' : 'dwordx2',
Expand Down Expand Up @@ -7294,8 +7307,11 @@ def globalReadGuardK(self, kernel, tP, vregSetIdx):
extraFields += " glc"
if tP["NonTemporal"]//2==1:
extraFields += " slc"
dtlNoDestVgpr = False
if kernel["DirectToLds%s"%tc]:
extraFields += " lds"
dtlNoDestVgpr = globalParameters["AsmCaps"][self.version]["HasDirectToLdsNoDest"]


directToLdsLoads = 0
prevLdsOffset = 0
Expand Down Expand Up @@ -7503,6 +7519,7 @@ def globalReadGuardK(self, kernel, tP, vregSetIdx):
addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 4), \
soffset=soffset, offset=offset, \
extraFields=extraFields, \
dtlNoDestVgpr=dtlNoDestVgpr, \
hi16=hi16, \
comment=comment).toStr()

Expand All @@ -7529,6 +7546,7 @@ def globalReadGuardK(self, kernel, tP, vregSetIdx):
addr0=vgpr("GlobalReadAddr%s+%u"%(tc,graIdx),2), addr1="", \
soffset=0, offset=0, \
extraFields=extraFields, \
dtlNoDestVgpr=dtlNoDestVgpr, \
hi16=hi16, \
comment="load one flat value").toStr()

Expand Down Expand Up @@ -7793,8 +7811,10 @@ def globalReadDo(self, kernel, mode, tP, vregSetIdx=0):
extraFields += " glc"
if tP["NonTemporal"]//2==1:
extraFields += " slc"
dtlNoDestVgpr = False
if kernel["DirectToLds%s"%tc]:
extraFields += " lds"
dtlNoDestVgpr = globalParameters["AsmCaps"][self.version]["HasDirectToLdsNoDest"]

directToLdsLoads = 0
instOffset = 0
Expand Down Expand Up @@ -7900,6 +7920,7 @@ def globalReadDo(self, kernel, mode, tP, vregSetIdx=0):
addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 4), \
soffset=soffset, offset=instOffset, \
extraFields=extraFields, \
dtlNoDestVgpr=dtlNoDestVgpr, \
hi16=(kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()) and loopCnt%2==1, \
comment="G -> Reg %u_%u_%u_%u"%(para, sPara, perp, sPerp)))

Expand All @@ -7925,6 +7946,7 @@ def globalReadDo(self, kernel, mode, tP, vregSetIdx=0):
addr0=vgpr("GlobalReadAddr%s+%u"%(tc,graIdx),2), addr1="", \
soffset=0, offset=0, \
extraFields=extraFields, \
dtlNoDestVgpr=dtlNoDestVgpr, \
hi16=(kernel["ProblemType"]["DataType"].isHalf() or kernel["ProblemType"]["DataType"].isBFloat16()) and loopCnt%2==1, \
comment="G -> Reg %u_%u_%u_%u"%(para, sPara, perp, sPerp )))

Expand Down Expand Up @@ -10982,7 +11004,7 @@ def globalWriteElements(self, kernel, vectorWidths, elements,
# bpl = bytes per load op
##############################################################################
def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \
addr0, addr1, soffset, offset, extraFields, hi16=0, comment="load C"):
addr0, addr1, soffset, offset, extraFields, dtlNoDestVgpr, hi16=0, comment="load C"):
# rpv = regs per vector
rpv = bpl/4.0

Expand All @@ -11000,34 +11022,25 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \
assert 0, "offset too large and soffset set"
if extraFields != "":
tailFields += ", %s"% extraFields
globalReadInst = None
if bpl==1 and hi16:
rv.addCode(Code.GlobalReadInst("_buffer_load_d16_hi_u8", vgpr(destVgpr, rpv*4), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_d16_hi_u8"
rpv *= 4
elif bpl==1 and not hi16:
rv.addCode(Code.GlobalReadInst("_buffer_load_d16_u8", vgpr(destVgpr, rpv*4), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_d16_u8"
rpv *= 4
elif bpl==2 and hi16:
rv.addCode(Code.GlobalReadInst("_buffer_load_d16_hi_b16", vgpr(destVgpr, rpv*2), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_d16_hi_b16"
rpv *= 2
elif bpl==2 and not hi16:
rv.addCode(Code.GlobalReadInst("_buffer_load_d16_b16", vgpr(destVgpr, rpv*2), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_d16_b16"
rpv *= 2
elif bpl==4:
rv.addCode(Code.GlobalReadInst("_buffer_load_b32", vgpr(destVgpr, rpv), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_b32"
elif bpl==8:
rv.addCode(Code.GlobalReadInst("_buffer_load_b64", vgpr(destVgpr, rpv), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_b64"
elif bpl==16:
rv.addCode(Code.GlobalReadInst("_buffer_load_b128", vgpr(destVgpr, rpv), addr0, \
addr1, soffset, tailFields, comment))
return rv
globalReadInst = "_buffer_load_b128"
elif bpl==32:
# split into two dwordx4 loads. Second load offset is +0.5 bpl
tailFields1 = "offen offset:%u"%(offset + bpl/2)
Expand All @@ -11039,9 +11052,17 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \
addr1, soffset, tailFields, comment))
rv.addCode(Code.GlobalReadInst("_buffer_load_b128", vgpr(int(destVgpr + rpv/2), rpv/2), addr0, \
addr1, soffset, tailFields1, comment))
return rv
else:
assert 0, "chooseGlobalRead: bad bpl"

if dtlNoDestVgpr:
globalReadInst += "_dtl"
args = [globalReadInst]
if not dtlNoDestVgpr:
args.append(vgpr(destVgpr, rpv))
args.extend([addr0, addr1, soffset, tailFields, comment])
rv.addCode(Code.GlobalReadInst(*args))
return rv

else:
Expand Down Expand Up @@ -11310,7 +11331,7 @@ def readCInput(self, kernel, ss, addrCalc, vc0, data, gwvw, addr, tmpS01):
if kernel["ProblemType"]["DestDataType"].isHalf():
kStr += self.chooseGlobalRead(useBuffer, bps, data, \
addr0, addr1, soffset=0, offset=addrCalc.globalOffset, \
extraFields=extraStr, hi16=vc0 % 2,
extraFields=extraStr, dtlNoDestVgpr=False, hi16=vc0 % 2,
comment="load C for beta calc").toStr()
elif kernel["ProblemType"]["DestDataType"].isBFloat16() or \
kernel["ProblemType"]["DestDataType"].isInt32() or \
Expand All @@ -11321,6 +11342,7 @@ def readCInput(self, kernel, ss, addrCalc, vc0, data, gwvw, addr, tmpS01):
kStr += self.chooseGlobalRead(useBuffer, bps, data, \
addr0, addr1, soffset=0, offset=addrCalc.globalOffset, \
extraFields=extraStr, \
dtlNoDestVgpr=False, \
comment="load C for beta calc").toStr()

return kStr
Expand Down Expand Up @@ -11501,6 +11523,7 @@ def globalWriteBatch(self, kernel, ss, batchIdx, applyAlpha, beta, edge, atomic,
vgprIdx = 1*(bpm//4)
kStr += self.chooseGlobalRead(useBuffer, bpm, dataV+vgprIdx, \
addr0, addr1, soffset=0, offset=addrCalc.globalOffset, extraFields="",
dtlNoDestVgpr=False, \
comment="load D (atomic) bpm=%u vaw=%u"%(bpm,atomicW)).toStr()

if kernel["InterleaveAlpha"] and applyAlpha:
Expand Down Expand Up @@ -12172,19 +12195,19 @@ def globalWriteBatch(self, kernel, ss, batchIdx, applyAlpha, beta, edge, atomic,
if kernel["ProblemType"]["DestDataType"].isHalf() or kernel["ProblemType"]["DestDataType"].isBFloat16():
if not kernel["ProblemType"]["HighPrecisionAccumulate"]:
kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx//2, \
addr0, addr1, soffset=0, offset=0, extraFields="", hi16=sumIdx%2).toStr()
addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False, hi16=sumIdx%2).toStr()
else:
kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx, \
addr0, addr1, soffset=0, offset=0, extraFields="", hi16=0).toStr()
addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False, hi16=0).toStr()
elif kernel["ProblemType"]["DestDataType"].isInt32() or kernel["ProblemType"]["DestDataType"].isSingle():
kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx, \
addr0, addr1, soffset=0, offset=0, extraFields="").toStr()
addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr()
elif kernel["ProblemType"]["DestDataType"].isDouble() or kernel["ProblemType"]["DestDataType"].isSingleComplex() :
kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx*2, \
addr0, addr1, soffset=0, offset=0, extraFields="").toStr()
addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr()
elif kernel["ProblemType"]["DestDataType"].isDoubleComplex():
kStr += self.chooseGlobalRead(useBuffer, bps, sumIdx*4, \
addr0, addr1, soffset=0, offset=0, extraFields="").toStr()
addr0, addr1, soffset=0, offset=0, extraFields="", dtlNoDestVgpr=False).toStr()
kStr += inst("s_waitcnt", "vmcnt(0)", "CheckStoreC, wait for stores to complete" )
if self.archCaps["SeparateVscnt"]:
kStr += inst("s_waitcnt_vscnt", "null", "0", "writes")
Expand Down
8 changes: 8 additions & 0 deletions Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,11 @@ def isDirectToLdsDoable(state, tc):
#TN
# use for all precisions with TransposeLDS=1

numRegisters = state["ProblemType"]["DataType"].numRegisters()
if numRegisters * state["GlobalLoadVectorWidth%c"%tc] != 1:
reject(state, "DirectToLds can only be used with buffer loads requiring 1 register")
return False

if state["ProblemType"]["DataType"].isHalf():
if state["AssertSummationElementMultiple"] % (2 * state["GlobalLoadVectorWidth%c"%tc]) != 0:
reject(state, "can't use DirectToLds for FP16 with AssertSummationElementMultiple %u" % state["AssertSummationElementMultiple"])
Expand Down Expand Up @@ -3437,6 +3442,9 @@ def assignDerivedParameters(state):
state["DirectToLdsB"] = True
state["LocalWriteUseSgprB"] = True
#print("DirectToLdsB", state["DirectToLdsB"])

if state["Valid"] and state["DirectToLds"] and not (state["DirectToLdsA"] or state["DirectToLdsB"]):
printWarning("DirectToLds requested, but not enabled for A or B, check kernel configuration!")

# Update parent variable so kernel display is accurate
state["DirectToLds"] = state["DirectToLdsA"] or state["DirectToLdsB"]
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[tox]
envlist = py35,py36,py27,lint
envlist = py35,py36,py38,py27,lint


[testenv]
# Some versions of Pytest versions have a bug:
Expand Down

0 comments on commit 006a5d6

Please sign in to comment.