K8s环境使用Triton实现云端模型推理
K8s环境部署Triton实现云端模型推理
·
前置条件:K8集群、helm
1、以模型名作为目录名,创建目录
mkdir resnet50_pytorch
2、将模型文件、配置文件(输入、输出等)存到刚创建的目录下,resnet50_pytorch目录下文件层级结构如下
model-respository/
└── resnet50_pytorch # 模型名字,需要和config.txt中名字一致
├── 1 # 模型版本号
│ └── model.pt # 上面保存的模型
├── config.pbtxt # 模型配置文件,必需
├── labels.txt # 可选,分类标签信息,注意格式
├── resnet_client.py # 客户端脚本,可以不放在这里
└── resnet_pytorch.py # 生成 model.pt 的脚本,可以不放在这里
3、生成模型model.pt的脚本resnet_pytorch.py如下
import torch
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
resnet50.eval()
image = torch.randn(1, 3, 244, 244)
resnet50_traced = torch.jit.trace(resnet50, image)
resnet50(image)
resnet50_traced.save('model.pt')
4、模型配置文件config.pbtxt示例如下,模型的输入是[ N, 3, -1, -1 ] 的图片,输出是 [ N, 1000 ] 维度的分类向量,并指定了分类的文件名,用于获取分类结果。
name: "resnet50_pytorch"
platform: "pytorch_libtorch"
max_batch_size: 128
input [
{
name: "INPUT__0"
data_type: TYPE_FP32
dims: [ 3, -1, -1 ]
}
]
output [
{
name: "OUTPUT__0"
data_type: TYPE_FP32
dims: [ 1000 ]
label_filename: "labels.txt"
}
]
5、通过http调用Triton推理服务的脚本:resnet_client.py示例
import numpy as np
import tritonclient.http as httpclient
import torch
from PIL import Image
if __name__ == '__main__':
triton_client = httpclient.InferenceServerClient(url='10.114.242.30:30381')#dev环境
#triton_client = httpclient.InferenceServerClient(url='20.102.236.169:8000')#生产环境
image = Image.open('./cat.jpg')
image = image.resize((224, 224), Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = np.expand_dims(image, axis=0)
image = np.transpose(image, axes=[0, 3, 1, 2])
image = image.astype(np.float32)
inputs = []
inputs.append(httpclient.InferInput('INPUT__0', image.shape, "FP32"))
inputs[0].set_data_from_numpy(image, binary_data=False)
outputs = []
outputs.append(httpclient.InferRequestedOutput('OUTPUT__0', binary_data=False, class_count=3))
results = triton_client.infer('resnet50_pytorch', inputs=inputs, outputs=outputs)
output_data0 = results.as_numpy('OUTPUT__0')
print(output_data0.shape)
print(output_data0)
执行命令python3 resnet_client.py输出如下,根据传入的图片,可以识别出图片内容
参考文章:
https://github.com/zzk0/triton/tree/master/quick/resnet50_pytorch
https://www.cnblogs.com/zzk0/p/15543824.html
更多推荐
已为社区贡献4条内容
所有评论(0)