Python

import os
import time

from d2l import torch as d2l
from net_load import Net
import torch
from imageio import imread
from torchvision import transforms
import matplotlib.pyplot as plt

def nonDisplay(ax):
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.spines['top'].set_visible(False)
    ax.axes.spines['right'].set_visible(False)
    ax.axes.spines['bottom'].set_visible(False)
    ax.axes.spines['left'].set_visible(False)


arrowRiSoP = imread("../img/arrow_RiSoP-Net.png")
arrowCNN = imread("../img/arrow_StandardCNN.png")
img = imread("../img/7.png")

t_data = torch.zeros((1, 1, 28, 28))
t_data[0][0] = torch.tensor(img)
t_data = t_data.cuda()
t_data.shape

net = d2l.torch.load('mnist_model.pkl')
net.eval()

cnn_net = d2l.torch.load('mnist_CNN.pkl')
cnn_net.eval()

res = d2l.F.softmax(net(t_data), dim=1)
bar_data = res.tolist()[0]
print(bar_data)

fig = plt.figure(1, figsize=(5, 5))
path = "../img/MNIST/"
for p in os.listdir(path):
    img = imread(f"{path}{p}")
    for degree in range(0, 365, 15):
        plt.subplot(2, 4, 1)
        plt.imshow(img, "gray")
        nonDisplay(plt.gca())

        plt.subplot(2, 4, 2)
        rotation = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation((degree, degree)),
            transforms.ToTensor()
        ])
        r_img = rotation(img)
        plt.imshow(r_img[0], "gray")
        nonDisplay(plt.gca())

        plt.subplot(2, 4, 3)
        plt.imshow(arrowRiSoP, "gray")
        nonDisplay(plt.gca())

        # 旋转图片演示
        t_data = torch.zeros((1, 1, 28, 28))
        t_data[0][0] = r_img
        t_data = t_data.cuda()

        res = d2l.F.softmax(net(t_data), dim=1)
        bar_data = res.tolist()[0]
        plt.subplot(2, 4, 4)
        plt.bar(range(len(bar_data)), bar_data)

        bar_np = d2l.np.array(bar_data)
        predict_loc = [bar_np.argmax(), bar_np.max()]
        plt.bar([predict_loc[0]], [predict_loc[1]], color='r')
        plt.text(*(predict_loc[0], 0.5), f"{predict_loc[0]}",
                 fontsize=20, ha="center")
        plt.xticks(range(10))
        plt.xlabel("class")
        plt.ylabel("accuracy")
        plt.ylim((0, 1))
        # plt.title("result")
        ax = plt.gca()
        ax.set_aspect(10)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

## ------------------标准CNN展示---------------------------------------------------------------------##
        plt.subplot(2, 4, 6)
        plt.imshow(r_img[0], "gray")
        nonDisplay(plt.gca())

        plt.subplot(2, 4, 7)
        plt.imshow(arrowCNN, "gray")
        nonDisplay(plt.gca())

        resize = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor()])
        cnn_img = resize(r_img[0])

        t_data = torch.zeros((1, 1, 224, 224))
        t_data[0][0] = cnn_img
        t_data = t_data.cuda()

        res = d2l.F.softmax(cnn_net(t_data), dim=1)
        bar_data = res.tolist()[0]
        plt.subplot(2, 4, 8)
        plt.bar(range(len(bar_data)), bar_data)

        bar_np = d2l.np.array(bar_data)
        predict_loc = [bar_np.argmax(), bar_np.max()]
        plt.bar([predict_loc[0]], [predict_loc[1]], color='r')
        plt.text(*(predict_loc[0], 0.5), f"{predict_loc[0]}",
                 fontsize=20, ha="center")
        plt.xticks(range(10))
        plt.xlabel("class")
        plt.ylabel("accuracy")
        plt.ylim((0, 1))
        # plt.title("result")
        ax = plt.gca()
        ax.set_aspect(10)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)


        plt.pause(0.2)
        if degree == 0:
            time.sleep(1)
        # 清除当前画布
        fig.clf()
time.sleep(5)
This is just a placeholder img.