MAX78000-AI宝可梦图鉴
基于图像识别等技术,利用MAX78000开发板完成宝可梦的识别任务,实现“宝可梦图鉴”的基本功能
标签
嵌入式系统
氢化脱氯次氯酸
更新2023-01-31
中国科学技术大学
4400

1 项目介绍

宝可梦是由任天堂、GAME FREAK和Creatures三家公司共同持有版权的连锁品牌,制作包括游戏、动画、电影、卡牌游戏、漫画和特许商品等方面的系列作品,作品描述了一个存在一种被称为宝可梦的生物所生存的世界。目前宝可梦的种类已经超过1000种,它们形态各异,包括动物形的宝可梦(如皮卡丘,图1‑1所示)、人形宝可梦(如腕力,图1‑2所示)、物体形宝可梦(如小磁怪,图1‑3所示)等等。在宝可梦世界中,宝可梦图鉴这一道具可以识别出不同的宝可梦,并提供这只宝可梦的有关信息等。

在现实世界中,随着计算机视觉、图像处理和机器学习等技术的发展,动植物识别等应用逐渐走入了我们的生活中,游戏中万能的“宝可梦图鉴”也在变为可能。因此本项目的目标是基于图像识别等技术,利用MAX78000开发板完成宝可梦的识别任务,实现“宝可梦图鉴”的基本功能,即对输入图像中的宝可梦进行准确的识别,并显示宝可梦的名称。由于所有宝可梦的种类数量较多(超过1000种),识别全部的宝可梦工作量较大,且开发板的资源有限,本项目只对第一世代游戏的部分宝可梦(150种)进行识别。

9k=

图1‑1 皮卡丘图像

Z

图1‑2 腕力图像

v85G0bu+kTQQAAAAABJRU5ErkJggg==

图1‑3 小磁怪图像

Afmmj9Xw95l3AAAAAElFTkSuQmCC

图1‑4 《宝可梦》第一世代游戏的图鉴外观

图形用户界面, 应用程序  描述已自动生成

图1‑5 Pokemon HOME游戏中的图鉴

2 项目设计思路

“宝可梦图鉴”有两大功能:宝可梦识别和浏览,灵感分别来源于动画和游戏中的宝可梦图鉴。宝可梦识别为主要功能,它的工作流程包含宝可梦图像采集、宝可梦识别和识别结果展示三大部分。图像采集方面,由于MAX78000开发板带有一个摄像头,因此可以将摄像头拍摄到的图片作为输入,传给识别模块。识别模块利用预训练的神经网络对输入图像进行处理,生成分类结果。最后的识别结果将展示在LCD屏幕上,同时还会显示SD卡中储存的宝可梦相关信息,以模拟图鉴的功能。宝可梦浏览为图鉴的次要功能,即从SD卡中读取所有宝可梦的信息并依次展示。

此外为了提升项目的可复现性,SD卡和LCD屏幕均为可选模块,只使用MAX78000开发板的板载摄像头也可以完成宝可梦的识别,识别结果将通过串口发送至电脑。

3 效果展示

开发板通电后,首先显示开机界面,按下SW1进入识别模式,按下SW2进入浏览模式,浏览模式必须插入SD卡。

图形用户界面  描述已自动生成

图3‑1 开机界面

3.1 识别模式

开始识别前,屏幕会持续显示摄像头拍摄的内容,按下SW1拍照。

屏幕上有字  描述已自动生成

图3‑2 显示摄像头实时画面

拍照后,MAX78000将照片送入卷积网络,得到分类结果,并读取SD卡的宝可梦信息和缩略图并显示在屏幕上。如图3‑3所示,左上显示摄像头拍摄的照片,坐下显示识别结果宝可梦的缩略图,右侧为宝可梦的信息,从上至下依次为宝可梦名称、图鉴编号、第一属性、第二属性、身高、体重、网络识别结果Top5宝可梦名称及概率。同时串口也会输出识别结果和相关信息,如图3‑4所示。

如果没有插入SD卡,仍然可以进行识别,不过不会显示宝可梦的缩略图和身高、体重、属性信息,如图3‑5所示。

Z

图3‑3 识别结果显示(显示屏)

WzN6BwAAgOIT+f8DtKJZoustIGQAAAAASUVORK5CYII=

图3‑4 识别结果显示(串口)

Z

图3‑5 识别结果显示(未插入SD卡)

按下SW1即可继续进行拍摄,按下SW2则以下一个编号的宝可梦进入浏览模式。

3.2 浏览模式

在开机或识别完成后按下SW2且插入SD卡即可进入浏览模式。当开机后按下SW2时,则从图鉴编号#001开始浏览;当从识别模式进入浏览模式时,从被识别宝可梦下一个图鉴编号的宝可梦开始浏览。每按下一次SW2即可让图鉴编号+1,一共有152只宝可梦可以浏览(#001妙蛙种子-#151梦幻+阿罗拉穿山王)。注:识别模式只能识别150种宝可梦,数据集未收录尼多兰(Nidoran♀)和尼多朗(Nidoran♂)的图片,可能是因为它们的英文名是一样的,只差了一个性别符号。

9k=

图3‑6 浏览模式显示

4 项目实现过程

图示  描述已自动生成

图4‑1 宝可梦图鉴项目开发流程

4.1 训练素材搜集

4.1.1 素材来源

