在树莓派上使用numpy实现简单的神经网络推理,pytorch在服务器或PC上训练好模型保存成numpy格式的数据,推理在树莓派上加载模型
2023-05-30 17:24:09 博客园


(资料图片)

这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是mlp识别mnist手写数字识别

训练代码在电脑上,cpu就能训练,很快的:

1 import torch 2 import torch.nn as nn 3 import torch.optim as optim 4 from torchvision import datasets, transforms 5  6 # 设置随机种子 7 torch.manual_seed(42) 8  9 # 定义MLP模型10 class MLP(nn.Module):11     def __init__(self):12         super(MLP, self).__init__()13         self.fc1 = nn.Linear(784, 256)14         self.fc2 = nn.Linear(256, 128)15         self.fc3 = nn.Linear(128, 10)16 17     def forward(self, x):18         x = x.view(-1, 784)19         x = torch.relu(self.fc1(x))20         x = torch.relu(self.fc2(x))21         x = self.fc3(x)22         return x23 24 # 加载MNIST数据集25 transform = transforms.Compose([26     transforms.ToTensor(),27     # transforms.Normalize((0.1307,), (0.3081,))28 ])29 30 train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)31 test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)32 33 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)34 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)35 36 # 创建模型实例37 model = MLP()38 39 # 定义损失函数和优化器40 criterion = nn.CrossEntropyLoss()41 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)42 43 # 训练模型44 def train(model, train_loader, optimizer, criterion, epochs):45     model.train()46     for epoch in range(1, epochs + 1):47         for batch_idx, (data, target) in enumerate(train_loader):48             optimizer.zero_grad()49             output = model(data)50             loss = criterion(output, target)51             loss.backward()52             optimizer.step()53             54             if batch_idx % 100 == 0:55                 print("Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(56                     epoch, batch_idx * len(data), len(train_loader.dataset),57                     100. * batch_idx / len(train_loader), loss.item()))58 59 # 训练模型60 train(model, train_loader, optimizer, criterion, epochs=5)61 62 # 保存模型为NumPy格式63 numpy_model = {}64 numpy_model["fc1.weight"] = model.fc1.weight.detach().numpy()65 numpy_model["fc1.bias"] = model.fc1.bias.detach().numpy()66 numpy_model["fc2.weight"] = model.fc2.weight.detach().numpy()67 numpy_model["fc2.bias"] = model.fc2.bias.detach().numpy()68 numpy_model["fc3.weight"] = model.fc3.weight.detach().numpy()69 numpy_model["fc3.bias"] = model.fc3.bias.detach().numpy()70 71 # 保存为NumPy格式的数据72 import numpy as np73 np.savez("mnist_model.npz", **numpy_model)

然后需要自己倒出一些图片在dataset里:我保存在了mnist_pi文件夹下,“_”后面的是标签,主要是在pc端导出保存到树莓派下

树莓派推理端的代码,需要numpy手动重新搭建网络,然后加载那些保存的矩阵参数,做矩阵乘法和加法

1 import numpy as np 2 import os 3 from PIL import Image 4  5 # 加载模型 6 model_data = np.load("mnist_model.npz") 7 weights1 = model_data["fc1.weight"] 8 biases1 = model_data["fc1.bias"] 9 weights2 = model_data["fc2.weight"]10 biases2 = model_data["fc2.bias"]11 weights3 = model_data["fc3.weight"]12 biases3 = model_data["fc3.bias"]13 14 # 进行推理15 def predict(image, weights1, biases1,weights2, biases2,weights3, biases3):16     image = image.flatten()/255  # 将输入图像展平并进行归一化17     output = np.dot(weights1, image) + biases118     output = np.dot(weights2, output) + biases219     output = np.dot(weights3, output) + biases320     predicted_class = np.argmax(output)21     return predicted_class22 23 24 25 26 folder_path = "./mnist_pi"  # 替换为图片所在的文件夹路径27 def infer_images_in_folder(folder_path):28     for file_name in os.listdir(folder_path):29         file_path = os.path.join(folder_path, file_name)30         if os.path.isfile(file_path) and file_name.endswith((".jpg", ".jpeg", ".png")):31             image = Image.open(file_path)32             label = file_name.split(".")[0].split("_")[1]33             image = np.array(image)34             print("file_path:",file_path,"img size:",image.shape,"label:",label)35             predicted_class = predict(image, weights1, biases1,weights2, biases2,weights3, biases3)36             print("Predicted class:", predicted_class)37         38 infer_images_in_folder(folder_path)

结果:

效果还不错:

这次内容就到这里了,下次争取做一个卷积的神经网络在树莓派上推理,然后争取做一个目标检测的模型在树莓派上

热门推荐

文章排行

  1. 2023-05-30在树莓派上使用numpy实现简单的神经网络推理,pytorch在服务器或PC上训练好模型保存成numpy格式的数据,推理在树莓派上加载模型
  2. 2023-05-30天天消息!泰坦女巨人官网在哪下载 最新官方下载安装地址
  3. 2023-05-30【当前独家】外交部:欢迎包括马斯克先生在内各国工商界人士访华
  4. 2023-05-30佛跳墙一般多少钱一份,正宗的佛跳墙价格在1000元左右|世界快看
  5. 2023-05-30深“V”:A股三大指数午后跌超1%后反抽收涨,中特估走强
  6. 2023-05-30【环球热闻】对华“去风险”本身就是在制造风险(国际论坛)
  7. 2023-05-30北京的理财公司(知钱 北京理财顾问有限责任公司) 每日资讯
  8. 2023-05-30天天信息:永春县气象台发布高温橙色预警信号【2023-05-30】
  9. 2023-05-30C919首航成功具有多重意义
  10. 2023-05-30天天报道:德力大力牛魔王01,电池容量9.98kwh,售价3.78万
  11. 2023-05-30全球视点!神舟十六号载人飞船发射圆满成功
  12. 2023-05-30冶金工业出版社网站_冶金工业出版社
  13. 2023-05-30铁电材料的介电常数_铁的相对介电常数是多少简介介绍 每日讯息
  14. 2023-05-30美豆整体偏弱为主 豆粕主力小幅下跌-天天快资讯
  15. 2023-05-30消息!国内油价调价窗口30日开启 或小幅上调
  16. 2023-05-30空气质量的健康革命,美的如何「鲜净」解题?
  17. 2023-05-30天天观天下!临沂:杏子丰收乐农家
  18. 2023-05-30实时焦点:硬核科技论 | 终于加上激光雷达 全新蔚来ES6黑科技升级盘点
  19. 2023-05-30把婴儿聚一起,不教他们说话,会产生新语言吗?这事埃及法老干过 焦点消息
  20. 2023-05-30excel怎么统计相同项的数据(excel怎么统计相同项)