torch.hub.load()把联网加载修改为本地加载模型
我们常用torch.load导入本地模型,但是最近在一个yolov5+gradio的项目中遇到了torch.hub.load方法,这是一个需要联网从GitHub或其他远程位置加载预训练模型的方法。但是联网从github加载通常会出现连接超时的情况,因此转为从本地加载会是更好的选择。
1. torch.load()
torch.load函数用于从磁盘加载已保存的模型或张量,以便进行后续的操作。这也是我们常用的一种导入预训练模型的方式,可以使用以下方式调用该函数:
model = torch.load('model.pth')
其中,model.pth就是我们存放模型的路径。
2. torch.hub.load()
最近在复现某一个关于yolo的项目中遇到了这个方法,从该方法的hub可以看出,它在每次加载模型时都要联网进行加载。比如:
model = torch.hub.load(
"ultralytics/yolov5",
"custom",
path=f"{local_model_path}/{model_name}",
device=device,
force_reload=[True if "refresh_yolov5" in opt else False][0],
_verbose=True,
)
其中custom表示自定义的模型,path是本地权重文件的路径,而"ultralytics/yolov5"表示该load方法每次加载模型时,都会访问到GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite这个网址。不过有些时候国内加载github没有那么稳定,就会导致这个load方法经常报“远程连接失败”的错误。
3. 如何把torch.hub.load()改为每次从本地加载?
1)将所要加载的存储库直接搬到项目中来
比如我需要的存储库在GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite,就可以直接访问该github网站把整个包克隆下来,放到项目中来(我放在了根目录下)。
2)修改hub.load代码
修改代码如下:
model = torch.hub.load(
"./ultralytics_yolov5_master",
"custom",
path=f"{local_model_path}/{model_name}",
device=device,
source='local',
force_reload=[True if "refresh_yolov5" in opt else False][0],
_verbose=True,
)
主要是两处发生了变化,一个是增加了参数source='local',指明我们是要从本地加载而不是联网加载(因为默认是source='github'),另外就是第一个参数中的路径(即加载路径)发生了变化,因为我们在第一步中已经将存储库拷贝到本地项目包的根目录下了。
到这里,之后再运行项目就会默认从本地加载啦。(>_< 联网加载真的太折磨人了)
---------------------------------------------------------------------------------------------------------------------------------
新人发帖,多多关照 ~
更多推荐
所有评论(0)