Skip to content

Commit

Permalink
[test] Implement stricter Memory checks
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Oct 16, 2024
1 parent cf209c9 commit 22642cb
Showing 1 changed file with 59 additions and 17 deletions.
76 changes: 59 additions & 17 deletions test/test_by_xed.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def newReg(s):
return s

class Memory:
def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=False):
def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=0):
self.size = size
self.base = newReg(base)
self.index = newReg(index)
Expand All @@ -85,8 +85,12 @@ def __init__(self, size=0, base=None, index=None, scale=0, disp=0, broadcast=Fal
self.broadcast = broadcast

def __str__(self):
s = 'ptr' if self.size == 0 else g_sizeTbl[int(math.log2(self.size))]
if self.broadcast:
if self.size == 0:
s = 'ptr'
else:
idx = self.size * max(self.broadcast, 1)
s = g_sizeTbl[int(math.log2(idx))]
if self.broadcast > 0:
s += '_b'
s += ' ['
needPlus = False
Expand All @@ -107,23 +111,36 @@ def __str__(self):
s += ']'
return s

# Xbyak uses 'ptr' when it can be automatically detected, so we should consider this in the comparison.
def __eq__(self, rhs):
# xbyak uses ptr if it is automatically detected, so xword == ptr is true
if self.broadcast != rhs.broadcast: return False
# if not self.broadcast and 0 < self.size <= 8 and 0 < rhs.size <= 8 and self.size != rhs.size: return False
if not self.broadcast and self.size > 0 and rhs.size > 0 and self.size != rhs.size: return False
if self.broadcast > rhs.broadcast:
return rhs == self
assert(self.broadcast <= rhs.broadcast)
if self.broadcast == 0:
if rhs.broadcast > 0: return False
# Xbyak uses 'ptr' when it is automatically detected.
# Therefore, the comparison is true if 'ptr' (i.e., size = 0) is used.
if 0 < self.size and 0 < rhs.size and self.size != rhs.size: return False
if self.broadcast == 1: # _b
if rhs.broadcast == 1: # compare ptr_b with ptr_b
if self.size != rhs.size:
return False
if self.size > 0 and (self.size != rhs.size * rhs.broadcast): # compare ptr_b with {1toX}
return False
else:
if self.broadcast != rhs.broadcast: return False
r = self.base == rhs.base and self.index == rhs.index and self.scale == rhs.scale and self.disp == rhs.disp
return r

def parseBroadcast(s):
if '_b' in s:
return (s.replace('_b', ''), True)
r = re.search(r'({1to\d+})', s)
return (s.replace('_b', ''), 1)
r = re.search(r'({1to(\d+)})', s)
if not r:
return (s, False)
return (s.replace(r.group(1), ''), True)
return (s, 0)
return (s.replace(r.group(1), ''), int(r.group(2)))

def parseMemory(s, broadcast=False):
def parseMemory(s, broadcast=0):
org_s = s

s = s.replace(' ', '').lower()
Expand All @@ -133,7 +150,7 @@ def parseMemory(s, broadcast=False):
scale = 0
disp = 0

if not broadcast:
if broadcast == 0:
(s, broadcast) = parseBroadcast(s)

# Parse size
Expand All @@ -157,7 +174,7 @@ def parseMemory(s, broadcast=False):
s = s[3:]

if s.startswith('_b'):
broadcast = True
broadcast = 1
s = s[2:]

# Extract the content inside brackets
Expand Down Expand Up @@ -335,7 +352,7 @@ def parseMemoryTest():
('[]', Memory()),
('[rax]', Memory(0, rax)),
('ptr[rax]', Memory(0, rax)),
('ptr_b[rax]', Memory(0, rax, broadcast=True)),
('ptr_b[rax]', Memory(0, rax, broadcast=1)),
('dword[rbx]', Memory(4, rbx)),
('xword ptr[rcx]', Memory(16, rcx)),
('xmmword ptr[rcx]', Memory(16, rcx)),
Expand All @@ -344,11 +361,36 @@ def parseMemoryTest():
('[0x12345]', Memory(0, None, None, 0, 0x12345)),
('yword [rax+rdx*4]', Memory(32, rax, rdx, 4)),
('zword [rax+rdx*4+123]', Memory(64, rax, rdx, 4, 123)),
('xword_b [rax]', Memory(16, rax, None, 0, 0, 1)),
('dword [rax]{1to4}', Memory(16, rax, None, 0, 0, 1)),
('yword_b [rax]', Memory(32, rax, None, 0, 0, 1)),
('dword [rax]{1to8}', Memory(32, rax, None, 0, 0, 1)),
]
for (s, expected) in tbl:
my = parseMemory(s)
assertEqualStr(my, expected)

print('compare test')
tbl = [
('ptr[rax]', 'dword[rax]', True),
('byte[rax]', 'dword[rax]', False),
('yword_b[rax]', 'dword [rax]{1to8}', True),
('yword_b[rax]', 'word [rax]{1to16}', True),
('zword_b[rax]', 'word [rax]{1to32}', True),
('zword_b[rax]', 'word [rax]{1to16}', False),
('dword [rax]{1to2}', 'dword [rax] {1to4}', False),
('zword_b[rax]', 'xword_b [rax]', False),
('ptr_b[rax]', 'word [rax]{1to32}', True), # ignore size
]
for (lhs, rhs, eq) in tbl:
a = parseMemory(lhs)
b = parseMemory(rhs)
if eq:
assertEqual(a, b)
assertEqual(b, a)
else:
assert(parseMemory(lhs) != parseMemory(rhs))

def parseNmemonicTest():
print('parseNmemonicTest')
tbl = [
Expand All @@ -364,8 +406,8 @@ def parseNmemonicTest():
('vpcompressw(zmm30 | k2 |T_z, zmm1);', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
('vpcompressw zmm30{k2}{z}, zmm1', Nmemonic('vpcompressw', [zmm30, zmm1], [k2, T_z])),
('vpshldw(xmm9|k3|T_z, xmm2, ptr [rax + 0x40], 5);', Nmemonic('vpshldw', [xmm9, xmm2, Memory(0, rax, None, 0, 0x40), 5], [k3, T_z])),
('vpshrdd(xmm5|k3|T_z, xmm2, ptr_b [rax + 0x40], 5);', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
('vpshrdd xmm5{k3}{z}, xmm2, dword ptr [rax+0x40]{1to4}, 0x5', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, True), 5], [k3, T_z])),
('vpshrdd(xmm5|k3|T_z, xmm2, ptr_b [rax + 0x40], 5);', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, 1), 5], [k3, T_z])),
('vpshrdd xmm5{k3}{z}, xmm2, dword ptr [rax+0x40]{1to4}, 0x5', Nmemonic('vpshrdd', [xmm5, xmm2, Memory(0, rax, None, 0, 0x40, 4), 5], [k3, T_z])),
('vcmpph(k1, xmm15, ptr[rax+64], 1);', Nmemonic('vcmpph', [k1, xmm15, Memory(0, rax, None, 0, 64), 1])),
]
for (s, expected) in tbl:
Expand Down

0 comments on commit 22642cb

Please sign in to comment.