二维数组Python代码(算法模板)


背景

最近在准备ACM校赛,遇到一个深度优先、广度优先搜索的问题,求一个地图块中有多少连接成一片的地方

整个数据是用二维数组表示的,于是我打算利用练习这个题的机会总结一个代码模板,遇到求连接片的问题可以直接套用。

甚至遇到“生命游戏”、“兰顿蚂蚁” 等等二维数组类问题也可以直接套用,同时也打算更新这类算法

众所周知,python的numpy库是一个优秀的第三方库,但是如果是算法比赛和做算法题,就不能引用这个库了。但是可以直接来这里复制~这是代码模板

待更新

  1. 水平、垂直反转容器方法
  2. 提供一个遍历操作二维容器里每个值的方法
  3. 添加生命游戏算法(如果是原地算法更好)
  4. 添加兰顿蚂蚁算法

遇到新的算法问题会继续添加

使用代码

def solution(ww, hh):
    width = ww
    height = hh
    m = Container2D(width, height)
    for y in range(height):
        line = input()
        for x, char in enumerate(line):
            if char == "*":
                m.set(x, y, 0)
            else:
                m.set(x, y, 1)
    return m.bfsGetGroupCount(1)
    # 生成完毕


def main():
    while True:
        height, width = input().split()
        height = int(height)
        width = int(width)
        if height == 0 and width == 0:
            break
        print(solution(width, height))

算法题运算结果

5 5 
****@
*@@*@
*@**@
@@@*@
@@**@

>>> 2

可以看到,算法求出来@符号连接成的区域有两块。

代码模板

# -*- encoding: utf-8 -*-
"""
二维数组、二维容器代码模板 算法模板
2021年11月28日
by littlefean
"""
from collections import deque

from typing import Tuple, Set, List


