前置条件:K8集群、helm

1、以模型名作为目录名,创建目录

mkdir resnet50_pytorch

2、将模型文件、配置文件(输入、输出等)存到刚创建的目录下,resnet50_pytorch目录下文件层级结构如下

model-respository/
└── resnet50_pytorch            # 模型名字,需要和config.txt中名字一致
    ├── 1                       # 模型版本号
    │   └── model.pt            # 上面保存的模型
    ├── config.pbtxt            # 模型配置文件,必需
    ├── labels.txt              # 可选,分类标签信息,注意格式
    ├── resnet_client.py        # 客户端脚本,可以不放在这里
    └── resnet_pytorch.py       # 生成 model.pt 的脚本,可以不放在这里

3、生成模型model.pt的脚本resnet_pytorch.py如下

import torch
import torchvision.models as models

resnet50 = models.resnet50(pretrained=True)
resnet50.eval()
image = torch.randn(1, 3, 244, 244)
resnet50_traced = torch.jit.trace(resnet50, image)
resnet50(image)
resnet50_traced.save('model.pt')

4、模型配置文件config.pbtxt示例如下,模型的输入是[ N, 3, -1, -1 ] 的图片,输出是 [ N, 1000 ] 维度的分类向量,并指定了分类的文件名,用于获取分类结果

name: "resnet50_pytorch"
platform: "pytorch_libtorch"
max_batch_size: 128
input [
  {
    name: "INPUT__0"
    data_type: TYPE_FP32
    dims: [ 3, -1, -1 ]
  }
]
output [
  {
    name: "OUTPUT__0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
    label_filename: "labels.txt"
  }
]

5、通过http调用Triton推理服务的脚本:resnet_client.py示例

import numpy as np
import tritonclient.http as httpclient
import torch
from PIL import Image


if __name__ == '__main__':
    triton_client = httpclient.InferenceServerClient(url='10.114.242.30:30381')#dev环境
    #triton_client = httpclient.InferenceServerClient(url='20.102.236.169:8000')#生产环境
    image = Image.open('./cat.jpg')
    
    image = image.resize((224, 224), Image.ANTIALIAS)
    image = np.asarray(image)
    image = image / 255
    image = np.expand_dims(image, axis=0)
    image = np.transpose(image, axes=[0, 3, 1, 2])
    image = image.astype(np.float32)

    inputs = []
    inputs.append(httpclient.InferInput('INPUT__0', image.shape, "FP32"))
    inputs[0].set_data_from_numpy(image, binary_data=False)
    outputs = []
    outputs.append(httpclient.InferRequestedOutput('OUTPUT__0', binary_data=False, class_count=3)) 

    results = triton_client.infer('resnet50_pytorch', inputs=inputs, outputs=outputs)
    output_data0 = results.as_numpy('OUTPUT__0')
    print(output_data0.shape)
    print(output_data0)

执行命令python3 resnet_client.py输出如下,根据传入的图片,可以识别出图片内容

参考文章:

https://github.com/zzk0/triton/tree/master/quick/resnet50_pytorch

https://www.cnblogs.com/zzk0/p/15543824.html

Logo

K8S/Kubernetes社区为您提供最前沿的新闻资讯和知识内容

更多推荐