本项目中用到的宝可梦数据集来源于该网站:https://www.kaggle.com/datasets/lantian773030/pokemonclassification,其包含了150种不同宝可梦共7000张图像。数据集中部分图片如图4‑2所示,可以看到包含了同一种宝可梦游戏、动画、同人创作和实物玩偶等不同种类的图像,类型相当丰富。

mmE0AAAAASUVORK5CYII=

图4‑2 数据集的部分图像(同一种宝可梦)

4.1.2 数据集生成

为了保证模型的泛化性,需要将数据集图片随机划分为训练集和测试集两部分,其中训练集占图片总数的80%,测试集占20%,划分后的数据集可以在https://rec.ustc.edu.cn/share/b0146640-9330-11ed-ac24-e105c3f0767c下载。

划分完毕后,需要用pytorch的ImageFolder类读取图片并进行处理,分别生成训练集和测试集,代码如下:

train_dataset = ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)

其中root参数为数据集目录,transform参数包含了对图片的处理。由于数据集中的图片尺寸不一,将图片送入网络处理前需要统一图片尺寸,即统一将图片大小设为128×128。另外训练过程中,为了提升模型在不同场景下的鲁棒性并防止过拟合,还需要对数据集进行增广操作,包括对图片进行仿射变换、透视变换、翻转、亮度改变等等操作。最后需要将图片变换成pytorch张量并进行归一化,将数据从[0,1]变换至[-128,127],从而能够在MAX78000上运行。这些对数据集的变换操作在pytorch中可以整合进一个变换函数中,代码如下:

train_transform = transforms.Compose([
            transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), shear=5),
            transforms.RandomPerspective(distortion_scale=0.3, p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.7),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            ai8x.normalize(args=args),
            ])

以同样的方法生成测试集(测试集不使用数据增广操作)。完成的数据集生成代码如下(ai8x-training/datasets/pokemon.py):

import os

from torchvision import transforms
from torchvision.datasets import ImageFolder
import ai8x

def pokemon_get_datasets(data, load_train=True, load_test=True):
    (data_dir, args) = data
    
    if load_train:
        train_transform = transforms.Compose([
            transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), shear=5),
            transforms.RandomPerspective(distortion_scale=0.3, p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.7),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            ai8x.normalize(args=args),
            ])
        
        train_dataset = ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)
    else:
        train_dataset = None
        
    if load_test:
        test_transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            ai8x.normalize(args=args),
            ])
        
        test_dataset = ImageFolder(root=os.path.join(data_dir, 'test'), transform=test_transform)
    else:
        test_dataset = None
        
    return train_dataset, test_dataset
    

datasets = [
    {
        'name': 'pokemon',
        'input': (3, 128, 128),
        'output': ('Abra', 'Aerodactyl', 'Alakazam', 'Alolan Sandslash', 'Arbok', 'Arcanine', 'Articuno', 'Beedrill', 'Bellsprout', 'Blastoise', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Chansey', 'Charizard', 'Charmander', 'Charmeleon', 'Clefable', 'Clefairy', 'Cloyster', 'Cubone', 'Dewgong', 'Diglett', 'Ditto', 'Dodrio', 'Doduo', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Eevee', 'Ekans', 'Electabuzz', 'Electrode', 'Exeggcute', 'Exeggutor', 'Farfetchd', 'Fearow', 'Flareon', 'Gastly', 'Gengar', 'Geodude', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Golem', 'Graveler', 'Grimer', 'Growlithe', 'Gyarados', 'Haunter', 'Hitmonchan', 'Hitmonlee', 'Horsea', 'Hypno', 'Ivysaur', 'Jigglypuff', 'Jolteon', 'Jynx', 'Kabuto', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingler', 'Koffing', 'Krabby', 'Lapras', 'Lickitung', 'Machamp', 'Machoke', 'Machop', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey', 'Marowak', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Moltres', 'MrMime', 'Muk', 'Nidoking', 'Nidoqueen', 'Nidorina', 'Nidorino', 'Ninetales', 'Oddish', 'Omanyte', 'Omastar', 'Onix', 'Paras', 'Parasect', 'Persian', 'Pidgeot', 'Pidgeotto', 'Pidgey', 'Pikachu', 'Pinsir', 'Poliwag', 'Poliwhirl', 'Poliwrath', 'Ponyta', 'Porygon', 'Primeape', 'Psyduck', 'Raichu', 'Rapidash', 'Raticate', 'Rattata', 'Rhydon', 'Rhyhorn', 'Sandshrew', 'Sandslash', 'Scyther', 'Seadra', 'Seaking', 'Seel', 'Shellder', 'Slowbro', 'Slowpoke', 'Snorlax', 'Spearow', 'Squirtle', 'Starmie', 'Staryu', 'Tangela', 'Tauros', 'Tentacool', 'Tentacruel', 'Vaporeon', 'Venomoth', 'Venonat', 'Venusaur', 'Victreebel', 'Vileplume', 'Voltorb', 'Vulpix', 'Wartortle', 'Weedle', 'Weepinbell', 'Weezing', 'Wigglytuff', 'Zapdos', 'Zubat'),
        'loader': pokemon_get_datasets,
    },
]

4.2 模型训练及部署

MAX78000的模型训练及部署分为四步,分别为模型训练、网络参数量化、代码生成和代码下载。下面将对这四步的具体流程以及模型的选择思路进行介绍。

4.2.1 模型选择