class Container2D:
    """二维容器类"""
    SPEED_CHANGE_DIC = {
        # 速度改变字典
        (1, 0): (0, 1),
        (0, 1): (-1, 0),
        (-1, 0): (0, -1),
        (0, -1): (1, 0)
    }

    def __init__(self, w: int, h: int, initVal=0):
        """
        按照宽高初始化一个二维容器
        w: 宽度
        h: 高度
        initVal: 每一个格子上的初始值
        """
        self.width = w
        self.height = h
        self.size = w * h  # 容器中能存放多少个物品,就是面积
        self.mapArr = []
        self.initVal = initVal

        for y in range(self.height):
            line = []
            for x in range(self.width):
                line.append(initVal)
            self.mapArr.append(line)

    @classmethod
    def getInstanceInit(cls):
        """获取全是默认值的实例对象"""
        pass

    def set(self, x, y, val):
        """
        甚至一个位置上的值
        Time: O(1)
        """
        if self.inContainer(x, y):
            self.mapArr[y][x] = val

    def get(self, x, y):
        """
        获取一个位置上的值
        Time: O(1)
        """
        if self.inContainer(x, y):
            return self.mapArr[y][x]
        else:
            raise Exception("访问越界")

    def getColumn(self, index: int) -> list:
        """获取某一列的元素"""
        res = []
        for y in range(self.height):
            res.append(self.get(index, y))
        return res

    @classmethod
    def getInstanceByList2D(cls, list2D: List[List]):
        """
        适用于做leet code题
        参数传入一个二维列表,返回一个该类的实例
        缺点是需要额外花费 On平方的时间
        """
        res = cls(len(list2D[0]), len(list2D))
        res.mapArr = list2D
        return res

    def rotate90(self, nag=False):
        """
        将整个二维容器旋转90度,当前只适用于正方形状
        如果二维容器不是正方形,则什么都不会发生
        nag为True的时候表示逆时针旋转
        原地算法实现
        """
        if self.width != self.height:
            return
        if nag:
            # 待做
            ...
        else:
            # 先关于 x = y 对称,然后再镜像反转
            # 转置
            for y in range(self.height):
                for x in range(0, y):
                    a, b = self.get(x, y), self.get(y, x)
                    self.set(x, y, b)
                    self.set(y, x, a)
            # 反转
            for y in range(self.height):
                for x in range(int(self.width / 2)):
                    a, b = self.get(x, y), self.get(self.width - x - 1, y)
                    self.set(x, y, b)
                    self.set(self.width - x - 1, y, a)
        return None

    def getSpiralOrder(self) -> list:
        """按照顺时针螺旋顺序返回容器里的所有元素,以列表的形式返回"""
        res = []
        visited = set()
        loc = [0, 0]
        speed = (1, 0)

        def run():
            nonlocal speed, loc
            loc[0] += list(speed)[0]
            loc[1] += list(speed)[1]

        def getNextPos() -> tuple:
            return loc[0] + speed[0], loc[1] + speed[1]

        while len(res) < self.size:
            if getNextPos() in visited or not self.inContainer(*getNextPos()):
                # 已经访问过了  下一个位置超出地图
                speed = self.SPEED_CHANGE_DIC[speed]
            # 把当前位置加入集合,当前位置值加入结果数组
            visited.add(tuple(loc))
            res.append(self.get(*loc))
            # 向前进
            run()
        return res

    @classmethod
    def getInstanceSpiral(cls, w, h, startValue):
        """获取一个螺旋矩阵实例"""
        res = cls(w, h, startValue)
        loc = [0, 0]
        speed = (1, 0)

        def run():
            nonlocal speed, loc
            loc[0] += list(speed)[0]
            loc[1] += list(speed)[1]

        def getNextPos() -> tuple:
            return loc[0] + speed[0], loc[1] + speed[1]

        visited = set()
        stuffNum = 0
        while res.size > stuffNum:
            if getNextPos() in visited or not res.inContainer(*getNextPos()):
                # 已经访问过了  下一个位置超出地图
                speed = res.SPEED_CHANGE_DIC[speed]
            # 把当前位置加入集合
            visited.add(tuple(loc))
            res.set(loc[0], loc[1], startValue + stuffNum)
            # 向前进
            run()
            stuffNum += 1
        return res

    @staticmethod
    def getLocLeftTop() -> Tuple[int, int]:
        """获取左上角的坐标位置"""
        return 0, 0

    def getLocLeftBottom(self) -> Tuple[int, int]:
        """获取左下角坐标位置"""
        return 0, self.height - 1

    def getLocRightTop(self) -> Tuple[int, int]:
        """获取右上角坐标位置"""
        return self.width - 1, 0

    def getLocRightBottom(self) -> Tuple[int, int]:
        """获取右下角坐标位置"""
        return self.width - 1, self.height - 1

    def inContainer(self, x, y):
        """
        某个下标位置是不是在里面,避免出现越界错误
        Time: O(1)"""
        return x in range(0, self.width) and y in range(0, self.height)

    def show(self):
        """
        打印情景
        Time: O(n^2)
        """
        print("==" * self.width)
        for y in range(self.height):
            for x in range(self.width):
                val = self.get(x, y)
                if self.initVal == val:
                    print(".", end=" ")
                else:
                    print(val, end=" ")
            print()

    def showPos(self, loc: Tuple[int, int]):
        """打印一个位置"""
        targetX, targetY = loc
        print("==" * self.width)
        for y in range(self.height):
            for x in range(self.width):

                if targetX == x and targetY == y:
                    print("#", end=" ")
                else:
                    print(".", end=" ")
            print()

    @staticmethod
    def roundLoc(x, y) -> list:
        """返回一个点周围一圈8个点"""
        return [(x + 1, y), (x + 1, y + 1), (x, y + 1), (x - 1, y + 1), (x - 1, y), (x - 1, y - 1), (x, y - 1),
                (x + 1, y - 1),
                ]

    @staticmethod
    def adjoinLoc(x, y) -> list:
        """返回一个点周围 上下左右四个点"""
        return [(x + 1, y), (x, y + 1), (x - 1, y), (x, y - 1)]

    def isHaveValueFast(self, target) -> bool:
        """
        快速判断一个数字是否在矩阵中,
        注意:当前容器里所有的东西都是数字
        且数字要遵循写字顺序的递增排列
        """
        for y in range(self.height):
            if self.mapArr[y][0] <= target <= self.mapArr[y][-1]:
                left = 0
                right = self.width - 1
                while left <= right:
                    mid = left + (right - left) // 2
                    if self.get(mid, y) > target:
                        right = mid - 1
                    elif self.get(mid, y) < target:
                        left = mid + 1
                    else:
                        return True
                return False
        return False

    def bfsGetGroupSet(self, targetVal, startLoc: Tuple[int, int]) -> Set[Tuple[int, int]]:
        """
        在二维数组中的一个位置开始广度优先搜索,
        返回这个位置开始成功扩散的所有点的位置集合
        """
        getRoundFunc = self.roundLoc  # 获取点的方法

        if self.get(*startLoc) != targetVal:
            return set()
        # 广度优先遍历
        q = deque()
        q.append(startLoc)
        smallSet = {startLoc}  # 用来防止和自身重复
        while len(q) != 0:
            loc = q.popleft()
            for aLoc in getRoundFunc(*loc):
                # 遍历当前这个点的邻居点
                if not self.inContainer(*aLoc):
                    # 越接
                    continue
                if aLoc in smallSet:
                    # 已经问过
                    continue
                if self.get(*aLoc) == targetVal:
                    q.append(aLoc)
                    smallSet.add(aLoc)  # 添加广度优先自身不重复集合
        # 广度优先结束
        return smallSet

    def bfsGetGroupData(self, targetVal, xJoin=False):
        """
        把二维容器想象成岛屿,返回这些岛屿的信息
        targetVal: 表示岛屿的值
        xJoin: 斜着是否算相连接
        """
        if xJoin:
            getRoundFunc = self.roundLoc  # 获取点的方法
        else:
            getRoundFunc = self.adjoinLoc  # 获取点的方法

        resObject = {
            "groupCount": 0,  # 群组的数量
            "GroupSizeList": [],  # 群组的大小列表
        }
        visitedSet = set()
        for yIndex in range(self.height):
            for xIndex in range(self.width):
                # 遍历每一个格子
                if self.get(xIndex, yIndex) == targetVal:
                    # 该格子是要寻找的值
                    if (xIndex, yIndex) in visitedSet:
                        # 这个格子已经访问过了之前
                        continue
                    else:
                        # 这个格子 是 新访问的
                        resObject["groupCount"] += 1
                        visitedSet.add((xIndex, yIndex))
                        # 广度优先遍历
                        q = deque()
                        q.append((xIndex, yIndex))
                        smallSet = {(xIndex, yIndex)}  # 用来防止和自身重复
                        while len(q) != 0:
                            loc = q.popleft()
                            for aLoc in getRoundFunc(*loc):
                                # 遍历当前这个点的邻居点
                                if not self.inContainer(*aLoc):
                                    # 越接
                                    continue
                                if aLoc in smallSet:
                                    # 已经问过
                                    continue
                                if self.get(*aLoc) == targetVal:
                                    q.append(aLoc)
                                    visitedSet.add(aLoc)  # 添加访问过了的集合
                                    smallSet.add(aLoc)  # 添加广度优先自身不重复集合
                        # 广度优先结束
                        resObject["GroupSizeList"].append(len(smallSet))
        return resObject

    def dfsGetMinPath(self, pathVal, startLoc: Tuple[int, int], endLoc: Tuple[int, int], debug=False) -> List[
        Tuple[int, int]]:
        """
        深度优先算法获取最短路径
        pathVal:可以走的路径值,如果0表示路径元素,1表示墙体,则应该写0
        startLoc:起点位置
        endLoc:终点位置
        debug:是否显示打印信息,查bug用
        """
        pathList = []

        if self.get(*startLoc) != pathVal:
            # 起点直接在墙里,没有结果
            return []

        visitedLoc = {startLoc}  # 已经访问过了的路径

        class Node:
            def __init__(self, val):
                self.loc = val
                self.father = None

            def getPathList(self):
                res = []
                node = self
                while node is not None:
                    res.append(node.loc)
                    node = node.father
                return res[::-1]

        def dfs(nowLocNode: Node):
            print("dfs in") if debug else ...
            nextList = self.adjoinLoc(*nowLocNode.loc)
            print("nextList:", nextList) if debug else ...
            if debug:
                self.showPos(nowLocNode.loc)
            for pos in nextList:
                if pos == endLoc:
                    # 走到了重点,把当前链路加入到链路列表里
                    node = Node(pos)
                    node.father = nowLocNode
                    pathList.append(node.getPathList())
                    print("发现终点!") if debug else ...
                    return
                if not self.inContainer(*pos):
                    # 不在地图中
                    print("不在地图中") if debug else ...
                    continue
                if self.get(*pos) != pathVal:
                    # 是墙
                    print("墙体", self.get(*pos), "!=", pathVal) if debug else ...
                    continue
                if pos in visitedLoc:
                    # 已经访问过
                    print("已经访问过") if debug else ...
                    continue

                visitedLoc.add(pos)

                node = Node(pos)
                node.father = nowLocNode
                dfs(node)
            ...

        dfs(Node(startLoc))
        if debug:
            for path in pathList:
                print("path: ", end="\t")
                for loc in path:
                    print(loc, end=" ")
                print()
        if len(pathList) == 1:
            return pathList[0]
        else:
            minPath = pathList[0]
            for i, path in enumerate(pathList):
                if len(path) < len(minPath):
                    minPath = path
            return minPath

    ...


class Solution:
    def spiralOrder(self, matrix: List[List[int]]) -> List[int]:
        m = Container2D.getInstanceByList2D(matrix)
        return m.getSpiralOrder()

    def generateMatrix(self, n: int) -> List[List[int]]:
        m = Container2D.getInstanceSpiral(n, n, 1)
        return m.mapArr

    def maxAreaOfIsland(self, grid: List[List[int]]) -> int:
        m = Container2D.getInstanceByList2D(grid)
        arr = m.bfsGetGroupData(1, False)["GroupSizeList"]
        if len(arr) == 0:
            return 0
        else:
            return max(arr)

    def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
        return Container2D.getInstanceByList2D(matrix).isHaveValueFast(target)


def main():
    print(Solution().spiralOrder([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]))

    ...


def test():
    count = int(input())
    for i in range(count):
        n = input()
        line1 = input()
        line2 = input()
        if line1[0] == "1" and line2[0] == "1":
            print("NO")
            continue
        if line1[-1] == "1" and line2[-1] == "1":
            print("NO")
            continue
        for x in range(len(line1)):
            if line1[x] == "1" and line2[x] == "1":
                print("NO")
                break
        else:
            print("YES")


if __name__ == "__main__":
    main()
    # test()


Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