算法模式:回溯

算法模式:回溯

在上一篇文章 算法模式:变治法 介绍一种有魔力的,可以将复杂问题化繁为简,化腐朽为神奇的算法模式:变治法。本篇文章,介绍一种“一步三回头”、“落棋有悔”的算法模式:回溯。

回溯

“回溯”算法也叫“回溯搜索”算法,主要用于在一个庞大的空间里搜索我们所需要的问题的解。我们每天使用的“搜索引擎”就是帮助我们在庞大的互联网上搜索我们需要的信息。“搜索”引擎的“搜索”和“回溯搜索”算法的“搜索”意思是一样的。

“回溯”指的是“状态重置”,可以理解为“回到过去”、“恢复现场”,是在编码的过程中,是为了节约空间而使用的一种技巧。而回溯其实是“深度优先遍历”特有的一种现象。之所以是“深度优先遍历”,是因为我们要解决的问题通常是在一棵树上完成的,在这棵树上搜索需要的答案,一般使用深度优先遍历。

“全排列”就是一个非常经典的“回溯”算法的应用。我们知道,N 个数字的全排列一共有 \$N!\$ 这么多个。

使用编程的方法得到全排列,就是在这样的一个树形结构中进行编程,具体来说,就是执行一次深度优先遍历,从树的根结点到叶子结点形成的路径就是一个全排列。

0046 01

说明:

  1. 每一个结点表示了“全排列”问题求解的不同阶段,这些阶段通过变量的“不同的值”体现;

  2. 这些变量的不同的值,也称之为“状态”;

  3. 使用深度优先遍历有“回头”的过程,在“回头”以后,状态变量需要设置成为和先前一样;

  4. 因此在回到上一层结点的过程中,需要撤销上一次选择,这个操作也称之为“状态重置”;

  5. 深度优先遍历,可以直接借助系统栈空间,为我们保存所需要的状态变量,在编码中只需要注意遍历到相应的结点的时候,状态变量的值是正确的,具体的做法是:往下走一层的时候,path 变量在尾部追加,而往回走的时候,需要撤销上一次的选择,也是在尾部操作,因此 path 变量是一个栈。

  6. 深度优先遍历通过“回溯”操作,实现了全局使用一份状态变量的效果。

解决一个回溯问题,实际上就是一个决策树的遍历过程。只需要思考 3 个问题:

  1. 路径:也就是已经做出的选择。

  2. 选择列表:也就是你当前可以做的选择。

  3. 结束条件:也就是到达决策树底层,无法再做选择的条件。

这三个问题也就对应回溯三部曲:

  1. 定义递归函数以及参数

  2. 确定递归终止条件

  3. 思考递归单层搜索逻辑

代码方面,回溯算法的框架:

result = []
def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(路径)
        return

    for 选择 in 选择列表:
        做选择
        backtrack(路径, 选择列表)
        撤销选择

其核心就是 for 循环里面的递归,在递归调用之前「做选择」,在递归调用之后「撤销选择」,特别简单。

必须说明的是,不管怎么优化,都符合回溯框架,而且时间复杂度都不可能低于 \$O(N!)\$,因为穷举整棵决策树是无法避免的。这也是回溯算法的一个特点,不像动态规划存在重叠子问题可以优化,回溯算法就是纯暴力穷举,复杂度一般都很高。

玩回溯,一定要画出递归调用树。这样可以帮助我们更深入地理解整个回溯的过程,方便进一步剪枝优化。

回溯优化,重要的是,要学会剪枝!

LeetCode 46. 全排列

给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。

示例 1:

输入:nums = [1,2,3]
输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]

示例 2:

输入:nums = [0,1]
输出:[[0,1],[1,0]]

示例 3:

输入:nums = [1]
输出:[[1]]

提示:

  • 1 <= nums.length <= 6

  • -10 <= nums[i] <= 10

  • nums 中的所有整数 互不相同

思路分析

全排列的递归调用树如下:

0046 01

参考代码框架,代码如下:

/**
 * @author D瓜哥 · https://www.diguage.com
 * @since 2025-04-06 16:50
 */
public List<List<Integer>> permute(int[] nums) {
  List<List<Integer>> result = new ArrayList<>();
  backtrack(nums, result, new ArrayList<>(), new boolean[nums.length]);
  return result;
}

private void backtrack(int[] nums, List<List<Integer>> result,
                       List<Integer> path, boolean[] used) {
  if (path.size() == nums.length) {
    result.add(new ArrayList<>(path));
    return;
  }
  for (int i = 0; i < nums.length; i++) {
    if (used[i]) {
      continue;
    }
    // 选择
    used[i] = true;
    path.add(nums[i]);
    backtrack(nums, result, path, used);
    // 撤销
    path.removeLast();
    used[i] = false;
  }
}