本项目选择的模型由SDK中自带的ai85net-simplenet-wide2x模型(下称“原模型”)修改而来(修改模型下称“宝可梦模型”)。原模型被用于cifar100分类任务上,输入为尺寸32×32×3的RGB图像,输出为该图像属于100个类中各类的概率,共包含了14个卷积层和5个最大池化层,除最后一个卷积层外,每个卷积层后还包含批归一化操作和ReLU激活层,最后输出100维的one-hot向量经过softmax,作为分类结果。

由于宝可梦分类与cifar100分类任务类似,都具有较多的类别数,因此对原模型进行简单修改即可应用到宝可梦分类任务中。宝可梦模型的输入被修改为128×128×3的RGB图像,输出为150维的向量,同时添加了两个最大池化层来保证最后一层的输出向量只有channel维度。宝可梦模型架构如图4‑3所示。

NSMAAAAASUVORK5CYII=

图4‑3 宝可梦模型架构

4.2.2 模型训练

模型训练参数设置为150epoch,batchsize128,学习率0.001,优化器为Adam,并随着训练进度阶梯式调整学习率(learning rate schedule)来使模型得到充分训练。此外,为了使模型能在MAX78000上部署,还采用了量化训练(quantization-aware training),来防止由于网络参数的量化造成的性能损失。

训练环境为Ubuntu 20.04,使用2张NVIDIA RTX 3090显卡进行训练,训练时间约为1小时。

训练脚本如下(ai8x-training/scripts/train_pokemon.sh):

python train.py --epochs 150 --optimizer Adam --lr 0.001 --wd 0 --compress policies/schedule-pokemon.yaml --model ai85pokemon --dataset pokemon --device MAX78000 --batch-size 128 --print-freq 10 --validation-split 0 --qat-policy policies/qat_pokemon.yaml --use-bias "$@" --data PokemonData_split --gpus 0,2

learning rate schedule配置文件如下(ai8x-training/policies/schedule-pokemon.yaml):

---
lr_schedulers:
  training_lr:
    class: MultiStepLR
    milestones: [40, 60, 80]
    gamma: 0.25

policies:
  - lr_scheduler:
      instance_name: training_lr
    starting_epoch: 0
    ending_epoch: 150
    frequency: 1

量化训练配置文件如下(ai8x-training/policies/qat_pokemon.yaml):

---
start_epoch: 50
shift_quantile: 0.985
weight_bits: 2
overrides:
  conv1:
    weight_bits: 8
  conv2:
    weight_bits: 4
  conv11:
    weight_bits: 4
  conv12:
    weight_bits: 4
  conv13:
    weight_bits: 4
  conv14:
    weight_bits: 4

训练过程的输出如下:

2023-01-12 17:22:41,240 - Training epoch: 5389 samples (128 per mini-batch)
2023-01-12 17:22:50,182 - Epoch: [0][   10/   43]    Overall Loss 5.005712    Objective Loss 5.005712                                        LR 0.001000    Time 0.894080    
2023-01-12 17:22:54,505 - Epoch: [0][   20/   43]    Overall Loss 4.996409    Objective Loss 4.996409                                        LR 0.001000    Time 0.663122    
2023-01-12 17:22:59,222 - Epoch: [0][   30/   43]    Overall Loss 4.975625    Objective Loss 4.975625                                        LR 0.001000    Time 0.599292    
2023-01-12 17:23:02,516 - Epoch: [0][   40/   43]    Overall Loss 4.951369    Objective Loss 4.951369                                        LR 0.001000    Time 0.531800    
2023-01-12 17:23:04,245 - Epoch: [0][   43/   43]    Overall Loss 4.942789    Objective Loss 4.942789    Top1 5.673759    Top5 18.439716    LR 0.001000    Time 0.534906    
2023-01-12 17:23:04,372 - --- validate (epoch=0)-----------
2023-01-12 17:23:04,373 - 1431 samples (128 per mini-batch)
2023-01-12 17:23:08,012 - Epoch: [0][   10/   12]    Loss 5.009361    Top1 0.781250    Top5 3.515625    
2023-01-12 17:23:08,200 - Epoch: [0][   12/   12]    Loss 5.009094    Top1 0.768693    Top5 3.354298    
2023-01-12 17:23:08,391 - ==> Top1: 0.769    Top5: 3.354    Loss: 5.009

2023-01-12 17:23:08,394 - ==> Best [Top1: 0.769   Top5: 3.354   Sparsity:0.00   Params: 704560 on epoch: 0]
2023-01-12 17:23:08,394 - Saving checkpoint to: logs/2023.01.12-172236/checkpoint.pth.tar
2023-01-12 17:23:08,433 - 

2023-01-12 17:23:08,433 - Training epoch: 5389 samples (128 per mini-batch)
2023-01-12 17:23:15,396 - Epoch: [1][   10/   43]    Overall Loss 4.758103    Objective Loss 4.758103                                        LR 0.001000    Time 0.696099    
2023-01-12 17:23:18,252 - Epoch: [1][   20/   43]    Overall Loss 4.716858    Objective Loss 4.716858                                        LR 0.001000    Time 0.490836    
2023-01-12 17:23:23,579 - Epoch: [1][   30/   43]    Overall Loss 4.676186    Objective Loss 4.676186                                        LR 0.001000    Time 0.504744    
2023-01-12 17:23:27,941 - Epoch: [1][   40/   43]    Overall Loss 4.631547    Objective Loss 4.631547                                        LR 0.001000    Time 0.487583    
2023-01-12 17:23:29,851 - Epoch: [1][   43/   43]    Overall Loss 4.622380    Objective Loss 4.622380    Top1 3.546099    Top5 19.148936    LR 0.001000    Time 0.497972    
2023-01-12 17:23:30,075 - --- validate (epoch=1)-----------
2023-01-12 17:23:30,075 - 1431 samples (128 per mini-batch)
2023-01-12 17:23:32,773 - Epoch: [1][   10/   12]    Loss 4.983619    Top1 0.781250    Top5 6.171875    
2023-01-12 17:23:32,852 - Epoch: [1][   12/   12]    Loss 4.982425    Top1 0.978337    Top5 6.429071    
2023-01-12 17:23:33,028 - ==> Top1: 0.978    Top5: 6.429    Loss: 4.982

