Skip to content

Latest commit

 

History

History
87 lines (70 loc) · 2.71 KB

lc-2157-groups-of-strings.org

File metadata and controls

87 lines (70 loc) · 2.71 KB

LC 2157. 字符串分组

https://leetcode-cn.com/problems/groups-of-strings/

如果按照修改状态去枚举的话,那么修改状态可能会多达 26*26 种,如果乘以N = 2*10^4的话,那么肯定会出现超时。我花了很多时间在优化这个操作上,但是方向却是错误的。

题解里面这个解法非常不错 https://leetcode-cn.com/problems/groups-of-strings/solution/jiang-ti-huan-cao-zuo-de-fu-za-du-you-o2-951t/ 大致思想是:

  • 对于增加和删除一个字符操作,可以使用 `st ^ (1<<i)` 来完成
  • 对于更换一个字符操作,可以可以使用 `st ^ (1<<i) | (1 << 26)`, 相当于做个标记
  • 这个标记意味着可以替换任何字符串,类似 “*bc” 这样

然后就是使用并集查找的数据结构,不过这里稍微有点特殊的是,对于更换一个字符操作可能产生的状态是不确定的。

leetcode上面总是有这些好题,工程量不是那么大,对于基础知识要求也不是特别高,但实现的时候需要想想和稍微绕点弯子。

class UnionFind:
    def __init__(self, values):
        r, c, = {}, {}
        for v in values:
            r[v], c[v] = v, 1
        self.r, self.c = r, c

    def size(self, a):
        ra = self.find(a)
        return self.c[ra]

    def find(self, a):
        if a not in self.r:
            self.r[a] = a
            self.c[a] = 1
            return a

        # find root.
        x = a
        while True:
            ra = self.r[x]
            if ra == x:
                break
            x = ra

        # compress path.
        x = a
        while x != ra:
            rx = self.r[x]
            self.r[x] = ra
            x = rx
        return ra

    def merge(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        ca, cb = self.c[ra], self.c[rb]
        if ca > cb:
            ca, cb, ra, rb = cb, ca, rb, ra
        self.r[ra] = rb
        self.c[rb] += ca


class Solution:
    def groupStrings(self, words: List[str]) -> List[int]:
        def value(w):
            st = 0
            for c in w:
                c2 = ord(c) - ord('a')
                st = st | (1 << c2)
            return st

        values = [value(w) for w in words]
        un = UnionFind(values)
        tmp = set(values)

        for st in values:
            for i in range(26):
                st2 = st ^ (1 << i)
                if st2 in tmp:
                    un.merge(st, st2)

            for i in range(26):
                if st & (1 << i):
                    st2 = (1 << 26) | (st ^ (1 << i))
                    un.merge(st, st2)

        d = Counter(un.find(st) for st in values)
        return [len(d), max(d.values())]