在二维数组的迷宫中寻找最短路径


背景

算法原题 迷宫问题 如果这个链接打不开了,可能说明时间有点久远了,学校已经出了新的oj系统,或者出现了点意外

在做一个迷宫表示的二维数组中寻找最短路径时想到这个算法可以放到自己做的“二维数组”类里,以后再做寻找最短路径问题的时候就可以直接套用了。

目前这个二维数组类里面包含了不属于这个题目的方法。可以忽视。

算法说明

使用的是深度优先搜索。最终返回的是一个最短的从起点到终点的路径,如果有多条路径都有相同的长度,那么只会返回一个。

整个实现的原理:其实是一边深度优先搜索,一边构建一个树结构,这个树的每个节点都有指向父元素的指针,但是不需要指向子元素的指针。一旦发现终点的时候,这个位置节点就会一直向上找到父节点,向上找父节点直到父节点是起始位置节点的时候,这一串节点,就是一条能够从起点走到终点的路径了。把这条路径加入路径列表。最后从路径列表里找到一个最短的路径就可以了。

需要注意的是向上找到父节点,初始位置节点的时候,还需要再把整个路径翻转,才是从起点到终点。

进行深度优先遍历的时候,会有很多停止递归的条件:

  1. 这个位置不在地图里
  2. 这个位置是不能走的位置(墙体)
  3. 这个位置之前已经访问过了(为了放置一直转圈。所以需要额外再创建一个集合来存储已经走过了的位置)

使用方法

先用循环输入的方式构造好二维数组对象,

调用.dfsGetMinPath(...)方法,这里展示一下这个函数的说明

    def dfsGetMinPath(self, pathVal, startLoc: Tuple[int, int], endLoc: Tuple[int, int], debug=False) -> List[
        Tuple[int, int]]:
        """
        深度优先算法获取最短路径
        pathVal:可以走的路径值,如果0表示路径元素,1表示墙体,则应该写0
        startLoc:起点位置
        endLoc:终点位置
        debug:是否显示打印信息,查bug用
        """

如果是1表示墙体,0表示道路,二维数组是这样的,那么pathVal参数就应该写0,因为0表示的是可以在上面走的元素。

如果左上角是起点,可以调用类内部的 getLocLeftTop() 方法直接返回容器左上角定点的值,其实是固定的 (0, 0)

另外还可以调用 getLocLeftBottomgetLocRightTopgetLocRightBottom 方法返回其他角落的位置

0 1 0 0 0
0 1 0 1 0
0 0 0 0 0
0 1 1 1 0
0 0 0 1 0

代码模板

# -*- encoding: utf-8 -*-
"""
二维数组、二维容器代码模板 算法模板
2021年11月28日
by littlefean
"""
from collections import deque

from typing import Tuple, Set, List