2023-01-12 17:23:33,031 - ==> Best [Top1: 0.978   Top5: 6.429   Sparsity:0.00   Params: 704560 on epoch: 1]
2023-01-12 17:23:33,031 - Saving checkpoint to: logs/2023.01.12-172236/checkpoint.pth.tar
2023-01-12 17:23:33,111 - 

2023-01-12 17:23:33,111 - Training epoch: 5389 samples (128 per mini-batch)
2023-01-12 17:23:39,162 - Epoch: [2][   10/   43]    Overall Loss 4.390907    Objective Loss 4.390907                                        LR 0.001000    Time 0.604864    
2023-01-12 17:23:42,430 - Epoch: [2][   20/   43]    Overall Loss 4.359883    Objective Loss 4.359883                                        LR 0.001000    Time 0.465793    
2023-01-12 17:23:47,960 - Epoch: [2][   30/   43]    Overall Loss 4.316997    Objective Loss 4.316997                                        LR 0.001000    Time 0.494837    
2023-01-12 17:23:52,032 - Epoch: [2][   40/   43]    Overall Loss 4.291897    Objective Loss 4.291897                                        LR 0.001000    Time 0.472925    
2023-01-12 17:23:52,534 - Epoch: [2][   43/   43]    Overall Loss 4.285234    Objective Loss 4.285234    Top1 11.347518    Top5 31.914894    LR 0.001000    Time 0.451596    
2023-01-12 17:23:52,747 - --- validate (epoch=2)-----------
2023-01-12 17:23:52,747 - 1431 samples (128 per mini-batch)
2023-01-12 17:23:55,682 - Epoch: [2][   10/   12]    Loss 4.213376    Top1 8.125000    Top5 26.015625    
2023-01-12 17:23:55,758 - Epoch: [2][   12/   12]    Loss 4.185180    Top1 8.315863    Top5 26.624738    
2023-01-12 17:23:55,943 - ==> Top1: 8.316    Top5: 26.625    Loss: 4.185

2023-01-12 17:23:55,945 - ==> Best [Top1: 8.316   Top5: 26.625   Sparsity:0.00   Params: 704560 on epoch: 2]
2023-01-12 17:23:55,946 - Saving checkpoint to: logs/2023.01.12-172236/checkpoint.pth.tar
2023-01-12 17:23:56,012 -

…

2023-01-12 18:19:20,585 - Training epoch: 5389 samples (128 per mini-batch)
2023-01-12 18:19:26,638 - Epoch: [149][   10/   43]    Overall Loss 0.264605    Objective Loss 0.264605                                        LR 0.000016    Time 0.605087    
2023-01-12 18:19:30,181 - Epoch: [149][   20/   43]    Overall Loss 0.263862    Objective Loss 0.263862                                        LR 0.000016    Time 0.479651    
2023-01-12 18:19:35,059 - Epoch: [149][   30/   43]    Overall Loss 0.252918    Objective Loss 0.252918                                        LR 0.000016    Time 0.482333    
2023-01-12 18:19:40,680 - Epoch: [149][   40/   43]    Overall Loss 0.242886    Objective Loss 0.242886                                        LR 0.000016    Time 0.502244    
2023-01-12 18:19:41,724 - Epoch: [149][   43/   43]    Overall Loss 0.238903    Objective Loss 0.238903    Top1 97.163121    Top5 100.000000    LR 0.000016    Time 0.491487    
2023-01-12 18:19:41,933 - --- validate (epoch=149)-----------
2023-01-12 18:19:41,934 - 1431 samples (128 per mini-batch)
2023-01-12 18:19:45,018 - Epoch: [149][   10/   12]    Loss 1.044160    Top1 75.625000    Top5 89.609375    
2023-01-12 18:19:45,113 - Epoch: [149][   12/   12]    Loss 1.018712    Top1 75.821104    Top5 89.797345    
2023-01-12 18:19:45,334 - ==> Top1: 75.821    Top5: 89.797    Loss: 1.019

2023-01-12 18:19:45,339 - ==> Best [Top1: 76.171   Top5: 90.496   Sparsity:0.00   Params: 704560 on epoch: 141]
2023-01-12 18:19:45,339 - Saving checkpoint to: logs/2023.01.12-172236/qat_checkpoint.pth.tar
2023-01-12 18:19:45,393 - --- test ---------------------
2023-01-12 18:19:45,394 - 1431 samples (128 per mini-batch)
2023-01-12 18:19:48,331 - Test: [   10/   12]    Loss 1.029740    Top1 76.171875    Top5 89.843750    
2023-01-12 18:19:48,432 - Test: [   12/   12]    Loss 1.021020    Top1 75.821104    Top5 89.797345    
2023-01-12 18:19:48,632 - ==> Top1: 75.821    Top5: 89.797    Loss: 1.021

