import os
import time
from d2l import torch as d2l
from net_load import Net
import torch
from imageio import imread
from torchvision import transforms
import matplotlib.pyplot as plt
def nonDisplay(ax):
ax.axes.get_yaxis().set_visible(False)
ax.axes.get_xaxis().set_visible(False)
ax.axes.spines['top'].set_visible(False)
ax.axes.spines['right'].set_visible(False)
ax.axes.spines['bottom'].set_visible(False)
ax.axes.spines['left'].set_visible(False)
arrowRiSoP = imread("../img/arrow_RiSoP-Net.png")
arrowCNN = imread("../img/arrow_StandardCNN.png")
img = imread("../img/7.png")
t_data = torch.zeros((1, 1, 28, 28))
t_data[0][0] = torch.tensor(img)
t_data = t_data.cuda()
t_data.shape
net = d2l.torch.load('mnist_model.pkl')
net.eval()
cnn_net = d2l.torch.load('mnist_CNN.pkl')
cnn_net.eval()
res = d2l.F.softmax(net(t_data), dim=1)
bar_data = res.tolist()[0]
print(bar_data)
fig = plt.figure(1, figsize=(5, 5))
path = "../img/MNIST/"
for p in os.listdir(path):
img = imread(f"{path}{p}")
for degree in range(0, 365, 15):
plt.subplot(2, 4, 1)
plt.imshow(img, "gray")
nonDisplay(plt.gca())
plt.subplot(2, 4, 2)
rotation = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomRotation((degree, degree)),
transforms.ToTensor()
])
r_img = rotation(img)
plt.imshow(r_img[0], "gray")
nonDisplay(plt.gca())
plt.subplot(2, 4, 3)
plt.imshow(arrowRiSoP, "gray")
nonDisplay(plt.gca())
# 旋转图片演示
t_data = torch.zeros((1, 1, 28, 28))
t_data[0][0] = r_img
t_data = t_data.cuda()
res = d2l.F.softmax(net(t_data), dim=1)
bar_data = res.tolist()[0]
plt.subplot(2, 4, 4)
plt.bar(range(len(bar_data)), bar_data)
bar_np = d2l.np.array(bar_data)
predict_loc = [bar_np.argmax(), bar_np.max()]
plt.bar([predict_loc[0]], [predict_loc[1]], color='r')
plt.text(*(predict_loc[0], 0.5), f"{predict_loc[0]}",
fontsize=20, ha="center")
plt.xticks(range(10))
plt.xlabel("class")
plt.ylabel("accuracy")
plt.ylim((0, 1))
# plt.title("result")
ax = plt.gca()
ax.set_aspect(10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
## ------------------标准CNN展示---------------------------------------------------------------------##
plt.subplot(2, 4, 6)
plt.imshow(r_img[0], "gray")
nonDisplay(plt.gca())
plt.subplot(2, 4, 7)
plt.imshow(arrowCNN, "gray")
nonDisplay(plt.gca())
resize = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor()])
cnn_img = resize(r_img[0])
t_data = torch.zeros((1, 1, 224, 224))
t_data[0][0] = cnn_img
t_data = t_data.cuda()
res = d2l.F.softmax(cnn_net(t_data), dim=1)
bar_data = res.tolist()[0]
plt.subplot(2, 4, 8)
plt.bar(range(len(bar_data)), bar_data)
bar_np = d2l.np.array(bar_data)
predict_loc = [bar_np.argmax(), bar_np.max()]
plt.bar([predict_loc[0]], [predict_loc[1]], color='r')
plt.text(*(predict_loc[0], 0.5), f"{predict_loc[0]}",
fontsize=20, ha="center")
plt.xticks(range(10))
plt.xlabel("class")
plt.ylabel("accuracy")
plt.ylim((0, 1))
# plt.title("result")
ax = plt.gca()
ax.set_aspect(10)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.pause(0.2)
if degree == 0:
time.sleep(1)
# 清除当前画布
fig.clf()
time.sleep(5)