1. torch.load()

torch.load函数用于从磁盘加载已保存的模型或张量,以便进行后续的操作。这也是我们常用的一种导入预训练模型的方式,可以使用以下方式调用该函数:

model = torch.load('model.pth')

其中,model.pth就是我们存放模型的路径。

2.  torch.hub.load()

最近在复现某一个关于yolo的项目中遇到了这个方法,从该方法的hub可以看出,它在每次加载模型时都要联网进行加载。比如:

model = torch.hub.load(
            "ultralytics/yolov5",
            "custom",
            path=f"{local_model_path}/{model_name}",
            device=device,
            force_reload=[True if "refresh_yolov5" in opt else False][0],
            _verbose=True,
        )

其中custom表示自定义的模型,path是本地权重文件的路径,而"ultralytics/yolov5"表示该load方法每次加载模型时,都会访问到GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite这个网址。不过有些时候国内加载github没有那么稳定,就会导致这个load方法经常报“远程连接失败”的错误。

3. 如何把torch.hub.load()改为每次从本地加载?

1)将所要加载的存储库直接搬到项目中来

比如我需要的存储库在GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite,就可以直接访问该github网站把整个包克隆下来,放到项目中来(我放在了根目录下)。

2)修改hub.load代码

修改代码如下:

model = torch.hub.load(
            "./ultralytics_yolov5_master",
            "custom",
            path=f"{local_model_path}/{model_name}",
            device=device,
            source='local',
            force_reload=[True if "refresh_yolov5" in opt else False][0],
            _verbose=True,
        )

主要是两处发生了变化,一个是增加了参数source='local',指明我们是要从本地加载而不是联网加载(因为默认是source='github'),另外就是第一个参数中的路径(即加载路径)发生了变化,因为我们在第一步中已经将存储库拷贝到本地项目包的根目录下了。

到这里,之后再运行项目就会默认从本地加载啦。(>_<  联网加载真的太折磨人了)

---------------------------------------------------------------------------------------------------------------------------------

新人发帖,多多关照 ~

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