可以看到,经过训练在测试集上的Top1准确率可以达到76%,Top5准确率可以达到90%,说明网络可以对大部分的宝可梦图片进行有效的识别。

4.2.3 参数量化

训练完毕后,对模型参数进行量化,使其从浮点数转换为整数。量化操作的脚本如下(ai8x-synthesis/scripts/quantize_pokemon.sh):

python quantize.py ../ai8x-training/logs/2023.01.12-172236/qat_best.pth.tar trained/pokemon-q.pth.tar --device MAX78000 -v -c networks/pokemon.yaml "$@"

该脚本以训练过程中保存的网络参数(qat_best.pth.tar)为输入,输出量化后的网络参数(pokemon-q.pth.tar),还需要指定描述网络的yaml文件(networks/pokemon.yaml),该文件的编写方法可以参考https://github.com/MaximIntegratedAI/MaximAI_Documentation/blob/master/Guides/YAML%20Quickstart.md,其内容如下:

arch: ai85pokemon
dataset: pokemon

layers:
  - out_offset: 0x2000
    processors: 0x0000000000000007  # 1
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
    data_format: HWC
    streaming: true
  - max_pool: 2
    pool_stride: 2
    pad: 1
    operation: conv2d
    kernel_size: 3x3
    activate: ReLU
    out_offset: 0x0000
    processors: 0x00000000000ffff0  # 2
    streaming: true
  - out_offset: 0x2000
    processors: 0x0ffffffff0000000  # 3
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
  - out_offset: 0x0000
    processors: 0x00000000ffffffff  # 4
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
  - max_pool: 2
    pool_stride: 2
    pad: 1
    operation: conv2d
    kernel_size: 3x3
    activate: ReLU
    out_offset: 0x2000
    processors: 0xffffffff00000000  # 5
  - out_offset: 0x0000
    processors: 0x00000000ffffffff  # 6
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
  - out_offset: 0x2000
    processors: 0xffffffff00000000  # 7
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
  - out_offset: 0x0000
    processors: 0x00000000ffffffff  # 8
    operation: conv2d
    kernel_size: 3x3
    pad: 1
    activate: ReLU
  - max_pool: 2
    pool_stride: 2
    pad: 1
    operation: conv2d
    kernel_size: 3x3
    activate: ReLU
    out_offset: 0x2000
    processors: 0xffffffffffffffff  # 9
  - max_pool: 2
    pool_stride: 2
    pad: 1
    operation: conv2d
    kernel_size: 3x3
    activate: ReLU
    out_offset: 0x0000
    processors: 0xffffffffffffffff  # 10
  - max_pool: 2
    pool_stride: 2
    pad: 0
    operation: conv2d
    kernel_size: 1x1
    activate: ReLU
    out_offset: 0x2000
    processors: 0xffffffffffffffff  # 11
  - max_pool: 2
    pool_stride: 2
    out_offset: 0x0000
    processors: 0xffffffffffffffff  # 12
    operation: conv2d
    kernel_size: 1x1
    pad: 0
    activate: ReLU
  - max_pool: 2
    pool_stride: 2
    pad: 1
    operation: conv2d
    kernel_size: 3x3
    activate: ReLU
    out_offset: 0x2000
    processors: 0xffffffffffffffff  # 13
  - out_offset: 0x0000
    processors: 0xffffffffffffffff  # 14
    operation: conv2d
    kernel_size: 1x1
    pad: 0
    output_width: 32
    activate: None

4.2.4 代码生成与下载

模型量化后的下一步是生成可以在MAX78000上运行的c语言代码,生成该项目主要需要三个文件:

(1) 量化后的模型(trained/pokemon-q.pth.tar)

(2) 网络描述文件(networks/pokemon.yaml)

(3) 输入样本文件(sample.npy)

其中输入样本文件可以由ai8x-synthesis/tests/convert_sample.py脚本通过已有图片生成,或由ai8x-synthesis/tests/make_sample.py脚本随机生成。该样本文件只有测试和检查网络输出的功能,它的内容对本项目没有影响。

生成操作的脚本如下(ai8x-synthesis/scripts/gen_pokemon.sh):

python ai8xize.py --test-dir "sdk/Examples/MAX78000/CNN" --prefix pokemon --checkpoint-file trained/pokemon-q.pth.tar --config-file networks/pokemon.yaml --fifo --softmax --device MAX78000 --timer 0 --display-checkpoint --verbose --compact-data --mexpress --sample-input 'sample.npy' --boost 2.5 "$@"

之后即可在sdk/Examples/MAX78000/CNN/pokemon目录下找到自动生成的项目,该项目可以直接下载到MAX78000中,功能为将输入样本送给网络得到输出,并通过串口打印结果。

至此,我们得到了可以部署在MAX78000上的宝可梦模型,完成了宝可梦分类的任务。但该模型只会对固定的输入样本进行分类,不能提供宝可梦的其他信息,也没有一个友好的UI界面。下面的部分将介绍宝可梦图鉴的图像采集、SD卡数据读取和信息显示部分,来解决以上两个问题。

4.3 图像采集

为实现宝可梦图鉴使用的便捷性和识别的实时性,待分类的宝可梦图片应该由MAX78000的板载摄像头直接采集,而非像上一节那样写死进代码中,每次识别都要重新下载代码。

