从一道面试题到单调队列优化 DP

前情提要

一个朋友询问我一道某大厂的面试题,看到题目的第一眼想到了 $\mathrm O(n\times k)$ 的动态规划算法来解决,朋友反馈说记得 $n$ 和 $k$ 至少 $10^5$,后来发现只要维护最后 $k$ 个元素的最大值即可,直接塞到 set 里面就能优化到 $\mathrm O(n\log n)$ 或 $\mathrm O(n\log k)$,应该可以 cover 掉数据范围,朋友也对复杂度满意了。但我总觉得可以 $\mathrm O(n)$ 线性解决,在为朋友写代码时,突然隐隐约约觉得单调队列似乎就是这样子,去 OI-Wiki 查询果然是这样,遂写此篇题解博客记录不断思考优化算法复杂度的心路历程。

题目大意

给定长度为 $n$ 一个数组,有正有负的数。问你在其中取哪些数,可以使得这些数的和最大。但是条件是,每 $k$ 个连续的数,都至少要取一个出来。输入是$n$、$k$、数组;输出是最大的这个和。

样例

输入

5 3
-4 -100 -9 -100 -4

输出

-9

题解

令 $dp[i]$ 表示取了第 $i$ 个数的情况下,取到的数的最大值。则状态转移方程为:
$$
dp[i] = \max_{j \in [1, \min{(k, i)}]}(dp[i - j]) + arr[i]
$$
最终答案为:
$$
\max_{j\in[0, k - 1]}(dp[n - j])
$$

直接按照数学公式计算即可得到答案,时间复杂度为 $\mathrm O(n\times k)$。

我们注意到,我们需要一直维护数组中最后 $k$ 个元素的最大值,这些最大值具有连续性($i$ 每次增加,都需要删掉一个元素,然后增加一个新元素),因此可以使用增加、删除以及查询最大值复杂度均为 $\mathrm O(\log n)$ 的数据结构来维护。

C++ 中我们可以使用 STL 中基于红黑树的 set 容器来维护,因为容器中最多有 $k$ 个元素,所以时间复杂度为 $\mathrm O(\log k)$;同理,在 Python 中我们可以使用基于小根堆的 heapq 优先队列来维护,因为队列中最多有 $n$ 个元素(元素均为非负数),所以时间复杂度为 $\mathrm O(\log n)$。

我们注意到,在优先队列中,其实某些元素永远不会出队,而这些元素共同的性质便是比队头早入队 $k$ 个或以上且比队头的元素小。那么我们能不能在当前队头元素入队时,将上述元素都踢出队列呢?优先队列显然是做不到的:因为上述元素一般集中在堆底。但如果我们尝试换成普通的队列呢?先不管如何查询最值,每次遇到一个元素,我们如果发现队尾元素比该元素更小,那么其实队尾元素在未来永远也用不到了(因为会比该元素早离开最后 $k$ 个元素的范围),不断重复这个过程直到遇到一个比该元素大的队尾,此时再将该元素插入队尾——这时候让我们回过头来看,会惊奇地发现,整个队列中的元素竟然是单调递减的!那么想要查询最大值只要从队头开始找属于最后 $k$ 个元素范围内的第一个元素即可。与此同时我们注意到,如果队头元素已经超出最后 $k$ 个元素的范围,那么该元素未来也不可能会用到了,所以同样可以将符合这一条件的队头元素踢出队列。于是乎,我们总是可以直接查询队头元素来得到当前最后 $k$ 个元素的最大值。因为这个过程中,每个元素只进队和出队一次,而查询最值也是 $\mathrm O(1)$ 的复杂度,因此总时间复杂度为 $\mathrm O(n)$。这种允许在一端插入元素、两端都删除元素的数据结构叫做双端队列 deque,在 Python 中位于标准库 collections 中,在 C++ 中位于 STL 中。而本题中的双端队列永远保持单调递减,所以全程都具有这种单调性质的队列叫做“单调队列”,因而这种对于动态规划的优化方法叫做“单调队列优化”。

代码

  • $\mathrm O(n\times k)$ 解法:

    def solve(arr, n, k):
        dp = [0]
        for i in range(1, n + 1):
            dp.append(max(dp[i - min(k, i):]) + arr[i - 1])
            # 也可以像下面这样写:
            # dp.append(dp[i - 1])
            # for j in range(2, min(k, i) + 1):
            #     dp[i] = max(dp[i], dp[i - j])
            # dp[i] += arr[i - 1]
        return max(dp[n - min(k, n - 1):])
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    
  • $\mathrm O(n\times \log n)$ 解法:

    import heapq
    def solve(arr, n, k):
        Q = [(0, 0)]
        for i in range(1, n + 1):
            heapq.heappush(Q, (Q[0][0] - arr[i - 1], i))
            while Q[0][1] <= i - k: heapq.heappop(Q)
        return -Q[0][0]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    
  • $\mathrm O(n)$ 解法:
    保留 $dp$ 数组:

    from collections import deque
    def solve(arr, n, k):
        dp = [0]
        q = deque([0])
        for i in range(1, n + 1):
            dp.append(dp[q[0]] + arr[i - 1])
            while q and dp[q[-1]] <= dp[i]: q.pop() # 将比当前值小的元素全部弹出
            q.append(i)
            while q[0] <= i - k: q.popleft() # 判断队首元素是否在窗口内
        return dp[q[0]]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))
    

    不保留 $dp$ 数组:

    from collections import deque
    def solve(arr, n, k):
        q = deque([(0, 0)])
        for i in range(1, n + 1):
            t = q[0][0] + arr[i - 1]
            while q and q[-1][0] <= t: q.pop() # 将比当前值小的元素全部弹出
            q.append((t, i))
            while q[0][1] <= i - k: q.popleft() # 判断队首元素是否在窗口内
        return q[0][0]
    if __name__ == '__main__':
        n, k = map(int, input().split())
        arr = list(map(int, input().split()))
        print(solve(arr, n, k))