Skip to content

Commit

Permalink
Add umull to the N1 cost model and define more variants of simple int…
Browse files Browse the repository at this point in the history
…eger ops

This patch adds `umull_wform` (`pattern = "umull <Xd>, <Wa>, <Wb>"`) to the N1
cost model, and defines more variants of simple integer ops.
  • Loading branch information
aqjune-aws committed Jul 23, 2024
1 parent 443d222 commit 874b23f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
38 changes: 38 additions & 0 deletions slothy/targets/aarch64/aarch64_neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,12 +1665,32 @@ class neg(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-n
inputs = ["Xa"]
outputs = ["Xd"]

class negs(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "negs <Xd>, <Xa>"
inputs = ["Xa"]
outputs = ["Xd"]
modifiesFlags=True

class ngc_zero(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "ngc <Xd>, xzr"
inputs = []
outputs = ["Xd"]
dependsOnFlags=True

class ngcs(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "ngcs <Xd>, <Xa>"
inputs = ["Xa"]
outputs = ["Xd"]
modifiesFlags=True
dependsOnFlags=True

class ngcs_zero(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "ngcs <Xd>, xzr"
inputs = []
outputs = ["Xd"]
modifiesFlags=True
dependsOnFlags=True

class adds(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "adds <Xd>, <Xa>, <imm>"
inputs = ["Xa"]
Expand Down Expand Up @@ -1720,6 +1740,13 @@ class sbcs_zero(AArch64BasicArithmetic): # pylint: disable=missing-docstring,inv
modifiesFlags=True
dependsOnFlags=True

class sbcs_to_zero(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "sbcs xzr, <Xa>, <Xb>"
inputs = ["Xa", "Xb"]
outputs = []
modifiesFlags=True
dependsOnFlags=True

class sbcs_zero_to_zero(AArch64BasicArithmetic): # pylint: disable=missing-docstring,invalid-name
pattern = "sbcs xzr, <Xa>, xzr"
inputs = ["Xa"]
Expand Down Expand Up @@ -1980,6 +2007,12 @@ class csel_xzr_ne(AArch64ConditionalSelect): # pylint: disable=missing-docstring
outputs = ["Xd"]
dependsOnFlags=True

class csel_xzr2_ne(AArch64ConditionalSelect): # pylint: disable=missing-docstring,invalid-name
pattern = "csel <Xd>, xzr, <Xe>, <flag>"
inputs = ["Xe"]
outputs = ["Xd"]
dependsOnFlags=True

class csel_ne(AArch64ConditionalSelect): # pylint: disable=missing-docstring,invalid-name
pattern = "csel <Xd>, <Xe>, <Xf>, <flag>"
inputs = ["Xe", "Xf"]
Expand Down Expand Up @@ -2125,6 +2158,11 @@ class tst_xform(Tst): # pylint: disable=missing-docstring,invalid-name
inputs = ["Xa", "Xb"]
modifiesFlags=True

class cmp(Tst): # pylint: disable=missing-docstring,invalid-name
pattern = "cmp <Xa>, <Xb>"
inputs = ["Xa","Xb"]
modifiesFlags=True

class cmp_xzr(Tst): # pylint: disable=missing-docstring,invalid-name
pattern = "cmp <Xa>, xzr"
inputs = ["Xa"]
Expand Down
3 changes: 3 additions & 0 deletions slothy/targets/aarch64/neoverse_n1_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_min_max_objective(slothy):
Tst : ExecutionUnit.I(),
AArch64ShiftedArithmetic : ExecutionUnit.M(),
Fmov : ExecutionUnit.M(),
umull_wform : ExecutionUnit.M(),
(AArch64HighMultiply,
AArch64Multiply) : ExecutionUnit.M(),
vdup : ExecutionUnit.M(),
Expand Down Expand Up @@ -158,6 +159,7 @@ def get_min_max_objective(slothy):
(AArch64HighMultiply) : 4,
(AArch64Multiply) : 3,
(vdup) : 1,
umull_wform : 1,
}

default_latencies = {
Expand Down Expand Up @@ -193,6 +195,7 @@ def get_min_max_objective(slothy):
AArch64HighMultiply : 5,
AArch64Multiply : 4,
(vdup) : 3,
umull_wform : 2,
}

def get_latency(src, out_idx, dst):
Expand Down

0 comments on commit 874b23f

Please sign in to comment.