图像采集部分的代码主要参考了官方例程cats-dogs_demo,例程包含了相机模块初始化、拍照、图像处理和导入模型等步骤。下面为图像处理的相关代码(capture_process_camera()函数),包含逐行读取图像、将图像转换为0x00bbggrr格式提供给模型作为输入、将图像转换为RGB565格式提供给TFT显示屏模块进行显示等操作。

    // Get image line by line
    for (int row = 0; row < h; row++) {
        // Wait until camera streaming buffer is full
        while ((data = get_camera_stream_buffer()) == NULL) {
            if (camera_is_image_rcv()) {
                break;
            }
        }

        j = 0;
        for (int k = 0; k < 4 * w; k += 4) {
            // data format: 0x00bbggrr
            r = data[k];
            g = data[k + 1];
            b = data[k + 2];
            //skip k+3

            // change the range from [0,255] to [-128,127] and store in buffer for CNN
            input_0[cnt++] = ((b << 16) | (g << 8) | r) ^ 0x00808080;

            // convert to RGB565 for display
            rgb = ((r & 0b11111000) << 8) | ((g & 0b11111100) << 3) | (b >> 3);
            data565[j] = (rgb >> 8) & 0xFF;
            data565[j + 1] = rgb & 0xFF;
            j += 2;
        }
#ifdef TFT_ENABLE
        MXC_TFT_ShowImageCameraRGB565(CAM_X_START, CAM_Y_START + row, data565, w, 1);
#endif
        // Release stream buffer
        release_camera_stream_buffer();
    }

    //camera_sleep(1);
    stat = get_camera_stream_statistic();

    if (stat->overflow_count > 0) {
        printf("OVERFLOW DISP = %d\n", stat->overflow_count);
        LED_On(LED2); // Turn on red LED if overflow detected
        while (1) {}
    }

图像处理完毕后,利用cnn_load_input()函数将图像导入进模型中:

void cnn_load_input(void)
{
    int i;
    const uint32_t *in0 = input_0;

    for (i = 0; i < IMAGE_SIZE_X * IMAGE_SIZE_Y; i++) {
        // Remove the following line if there is no risk that the source would overrun the FIFO:
        while (((*((volatile uint32_t *)0x50000004) & 1)) != 0) {}
        // Wait for FIFO 0
        *((volatile uint32_t *)0x50000008) = *in0++; // Write FIFO 0
    }
}

4.4 SD卡数据读取

本项目利用SD卡储存了宝可梦的图片、属性、身高、体重信息,这些信息在宝可梦识别后或在浏览模式中可以被读取并显示,从而丰富图鉴内容。SD卡为可选模块,如果没有SD卡,程序也会运行,不过识别后只会显示宝可梦名称和编号,而不会显示上述只储存在SD卡中的信息。此外没有SD卡图鉴的浏览模式也将不可用。下面将对这些数据的准备流程和MAX78000读取数据的流程进行介绍。

4.4.1 数据准备

SD卡根目录包含了若干以宝可梦名称命名的文件夹,每个文件夹包含data.bin和image.bmp两个文件,分别储存了宝可梦的数据和图片,其结构如下:

├───Abra

│ data.bin

│ image.bmp

├───Aerodactyl

│ data.bin

│ image.bmp

├───Alakazam

│ data.bin

│ image.bmp

