问题描述

使用matplotlib显示彩色图像出现问题

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-25-edd857df93f0> in <module>
     20 # import numpy as np
     21 # # np_img = np.array(pil_img)
---> 22 plt.imshow(np.array(img))
     23 # plt.imshow(img.permute(1,2,0))
     24 plt.show()
    
TypeError: Invalid shape (3, 224, 224) for image data

原因分析:

使用matplotlib显示彩色图像需要数据的维度为 【width, height, channel】,就是224 * 224 * 3

报错原因是我这里的tensor的维度为 3 * 224 * 224

x_train_tensor = torch.from_numpy(x_train)
y_train_tensor = torch.from_numpy(y_train)

解决方案:

将tensor或者数组的维度交换即可

可以使用permute函数,这个函数的参数就是我们交换之后新维度的排序,下面为1,2,0就是我们需要将原来1和2维度的内容排在前面,而通道维度放在最后

img.permute(1,2,0)

或者还可以使用transpose函数直接交换维度

img.transpose(0,2)
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