Skip to content

Latest commit

 

History

History
91 lines (85 loc) · 3.47 KB

378_有序矩阵中第K小的元素.org

File metadata and controls

91 lines (85 loc) · 3.47 KB

题目

Screen-Pictures/%E9%A2%98%E7%9B%AE/2020-06-11_14-46-26_%E6%88%AA%E5%B1%8F2020-06-11%20%E4%B8%8B%E5%8D%882.46.23.png

思路

  • 二分查找

Screen-Pictures/%E6%80%9D%E8%B7%AF/2020-06-11_14-47-20_%E6%88%AA%E5%B1%8F2020-06-11%20%E4%B8%8B%E5%8D%882.47.18.png

  • 优先队列 维护一个长度为k的优先队列(堆),遍历矩阵时,把元素加到队列中,直到队列长度>k,此时对队列调整排序,弹出最大的元素。可以对遍历过程进行优化,

code

# 二分法
n = len(matrix)
        lo = matrix[0][0]
        hi = matrix[n-1][n-1]
        while lo <= hi:
            mid = lo + (hi - lo) // 2
            cnt = 0
            for i in range(n):
                for j in range(n):
                    if matrix[i][j] > mid:
                        break
                    cnt += 1

            if cnt < k:
                lo = mid + 1
            else:
                hi = mid - 1
        
        return lo

#优先队列
def build_heap(root, k_list, end):
            while True:
                child = 2*root + 1
                if child > end:
                    break
                if child + 1 <= end and k_list[child] < k_list[child+1]:
                    child += 1
                if k_list[root] < k_list[child]:
                    k_list[root], k_list[child] = k_list[child], k_list[root]
                    root = child
                else:
                    break

        def heapsort(k_list):
            n = len(k_list)
            first_root = n//2 - 1
            for root in range(first_root, -1, -1):
                build_heap(root, k_list, len(k_list)-1)
            for node in range(len(k_list)-1, 0, -1):
                k_list[0], k_list[node] = k_list[node], k_list[0]
                build_heap(0, k_list, node-1)

        m, n = len(matrix), len(matrix[0])
        # print(m, n)
        if k > m*n //2:
            k = m*n - k + 1
            k_list = []
            for i in range(m-1, -1, -1):
                for j in range(n-1, -1, -1):
                    if k_list:
                        if matrix[i][j] < k_list[0] and len(k_list)==k:
                            break
                    k_list.append(matrix[i][j])
                    # print(k_list)
                    if len(k_list)>=k:
                        heapsort(k_list)
                        # print(k_list)   
                    if len(k_list) > k:
                        k_list.pop(0)
                    # print(k_list)
            return k_list[0]
        else:
            k_list = []
            for i in range(m):
                for j in range(n):
                    if k_list:
                        if matrix[i][j] > k_list[-1] and len(k_list)==k:
                            break
                    k_list.append(matrix[i][j])
                    # print(k_list)
                    if len(k_list)>=k:
                        heapsort(k_list)
                        # print(k_list)   
                    if len(k_list) > k:
                        k_list.pop()
                    # print(k_list)
            return k_list[-1]