SD卡中的宝可梦图片来源于Kaggle上的Pokemon Image Dataset(网址:https://www.kaggle.com/datasets/vishalsubbiah/pokemon-images-and-types),它包含了所有宝可梦的png格式缩略图(尺寸为120×120),每种宝可梦仅一张图片。注意该图片集并非模型训练用到的训练集,SD卡中的图片仅有显示作用,和模型训练和测试均无关联。由于png格式为压缩后的图像格式,不能在MAX78000中直接显示,我们需要把这些图片转换为未压缩的bmp格式后才能够在MAX78000中使用。

宝可梦的属性、身高、体重数据来源于Kaggle上的All Pokemon Dataset(网址:https://www.kaggle.com/datasets/maca11/all-pokemon-dataset),我们可以用python的pandas包对这些数据进行处理,筛选出我们需要的数据,并以二进制格式进行保存。data.bin的数据格式如下:

字节

内容

类型

0

编码后的第一属性

uint8

1

编码后的第二属性

uint8

2~5

身高(m)

float

6~9

体重(kg)

float

属性的编码方式如下:

属性

编码

属性

编码

属性

编码

无属性

0

冰 'Ice'

7

岩石 'Rock'

14

电 'Electric'

1

钢 'Steel'

8

幽灵 'Ghost'

15

水 'Water'

2

地面 'Ground'

9

草 'Grass'

16

火 'Fire'

3

格斗 'Fighting'

10

超能力 'Psychic'

17

飞行 'Flying'

4

龙 'Dragon'

11

毒 'Poison'

18

恶 'Dark'

5

虫 'Bug'

12

   

妖精 'Fairy'

6

一般 'Normal'

13

   

4.4.2 读取数据

一次完整的SD卡数据的读取流程包含检查SD卡可用性、挂载SD卡、打开文件、读取文件和关闭文件这几步,下面的函数均定义在文件sd.c中。

(1) 检查SD卡可用性

当有SD卡插入时,MAX78000的P0.12引脚会输出低电平,所以读取P0.12引脚电平即可判断SD卡是否可用。

uint8_t checkCardInserted(void)
{
    // On the MAX78000FTHR board, P0.12 will be pulled low when a card is inserted.
    mxc_gpio_cfg_t cardDetect;
    cardDetect.port = MXC_GPIO0;
    cardDetect.mask = MXC_GPIO_PIN_12;
    cardDetect.func = MXC_GPIO_FUNC_IN;
    cardDetect.pad = MXC_GPIO_PAD_NONE;
    cardDetect.vssel = MXC_GPIO_VSSEL_VDDIOH;

    MXC_GPIO_Config(&cardDetect);

    // Exit function if card is already inserted
    if (MXC_GPIO_InGet(MXC_GPIO0, MXC_GPIO_PIN_12) == 0) {
        return 1;
    }

    return 0;
}

(2) 挂载SD卡

int mount()
{    
    fs = &fs_obj;

    if ((err = f_mount(fs, "", 1)) != FR_OK) { //Mount the default drive to fs now
        printf("Error opening SD card: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
    } else {
        printf("SD card mounted.\n");
        mounted = 1;
    }

    f_getcwd(cwd, sizeof(cwd)); //Set the Current working directory

    return err;
}

(3) 打开文件

    if ((err = f_open(&file, filename, FA_READ)) != FR_OK) {
        printf("Error opening file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

(4) 读取文件

读取data.bin的代码如下,其中用到了库函数f_read,该函数有4个参数,第1个参数为文件指针,第2个参数为读取数据的存放地址,第3个参数为想要读取的字节数,第4个参数为已读取的字节数的地址。按照前面规定的数据格式,分别读取1、1、4、4个字节,即可得到第一/第二属性、身高和体重信息。

    if ((err = f_read(&file, type1, 1, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, type2, 1, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, height, 4, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, weight, 4, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

读取image.bmp的代码如下。根据bmp的数据格式,一般第54字节为图片数据的起始位置,因此利用f_lseek将光标移动至第54个字节再调用f_read读取至文件末尾即可得到图片的RGB值。

    uint8_t *ptr = rgb_image;
 
    if ((err = f_lseek(&file, 54)) != FR_OK) {
        printf("Error seeking file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    do
    {
        if ((err = f_read(&file, ptr, 255, &bytes_read)) != FR_OK) {
            printf("Error reading file: %s\n", FF_ERRORS[err]);
            f_mount(NULL, "", 0);
            return err;
        }
        ptr += 255;
    } while (bytes_read == 255);

(5) 关闭文件

    if ((err = f_close(&file)) != FR_OK) {
        printf("Error closing file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

完整的数据读取(read_data())和图片读取(read_bmp())函数如下:

int read_data(char filename[], uint8_t *type1, uint8_t *type2, float *height, float *weight)
{
    FIL file; //FFat File Object
    UINT bytes_read;

    if (!mounted)
        mount();

    if ((err = f_open(&file, filename, FA_READ)) != FR_OK) {
        printf("Error opening file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, type1, 1, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, type2, 1, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, height, 4, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_read(&file, weight, 4, &bytes_read)) != FR_OK) {
        printf("Error reading file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_close(&file)) != FR_OK) {
        printf("Error closing file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    return 0;
}

int read_bmp(char filename[], uint8_t rgb_image[])
{
    FIL file; //FFat File Object
    UINT bytes_read;
    uint8_t *ptr = rgb_image;
    // int height = 60, width = 60;
    
    if (!mounted)
        mount();

    if ((err = f_open(&file, filename, FA_READ)) != FR_OK) {
        printf("Error opening file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    if ((err = f_lseek(&file, 54)) != FR_OK) {
        printf("Error seeking file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    do
    {
        if ((err = f_read(&file, ptr, 255, &bytes_read)) != FR_OK) {
            printf("Error reading file: %s\n", FF_ERRORS[err]);
            f_mount(NULL, "", 0);
            return err;
        }
        ptr += 255;
    } while (bytes_read == 255);

    if ((err = f_close(&file)) != FR_OK) {
        printf("Error closing file: %s\n", FF_ERRORS[err]);
        f_mount(NULL, "", 0);
        return err;
    }

    return 0;
}

4.5 显示模块

4.5.1 硬件连接

本项目使用的显示屏为ILI9341驱动的TFT LCD显示屏,分辨率为320×240,使用SPI协议与MAX78000进行通讯。LCD的连接参考群友HonestQiao分享的TFT LCD使用心得,在此感谢这位群友的分享。

日程表  描述已自动生成

图4‑4 LCD连线图

4.5.2 图片显示

显示屏上显示的图片包含摄像头拍摄的图片和从SD卡读取的分类结果图片,二者均需要转换成RGB565格式并使用库函数MXC_TFT_ShowImageCameraRGB565()来显示。摄像头图片的格式转换在前面的小节已经介绍过了,下面是SD卡读取的图片转换为RGB565格式并进行显示的代码。要注意的是bmp格式储存的图片是行倒序的,显示时要逐行反向输出。

void display_img(int x, int y)
{
    int j, cnt = 0;
    uint16_t rgb;
    uint8_t r, g, b;
    

    for (int row = DISP_SIZE_Y - 1; row >= 0; row--) {
        j = 0;
        for (int col = 0; col < DISP_SIZE_X; col++) {
            // convert to RGB565 for display
            b = disp_img[cnt++];
            g = disp_img[cnt++];
            r = disp_img[cnt++];
            rgb = ((r & 0b11111000) << 8) | ((g & 0b11111100) << 3) | (b >> 3);
            data565[j] = (rgb >> 8) & 0xFF;
            data565[j + 1] = rgb & 0xFF;
            j += 2;
        }

        MXC_TFT_ShowImageCameraRGB565(x, y + row, data565, DISP_SIZE_X, 1);
    }
}

4.5.3 数据显示

显示屏上显示的数据包括分类结果(宝可梦名称)、宝可梦编号、SD卡中读取的属性、身高、体重信息和Top5分类结果。可以定义TFT_Print()函数来把这些字符串显示在屏幕上。

void TFT_Print(char *str, int x, int y, int font, int length)
{
    // fonts id
    text_t text;
    text.data = str;
    text.len = length;
    MXC_TFT_PrintFont(x, y, font, &text, NULL);
}

调用:

        // Display Pokemon name
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START, font_1,
                      snprintf(buff, sizeof(buff), "%s", pkm_names[num-1]));
        // Display number
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 25, font_1,
                    snprintf(buff, sizeof(buff), "#%03d", num == 152 ? 28 : num));  // 152->28 for Alolan Sandslash
        // Display Top5 classification results
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 150, font_2,
                snprintf(buff, sizeof(buff), "Top5:"));
        for (i = 0; i < 5; i++)
            TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 165 + 15 * i, font_2,
                    snprintf(buff, sizeof(buff), "%s (%d%%)", pkm_names[class2num[top_5_index[i]]-1], result[top_5_index[i]]));
        // Display type 1
        MXC_TFT_SetForeGroundColor(types_color[type1]);
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 50, font_1,
                    snprintf(buff, sizeof(buff), "%s", types[type1]));
        // Display type 2
        MXC_TFT_SetForeGroundColor(types_color[type2]);
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 75, font_1,
                    snprintf(buff, sizeof(buff), "%s", types[type2]));
        // Display height and weight
        MXC_TFT_SetForeGroundColor(WHITE);
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 100, font_1,
                    snprintf(buff, sizeof(buff), "Height: %.2fm", height));
        TFT_Print(buff, TEXT_X_START, TEXT_Y_START + 125, font_1,
                    snprintf(buff, sizeof(buff), "Weight: %.2fkg", weight));
        display_img(DISP_X_START, DISP_Y_START);

5 遇到的问题

5.1 识别准确率问题

由于MAX78000摄像头拍摄到的画面颜色相比数据集中的图片和实际画面有较大差异,并且模型主要以颜色为判断依据而非轮廓,所以模型有时无法准确识别这些图片中的宝可梦。尤其是拍摄到的图片偏白时,模型很容易将其识别为白海狮(Dewgong)或小海狮(Seel),因为这些宝可梦几乎全身都是白色的。如图5‑1所示,拍摄到的宝可梦为#151梦幻(Mew),但模型将其识别成了#087白海狮(Dewgong),不过Mew仍然在识别结果的Top5中。

改进方法有以下几种:

(1) 对相机拍摄到的图片进行处理来减少偏色,如直方图均衡;

(2) 搜集同一宝可梦不同颜色的图片(如素描画、或游戏中的异色宝可梦),或者在数据增广过程中将一部分图片变为灰度图片,让模型结合颜色和轮廓给出判断;

(3) 调整网络架构,将模型分为特征提取器(feature extractor)和分类器(classifier)两个串联的网络,并寻找更合适的特征提取网络从而得到更适合分类的特征。但由于已有的特征提取网络都较大,MAX78000的资源有限,因此这种方法难以在MAX78000上实施。

9k=

图5‑1 识别错误的情况

5.2 输入尺寸问题

宝可梦模型的输入尺寸为128×128×4,超过了MAX78000的限制(32768字节),在生成代码过程中会产生错误:ERROR: Layer 0: 4 channels/word 128x128 input (size 65536) with input offset 0x0000 and expansion 1x exceeds data memory instance size of 32768. 查阅网络描述文档(https://github.com/MaximIntegratedAI/MaximAI_Documentation/blob/master/Guides/YAML%20Quickstart.md)可知可以在尺寸较大的层添加streaming: true选项,流式加载数据而非一次性加载,即可解决输入尺寸过大的问题。

6 总结

本次活动板卡较为独特,是一个功能齐全的嵌入式AI开发板。作为一名嵌入式爱好者、AI领域的一名研究生和宝可梦迷,这个板卡完美的将我的爱好和专业融合在了一起,再加上我一直以来实现游戏与动画中宝可梦图鉴的愿望,于是就诞生了这个基于MAX78000的宝可梦图鉴。本次活动的板卡官方提供的说明文档、例程非常丰富,上手较为容易,美中不足的就是计算资源略微紧张,不能支持较大的模型,不过这也是嵌入式AI的魅力所在哈哈哈。最后感谢硬禾学堂和ADI公司举办本次活动!

 

划分后的数据集:https://rec.ustc.edu.cn/share/b0146640-9330-11ed-ac24-e105c3f0767c

 

附件下载
pokedex.zip
MAX78000 VSCode工程文件
training.zip
训练文件
sd.zip
SD卡文件
团队介绍
中国科学技术大学学生
团队成员
氢化脱氯次氯酸
评论
0 / 100
查看更多
目录
硬禾服务号
关注最新动态
0512-67862536
info@eetree.cn
江苏省苏州市苏州工业园区新平街388号腾飞创新园A2幢815室
苏州硬禾信息科技有限公司
Copyright © 2024 苏州硬禾信息科技有限公司 All Rights Reserved 苏ICP备19040198号