这题目就是做回溯搜索,但是关键在如何剪枝上。有效的剪枝可以从TLE到520ms. 所谓的剪枝,就是观察数据分布和程序的执行情况,看哪些地方可以提前返回。 剪枝要尽可能地减掉起始搜索的可能性,所以剪枝的好坏很大程度上在于搜索空间的顺序选择。
这道题目最主要的剪枝,就是范围检查。上了范围检查之后,能从TLE到2816ms. 比如对case: words = [“SEND”,”MORE”], result = “MONEY” 来说
- 如果S被安排到1, M安排在9的话
- SEND的取值范围在 [1000, 1999] (当然可以更小)
- MORE取值范围在 [9000, 9999]
- MONEY在 [90000, 99999]
从取值范围来看,很明显这个分布是不成立的,这个搜索空间可以立刻被剪去。
为了配合这种剪枝策略,搜索空间的顺序就要尽可能地满足靠前面的字母,以上面为例
- 第一轮选择 S, M, M
- 第二轮选择 E, O, O
- 第三轮选择 N, R, N
- 第四轮选择 D, E, E
- 最后选择 Y
所以我们搜索空间顺序可以是 [S, M, E, O, N, R, D, Y]. 这样可以确保开头的字母优先被分配到,可以更快地做范围检查。
其实还有一个优化是尾字母的和检查,比如 (“D” + “E”) % 10 = “Y”这样。但是尾字母的和检查和范围检查要求是相悖的, 而且都是在搜索空间的后期才能被使用上,能够被剪去的情况也不多,所以只能当做辅助检查。
class Solution:
def isSolvable(self, words: List[str], result: str) -> bool:
chars = set(result)
for w in words:
chars.update(w)
chars = list(chars)
chars.sort()
chars_to_idx = {c: i for i, c in enumerate(chars)}
# print(chars_to_idx)
words = [[chars_to_idx[x] for x in w] for w in words]
result = [chars_to_idx[x] for x in result]
lead = {w[0] for w in words}
lead.add(result[0])
# 每个字母的可选数字集合
mat = []
for i in range(len(chars)):
if i in lead:
mat.append(list(range(1, 10)))
else:
mat.append(list(range(10)))
# 优先挑选在开头的数字,这样可以通过范围判定是否可行
# 挑选顺序是从每个字符串开头选择一个
head = set()
orders = []
for p in range(7):
for w in words:
if p < len(w):
x = w[p]
if x not in head:
orders.append(x)
head.add(x)
if p < len(result):
x = result[p]
if x not in head:
orders.append(x)
head.add(x)
print(head, orders, result, words)
# print(orders, tail)
# for x in mat:
# print(x)
assert len(orders) == len(chars)
mapping = [-1] * 10
used = [0] * 10
def qc():
res = 0
for w in words:
if mapping[w[-1]] == -1:
return True
res += mapping[w[-1]]
if mapping[result[-1]] == -1:
return True
exp = mapping[result[-1]]
return res % 10 == exp
def to_int(w):
res = 0
for c in w:
res = res * 10 + mapping[c]
return res
def to_int_range(w):
res = 0
for (idx, c) in enumerate(w):
if mapping[c] != -1:
res = res * 10 + mapping[c]
else:
shift = (len(w) - idx)
base = 10 ** shift
return (res * base, (res + 1) * base - 1)
# note(yan): 下面这个优化还是不太好用,时间反而提升了200-400ms
# 这里如果做更加准确的估计可以缩小范围
# min_v, max_v = 0, 9
# base = 10 ** (shift - 1)
# for v in mat[c]:
# if used[v]:
# continue
# min_v = min(min_v, v)
# max_v = max(max_v, v)
# a = (res * 10 + min_v) * base
# b = (res * 10 + max_v + 1) * base - 1
# return (a, b)
return (res, res)
def range_check():
xs = [to_int_range(w) for w in words]
x0, x1 = sum([x[0] for x in xs]), sum([x[1] for x in xs])
ys = to_int_range(result)
y0, y1 = ys
if y1 < x0 or y0 > x1:
return False
return True
def test(i):
# if i == len(tail) and not qc():
# return False
if i == len(orders):
a = sum((to_int(w) for w in words))
b = to_int(result)
return a == b
# 对范围做检查. 现在所有字符的第一位数字都安排好了
# if i >= (len(words) + 1) and not range_check():
# return False
# note(yan): 不定等待所有数字都安排好就开始快速检查范围 2000ms->516ms.
if not range_check():
return False
# 针对结尾字符做检查
if not qc():
return False
x = orders[i]
if mapping[x] != -1:
if test(i + 1):
return True
else:
for v in mat[x]:
if not used[v]:
mapping[x] = v
used[v] = 1
if test(i + 1):
return True
used[v] = 0
mapping[x] = -1
return False
ans = test(0)
return ans