class Container2D:
    def __init__(self, w: int, h: int, initVal=0):
        """
        按照宽高初始化一个二维容器
        w: 宽度
        h: 高度
        initVal: 每一个格子上的初始值
        """
        self.width = w
        self.height = h
        self.mapArr = []
        self.initVal = initVal
        for y in range(self.height):
            line = []
            for x in range(self.width):
                line.append(initVal)
            self.mapArr.append(line)

    def set(self, x, y, val):
        """
        甚至一个位置上的值
        Time: O(1)
        """
        if self.inContainer(x, y):
            self.mapArr[y][x] = val

    def get(self, x, y):
        """
        获取一个位置上的值
        Time: O(1)
        """
        if self.inContainer(x, y):
            return self.mapArr[y][x]
        else:
            raise Exception("访问越界")

    def getColumn(self, index: int) -> list:
        """获取某一列的元素"""
        res = []
        for y in range(self.height):
            res.append(self.get(index, y))
        return res

    @staticmethod
    def getLocLeftTop() -> Tuple[int, int]:
        """获取左上角的坐标位置"""
        return 0, 0

    def getLocLeftBottom(self) -> Tuple[int, int]:
        """获取左下角坐标位置"""
        return 0, self.height - 1

    def getLocRightTop(self) -> Tuple[int, int]:
        """获取右上角坐标位置"""
        return self.width - 1, 0

    def getLocRightBottom(self) -> Tuple[int, int]:
        """获取右下角坐标位置"""
        return self.width - 1, self.height - 1

    def inContainer(self, x, y):
        """
        某个下标位置是不是在里面,避免出现越界错误
        Time: O(1)"""
        return x in range(0, self.width) and y in range(0, self.height)

    def show(self):
        """
        打印情景
        Time: O(n^2)
        """
        print("==" * self.width)
        for y in range(self.height):
            for x in range(self.width):
                val = self.get(x, y)
                if self.initVal == val:
                    print(".", end=" ")
                else:
                    print(val, end=" ")
            print()

    def showPos(self, loc: Tuple[int, int]):
        """打印一个位置"""
        targetX, targetY = loc
        print("==" * self.width)
        for y in range(self.height):
            for x in range(self.width):

                if targetX == x and targetY == y:
                    print("#", end=" ")
                else:
                    print(".", end=" ")
            print()

    @staticmethod
    def roundLoc(x, y) -> list:
        """返回一个点周围一圈8个点"""
        return [(x + 1, y), (x + 1, y + 1), (x, y + 1), (x - 1, y + 1), (x - 1, y), (x - 1, y - 1), (x, y - 1),
                (x + 1, y - 1),
                ]

    @staticmethod
    def adjoinLoc(x, y) -> list:
        """返回一个点周围 上下左右四个点"""
        return [(x + 1, y), (x, y + 1), (x - 1, y), (x, y - 1)]

    def bfsGetGroupSet(self, targetVal, startLoc: Tuple[int, int]) -> Set[Tuple[int, int]]:
        """
        在二维数组中的一个位置开始广度优先搜索,
        返回这个位置开始成功扩散的所有点的位置集合
        """
        getRoundFunc = self.roundLoc  # 获取点的方法

        if self.get(*startLoc) != targetVal:
            return set()
        # 广度优先遍历
        q = deque()
        q.append(startLoc)
        smallSet = {startLoc}  # 用来防止和自身重复
        while len(q) != 0:
            loc = q.popleft()
            for aLoc in getRoundFunc(*loc):
                # 遍历当前这个点的邻居点
                if not self.inContainer(*aLoc):
                    # 越接
                    continue
                if aLoc in smallSet:
                    # 已经问过
                    continue
                if self.get(*aLoc) == targetVal:
                    q.append(aLoc)
                    smallSet.add(aLoc)  # 添加广度优先自身不重复集合
        # 广度优先结束
        return smallSet

    def bfsGetGroupCount(self, targetVal):
        """
        在二维数组中广度优先搜索,有多少个目标值连接成块的块数
        targetVal: 寻找的目标值
        当前默认是斜着也算连接着,如果想改成上下左右算连接,
        更改 getRoundFunc = self.roundLoc
        改成 getRoundFunc = self.adjoinLoc
        Time: >= O(n^2)
        """

        getRoundFunc = self.roundLoc  # 获取点的方法

        visitedSet = set()
        resFinal = 0
        for yIndex in range(self.height):
            for xIndex in range(self.width):
                # 遍历每一个格子
                if self.get(xIndex, yIndex) == targetVal:
                    # 该格子是要寻找的值
                    if (xIndex, yIndex) in visitedSet:
                        # 这个格子已经访问过了之前
                        continue
                    else:
                        # 这个格子 是 新访问的
                        resFinal += 1
                        visitedSet.add((xIndex, yIndex))
                        # 广度优先遍历
                        q = deque()
                        q.append((xIndex, yIndex))
                        smallSet = {(xIndex, yIndex)}  # 用来防止和自身重复
                        while len(q) != 0:
                            loc = q.popleft()
                            for aLoc in getRoundFunc(*loc):
                                # 遍历当前这个点的邻居点
                                if not self.inContainer(*aLoc):
                                    # 越接
                                    continue
                                if aLoc in smallSet:
                                    # 已经问过
                                    continue
                                if self.get(*aLoc) == targetVal:
                                    q.append(aLoc)
                                    visitedSet.add(aLoc)  # 添加访问过了的集合
                                    smallSet.add(aLoc)  # 添加广度优先自身不重复集合
                        # 广度优先结束
        return resFinal

    def dfsGetMinPath(self, pathVal, startLoc: Tuple[int, int], endLoc: Tuple[int, int], debug=False) -> List[
        Tuple[int, int]]:
        """
        深度优先算法获取最短路径
        pathVal:可以走的路径值,如果0表示路径元素,1表示墙体,则应该写0
        startLoc:起点位置
        endLoc:终点位置
        debug:是否显示打印信息,查bug用
        """
        pathList = []

        if self.get(*startLoc) != pathVal:
            # 起点直接在墙里,没有结果
            return []

        visitedLoc = {startLoc}  # 已经访问过了的路径

        class Node:
            def __init__(self, val):
                self.loc = val
                self.father = None

            def getPathList(self):
                res = []
                node = self
                while node is not None:
                    res.append(node.loc)
                    node = node.father
                return res[::-1]

        def dfs(nowLocNode: Node):
            print("dfs in") if debug else ...
            nextList = self.adjoinLoc(*nowLocNode.loc)
            print("nextList:", nextList) if debug else ...
            if debug:
                self.showPos(nowLocNode.loc)
            for pos in nextList:
                if pos == endLoc:
                    # 走到了重点,把当前链路加入到链路列表里
                    node = Node(pos)
                    node.father = nowLocNode
                    pathList.append(node.getPathList())
                    print("发现终点!") if debug else ...
                    return
                if not self.inContainer(*pos):
                    # 不在地图中
                    print("不在地图中") if debug else ...
                    continue
                if self.get(*pos) != pathVal:
                    # 是墙
                    print("墙体", self.get(*pos), "!=", pathVal) if debug else ...
                    continue
                if pos in visitedLoc:
                    # 已经访问过
                    print("已经访问过") if debug else ...
                    continue

                visitedLoc.add(pos)

                node = Node(pos)
                node.father = nowLocNode
                dfs(node)
            ...

        dfs(Node(startLoc))
        if debug:
            for path in pathList:
                print("path: ", end="\t")
                for loc in path:
                    print(loc, end=" ")
                print()
        if len(pathList) == 1:
            return pathList[0]
        else:
            minPath = pathList[0]
            for i, path in enumerate(pathList):
                if len(path) < len(minPath):
                    minPath = path
            return minPath

    ...


def main():
    m = Container2D(5, 5)
    for y in range(5):
        line = input().split()
        for x, v in enumerate(line):
            m.set(x, y, int(v))
    # m.show()
    arr = m.dfsGetMinPath(0, m.getLocLeftTop(), m.getLocRightBottom(), debug=False)
    for item in arr:
        print(f"({item[1]}, {item[0]})")
    ...


if __name__ == "__main__":
    main()
    # test()

键盘输入:

0 1 0 0 0
0 1 0 1 0
0 0 0 0 0
0 1 1 1 0
0 0 0 1 0

打印输出

(0, 0)
(1, 0)
(2, 0)
(2, 1)
(2, 2)
(2, 3)
(2, 4)
(3, 4)
(4, 4)
Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