MAX78000第二季——十二生肖检测
MAX78000FTHR 机器学习 图片 训练 python 训练集 测试集 图片清洗
标签
嵌入式系统
测试
开发板
MAX78000
aramy
更新2025-03-24
10

MAX78000:是一个小型模块 — 66 × 23 mm — 具有 MAX78000、微型 VGA 摄像头、数字麦克风、立体声音频 I/O、microSD 卡插槽、1 MB QSPI RAM、SWD 调试器/编程器USB 端口和锂聚合物电池充电器。还有两个用户 RGB LED 和两个用户按钮以及与 Adafruit Feather 外形兼容的扩展连接器。一个单独的 JTAG 连接器可用于对 RISC-V 内核进行编程和调试。

项目介绍:这次参加MAX78000的人工智能应用设计大赛第二季,我打算使用MAX78000开发板,进行图像的识别,选择的任务是:1视觉识别类中的十二生肖检测。
基本流程分为两个部分:
1、收集十二生肖图片,进行训练,得到数据识别模型,并用模型生成MAX78000能用的模型数据。
2、使用开发板,调用模型,读取板载摄像头收集环境提供的图片,镜像识别,并输出结果。

项目设计思路:完成项目分两个部分,一个是模型的收集训练;一个是开发板的编程实现。先说说模型收集训练部分。

  1. 训练环境搭建:我的电脑资源为windows10,显卡为2060的卡。训练环境依赖python环境。使用anconda单独为这个项目创建一套环境。
    conda create -n pytorch python=3.8.2

    这里需要留意一下,python的版本似乎只能是3.8的,不能太高了,尝试过使用3.10版本,后边有些包就安装不上。
    然后安装pytorch,这个也是比较简单,访问官网,按电脑实际情况选择安装即可。
    Fq-SL5LjZbIwWE8dfu2b9GxjeFmU最后参考https://github.com/MaximIntegratedAI中的文档,将ai8x-synthesis和ai8x-training这两个文件夹复制到本地,在这两个文件夹里分别有requirements-win-cu11.txt文件,使用命令将需要的包安装上即可。

    pip install -r requirements-win-cu11.txt
  2. 数据收集:机器学习中最辛苦的步骤就是数据收集了。这里我做的是十二生肖的图片识别,所以就需要相当数量的这十二种动物图片作为训练集和验证集。在这我没有自行收集十二生肖的图片,而是直接使用了飞浆提供的十二生肖图片集。这里的十二生肖分类数据集包含12个类,8508张图片。数据集已事先分割为训练、验证和测试,比例为 85 : 7.5 : 7.5。
    自己的项目只用到训练集和验证集,所以就直接将图片中的验证集和测试集合并成验证集了。然后图片数据中有少量的图片有问题,是无法打开的。再者将网上一些生肖邮票的图片,放入了训练集中,导致有些图片是PNG的后缀。在这里统一对原始图片进行一次清洗。对打不开或打开有异常的图片做删除处理,对后缀名统一改为“jpg”处理。
    import os
    import cv2 as cv
    import numpy as np
    
    sourimgpath="E:\\MAX78000\\2023\\ai8x-training\\data\\zodiac"
    desimgpath="..\\img"
    taget_size=(128,128)
    
    def findAllFile(base):      #遍历文件
        for root, ds, fs in os.walk(base):
            for f in fs:
                yield root,f
    
    def resizeimg(sourpath,sourfile,despath):            #图片缩放,入口:源文件名, 目标路径
        # img=cv.imread(sourfile)
        img=cv.imdecode(np.fromfile(sourpath+os.sep+sourfile, dtype=np.uint8), cv.IMREAD_COLOR)
        # x, y = img.shape[0:2]           # 将图片高和宽分别赋值给x,y
        # cv.imshow('OriginalPicture', img)
    
        img_des = cv.resize(img, taget_size,  interpolation=cv.INTER_AREA)
        # cv.imshow('resize1', img_des)
        # cv.waitKey()
        # cv.destroyAllWindows()
        # cv.imwrite(desimgpath+os.sep+sourfile, img_des)
        cv.imencode('.png', img_des)[1].tofile(desimgpath+os.sep+despath+os.sep+sourfile)
    
    def getLab(filepath):      #从路径获取标签
        labinfo=filepath.split("\\")[-1]
        # print(labinfo)
        return  labinfo
    
    
    def checkimg(sourpath,sourfile):
        #检查图片格式是否正确
        # print(sourpath + os.sep + sourfile)
        img = cv.imdecode(np.fromfile(sourpath + os.sep + sourfile, dtype=np.uint8), cv.IMREAD_COLOR)
        if img is None:
            return False
        else:
            # x, y = img.shape[0:2]  # 将图片高和宽分别赋值给x,y
            # print(sourpath + os.sep + sourfile,x,y)
            # img = cv.cvtColor(img, cv.COLOR_RGBA2BGRA)
            os.remove(sourpath + os.sep + sourfile)
            sourfile=sourfile.replace("jpeg","jpg")
            cv.imencode('.png', img)[1].tofile(sourpath + os.sep + sourfile)
            return True
    
    if __name__ == "__main__":
        filenum = 1
        for filepath,i in findAllFile(sourimgpath):
            print(filenum,filepath,i)
            filenum=filenum+1
            if checkimg(filepath,i)==False:
                print(filepath, i)
                os.remove(filepath + os.sep + i)
  3. 模型训练:参考官方例程是我觉得最为有效的一个学习途径。在官方例程中有个猫狗分类的例子,和自己要做的十二生肖项目一样都属于机器学习中的图像分类问题。区别只是输出结果分类的个数问题。所以直接参考官方例程开搞。
    #!/bin/sh
    python train.py --epochs 250 --optimizer Adam --lr 0.001 --wd 0 --deterministic --compress policies/schedule-catsdogs.yaml --model ai85cdnet --dataset cats_vs_dogs --confusion --param-hist --embedding --device MAX78000 "$@"

     

    查看文件ai8x-training\scripts\train_catsdogs.sh 例程中猫狗分类使用了一个ai85cdnet模型,policies/schedule-catsdogs.yaml这个文件没看明白是做什么的,感觉只是些训练用的参数,所以照抄这个文件,然后创建自己的模型文件models/ai85net-zodiac.py。改变了输出结果矩阵,变为12分类。--dataset cats_vs_dogs这个参数是对输入的数据,调用了datasets/cats_vs_dogs.py这个文件进行了数据处理。所以这里模仿着写了个文件处理文件datasets/zodiac.py
    ###################################################################################################
    #
    # Copyright (C) 2022 Maxim Integrated Products, Inc. All Rights Reserved.
    #
    # Maxim Integrated Products, Inc. Default Copyright Notice:
    # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
    #
    ###################################################################################################
    """
    12生肖训练模型 classification network for AI85
    """
    from torch import nn
    
    import ai8x
    
    
    class AI85CatsDogsNet(nn.Module):
        """
        Define CNN model for image classification.
        """
        def __init__(self, num_classes=12, num_channels=3, dimensions=(128, 128),
                     fc_inputs=16, bias=False, **kwargs):
            super().__init__()
    
            # AI85 Limits
            assert dimensions[0] == dimensions[1]  # Only square supported
    
            # Keep track of image dimensions so one constructor works for all image sizes
            dim = dimensions[0]
    
            self.conv1 = ai8x.FusedConv2dReLU(num_channels, 16, 3,
                                              padding=1, bias=bias, **kwargs)
            # padding 1 -> no change in dimensions -> 16x128x128
    
            pad = 2 if dim == 28 else 1
            self.conv2 = ai8x.FusedMaxPoolConv2dReLU(16, 32, 3, pool_size=2, pool_stride=2,
                                                     padding=pad, bias=bias, **kwargs)
            dim //= 2  # pooling, padding 0 -> 32x64x64
            if pad == 2:
                dim += 2  # padding 2 -> 32x32x32
    
            self.conv3 = ai8x.FusedMaxPoolConv2dReLU(32, 64, 3, pool_size=2, pool_stride=2, padding=1,
                                                     bias=bias, **kwargs)
            dim //= 2  # pooling, padding 0 -> 64x32x32
    
            self.conv4 = ai8x.FusedMaxPoolConv2dReLU(64, 32, 3, pool_size=2, pool_stride=2, padding=1,
                                                     bias=bias, **kwargs)
            dim //= 2  # pooling, padding 0 -> 32x16x16
    
            self.conv5 = ai8x.FusedMaxPoolConv2dReLU(32, 32, 3, pool_size=2, pool_stride=2, padding=1,
                                                     bias=bias, **kwargs)
            dim //= 2  # pooling, padding 0 -> 32x8x8
    
            self.conv6 = ai8x.FusedConv2dReLU(32, fc_inputs, 3, padding=1, bias=bias, **kwargs)
    
            self.fc = ai8x.Linear(fc_inputs*dim*dim, num_classes, bias=True, wide=True, **kwargs)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
        def forward(self, x):  # pylint: disable=arguments-differ
            """Forward prop"""
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.conv4(x)
            x = self.conv5(x)
            x = self.conv6(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
    
            return x
    
    
    def ai85cdnet(pretrained=False, **kwargs):
        """
        Constructs a AI85CatsDogsNet model.
        """
        assert not pretrained
        return AI85CatsDogsNet(**kwargs)
    
    
    models = [
        {
            'name': 'ai85cdnet',
            'min_input': 1,
            'dim': 2,
        },
    ]
    
    ###################################################################################################
    #
    # Copyright (C) 2022 Maxim Integrated Products, Inc. All Rights Reserved.
    #
    # Maxim Integrated Products, Inc. Default Copyright Notice:
    # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
    #
    ###################################################################################################
    """
    12生肖 Datasets
    """
    import errno
    import os
    import shutil
    import sys
    
    import torch
    import torchvision
    from torchvision import transforms
    
    from PIL import Image
    
    import ai8x
    
    torch.manual_seed(0)
    
    
    def augment_affine_jitter_blur(orig_img):
        """
        Augment with multiple transformations
        """
        train_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), shear=5),
            transforms.RandomPerspective(distortion_scale=0.3, p=0.2),
            transforms.CenterCrop((180, 180)),
            transforms.ColorJitter(brightness=.7),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5)),
            transforms.RandomHorizontalFlip(),
            ])
        return train_transform(orig_img)
    
    
    def augment_blur(orig_img):
        """
        Augment with center crop and bluring
        """
        train_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop((220, 220)),
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5))
            ])
        return train_transform(orig_img)
    
    
    def zodiac_get_datasets(data, load_train=True, load_test=True, aug=2):
        """
        Load Cats & Dogs dataset
        """
        (data_dir, args) = data
        path = data_dir
        dataset_path = os.path.join(path, "zodiac")
        is_dir = os.path.isdir(dataset_path)
        if not is_dir:
            print(dataset_path,"is not exist!")
            sys.exit("Dataset not found!")
        else:
            processed_dataset_path = os.path.join(dataset_path, "augmented")
    
            if os.path.isdir(processed_dataset_path):
                print("augmented folder exits. Remove if you want to regenerate")
    
            train_path = os.path.join(dataset_path, "train")
            test_path = os.path.join(dataset_path, "test")
            processed_train_path = os.path.join(processed_dataset_path, "train")
            processed_test_path = os.path.join(processed_dataset_path, "test")
            if not os.path.isdir(processed_dataset_path):
                os.makedirs(processed_dataset_path, exist_ok=True)
                os.makedirs(processed_test_path, exist_ok=True)
                os.makedirs(processed_train_path, exist_ok=True)
    
                # create label folders
                for d in os.listdir(test_path):
                    mk = os.path.join(processed_test_path, d)
                    try:
                        os.mkdir(mk)
                    except OSError as e:
                        if e.errno == errno.EEXIST:
                            print(f'{mk} already exists!')
                        else:
                            raise
                for d in os.listdir(train_path):
                    mk = os.path.join(processed_train_path, d)
                    try:
                        os.mkdir(mk)
                    except OSError as e:
                        if e.errno == errno.EEXIST:
                            print(f'{mk} already exists!')
                        else:
                            raise
    
                # copy test folder files
                test_cnt = 0
                for (dirpath, _, filenames) in os.walk(test_path):
                    print(f'copying {dirpath} -> {processed_test_path}')
                    for filename in filenames:
                        if filename.endswith('.jpg'):
                            relsourcepath = os.path.relpath(dirpath, test_path)
                            destpath = os.path.join(processed_test_path, relsourcepath)
    
                            destfile = os.path.join(destpath, filename)
                            shutil.copyfile(os.path.join(dirpath, filename), destfile)
                            test_cnt += 1
    
                # copy and augment train folder files
                train_cnt = 0
                for (dirpath, _, filenames) in os.walk(train_path):
                    print(f'copying and augmenting {dirpath} -> {processed_train_path}')
                    for filename in filenames:
                        if filename.endswith('.jpg'):
                            relsourcepath = os.path.relpath(dirpath, train_path)
                            destpath = os.path.join(processed_train_path, relsourcepath)
                            srcfile = os.path.join(dirpath, filename)
                            destfile = os.path.join(destpath, filename)
    
                            # original file
                            shutil.copyfile(srcfile, destfile)
                            train_cnt += 1
    
                            orig_img = Image.open(srcfile)
    
                            # crop center & blur only
                            aug_img = augment_blur(orig_img)
                            augfile = destfile[:-4] + '_ab' + str(0) + '.jpg'
                            aug_img.save(augfile)
                            train_cnt += 1
    
                            # random jitter, affine, brightness & blur
                            for i in range(aug):
                                aug_img = augment_affine_jitter_blur(orig_img)
                                augfile = destfile[:-4] + '_aj' + str(i) + '.jpg'
                                aug_img.save(augfile)
                                train_cnt += 1
                print(f'Augmented dataset: {test_cnt} test, {train_cnt} train samples')
    
        # Loading and normalizing train dataset
        if load_train:
            train_transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                ai8x.normalize(args=args)
            ])
    
            train_dataset = torchvision.datasets.ImageFolder(root=processed_train_path,
                                                             transform=train_transform)
        else:
            train_dataset = None
    
        # Loading and normalizing test dataset
        if load_test:
            test_transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                ai8x.normalize(args=args)
            ])
    
            test_dataset = torchvision.datasets.ImageFolder(root=processed_test_path,
                                                            transform=test_transform)
    
            if args.truncate_testset:
                test_dataset.data = test_dataset.data[:1]
        else:
            test_dataset = None
    
        return train_dataset, test_dataset
    
    
    datasets = [
        {
            'name': 'zodiac',
            'input': (3, 128, 128),
            'output': ('ratt','ox','tiger','rabbit','dragon','snake','horse','goat','monkey','rooster','dog','pig'),
            'loader': zodiac_get_datasets,
        },
    ]
    

    文件处理这里挺有意思,将原始文件做了随机仿射变换、随机视角变换等多种图像变换映射,这样变换处理后就将一张原始的图片,变换成4张图片,大大地扩展了数据源。

    FkR2d54H9NG6MZW5xQNSidw_UrUH接下来就是漫长的训练时间了。

    python train.py --epochs 1000 --optimizer Adam --lr 0.0001 --wd 0 --deterministic --compress policies/schedule-zodiac.yaml --model ai85cdnet --dataset zodiac --confusion --param-hist --embedding --device MAX78000

    这个训练真是漫长,用办公室一个性能还行的笔记本来跑这个训练,足足跑了5天,才跑完。训练期间,笔记本风扇狂转。这里需要留意一下“ --epochs 1000” 这里是指训练的次数,不能太小了,貌似小于10次,就无法生成“qat_best.pth.tar”文件。

    2023-10-04 09:00:30,052 - Training epoch: 25578 samples (256 per mini-batch)
    2023-10-04 09:01:01,476 - Epoch: [0][   10/  100]    Overall Loss 2.494613    Objective Loss 2.494613                                        LR 0.000100    Time 3.142342    
    2023-10-04 09:01:30,465 - Epoch: [0][   20/  100]    Overall Loss 2.485947    Objective Loss 2.485947                                        LR 0.000100    Time 3.020593    
    2023-10-04 09:01:58,385 - Epoch: [0][   30/  100]    Overall Loss 2.482571    Objective Loss 2.482571                                        LR 0.000100    Time 2.944394    
    2023-10-04 09:02:26,262 - Epoch: [0][   40/  100]    Overall Loss 2.478276    Objective Loss 2.478276                                        LR 0.000100    Time 2.905194    
    2023-10-04 09:02:54,531 - Epoch: [0][   50/  100]    Overall Loss 2.474552    Objective Loss 2.474552                                        LR 0.000100    Time 2.889531    
    2023-10-04 09:03:22,387 - Epoch: [0][   60/  100]    Overall Loss 2.470973    Objective Loss 2.470973                                        LR 0.000100    Time 2.872203    
    2023-10-04 09:03:50,947 - Epoch: [0][   70/  100]    Overall Loss 2.466363    Objective Loss 2.466363                                        LR 0.000100    Time 2.869874    
    2023-10-04 09:04:19,228 - Epoch: [0][   80/  100]    Overall Loss 2.461283    Objective Loss 2.461283                                        LR 0.000100    Time 2.864655    
    2023-10-04 09:04:47,698 - Epoch: [0][   90/  100]    Overall Loss 2.455534    Objective Loss 2.455534                                        LR 0.000100    Time 2.862682    
    2023-10-04 09:05:11,753 - Epoch: [0][  100/  100]    Overall Loss 2.450443    Objective Loss 2.450443    Top1 20.204082    Top5 60.816327    LR 0.000100    Time 2.816965    
    2023-10-04 09:05:11,784 - --- validate (epoch=0)-----------
    2023-10-04 09:05:11,785 - 2842 samples (256 per mini-batch)
    2023-10-04 09:05:24,565 - Epoch: [0][   10/   12]    Loss 2.398368    Top1 17.890625    Top5 59.570312    
    2023-10-04 09:05:25,559 - Epoch: [0][   12/   12]    Loss 2.390640    Top1 17.874736    Top5 59.359606    
    2023-10-04 09:05:25,588 - ==> Top1: 17.875    Top5: 59.360    Loss: 2.391
    
    2023-10-04 09:05:25,589 - ==> Confusion:
    [[  6  52  12  26   9  15  34   0  18  13   3  44]
     [  3  79   9  28   1  11   6   1  11  21   3  71]
     [  8  39  29  28   4  13  39   0  14  12   3  55]
     [  2  24   3  76   0  17  22   2   6  20   2  66]
     [ 10  13  10  34  11   6  28   1  20  29   3  88]
     [  3  25   7  45   5  33  30   0   9  10   0  53]
     [  2  24   7  34   6   3  75   1  13   6   1  62]
     [ 14  49  21  20   1   4  23   1  19  13   7  64]
     [  4  54  17  18   3  10  28   2  15   5   5  55]
     [  9  35   7  29   1  18  22   2  10  36   2  60]
     [  5  60   7  15   3  13  13   0  25  14  13  78]
     [  6  10   7  30   2   6  17   1  12  18   3 134]]
    
    2023-10-04 09:05:25,591 - ==> Best [Top1: 17.875   Top5: 59.360   Sparsity:0.00   Params: 68016 on epoch: 0]
    2023-10-04 09:05:25,591 - Saving checkpoint to: logs/2023.10.04-090029/checkpoint.pth.tar
    2023-10-04 09:05:25,596 - 
    
    2023-10-04 09:05:25,596 - Training epoch: 25578 samples (256 per mini-batch)
    2023-10-04 09:05:55,332 - Epoch: [1][   10/  100]    Overall Loss 2.391924    Objective Loss 2.391924                                        LR 0.000100    Time 2.973487    
    2023-10-04 09:06:23,182 - Epoch: [1][   20/  100]    Overall Loss 2.383263    Objective Loss 2.383263                                        LR 0.000100    Time 2.879239    
    2023-10-04 09:06:51,060 - Epoch: [1][   30/  100]    Overall Loss 2.379587    Objective Loss 2.379587                                        LR 0.000100    Time 2.848744    
    2023-10-04 09:07:19,035 - Epoch: [1][   40/  100]    Overall Loss 2.368326    Objective Loss 2.368326                                        LR 0.000100    Time 2.835905    
    2023-10-04 09:07:47,092 - Epoch: [1][   50/  100]    Overall Loss 2.358375    Objective Loss 2.358375                                        LR 0.000100    Time 2.829871    
    2023-10-04 09:08:15,099 - Epoch: [1][   60/  100]    Overall Loss 2.352214    Objective Loss 2.352214                                        LR 0.000100    Time 2.824993    
    2023-10-04 09:08:42,563 - Epoch: [1][   70/  100]    Overall Loss 2.343251    Objective Loss 2.343251                                        LR 0.000100    Time 2.813764    
    2023-10-04 09:09:10,168 - Epoch: [1][   80/  100]    Overall Loss 2.334533    Objective Loss 2.334533                                        LR 0.000100    Time 2.807103    
    2023-10-04 09:09:37,946 - Epoch: [1][   90/  100]    Overall Loss 2.325909    Objective Loss 2.325909                                        LR 0.000100    Time 2.803839    
    2023-10-04 09:10:01,328 - Epoch: [1][  100/  100]    Overall Loss 2.317868    Objective Loss 2.317868    Top1 26.326531    Top5 68.571429    LR 0.000100    Time 2.757268    
    2023-10-04 09:10:01,368 - --- validate (epoch=1)-----------
    2023-10-04 09:10:01,368 - 2842 samples (256 per mini-batch)
    2023-10-04 09:10:14,737 - Epoch: [1][   10/   12]    Loss 2.232340    Top1 26.289063    Top5 69.257812    
    2023-10-04 09:10:15,722 - Epoch: [1][   12/   12]    Loss 2.229931    Top1 26.143561    Top5 69.071077    
    2023-10-04 09:10:15,760 - ==> Top1: 26.144    Top5: 69.071    Loss: 2.230
    
    ……
    
    2023-10-08 07:31:22,234 - Training epoch: 25578 samples (256 per mini-batch)
    2023-10-08 07:31:57,106 - Epoch: [999][   10/  100]    Overall Loss 0.560685    Objective Loss 0.560685                                        LR 0.000022    Time 3.487077    
    2023-10-08 07:32:29,275 - Epoch: [999][   20/  100]    Overall Loss 0.569232    Objective Loss 0.569232                                        LR 0.000022    Time 3.352005    
    2023-10-08 07:33:01,742 - Epoch: [999][   30/  100]    Overall Loss 0.564238    Objective Loss 0.564238                                        LR 0.000022    Time 3.316882    
    2023-10-08 07:33:34,864 - Epoch: [999][   40/  100]    Overall Loss 0.559806    Objective Loss 0.559806                                        LR 0.000022    Time 3.315700    
    2023-10-08 07:34:07,270 - Epoch: [999][   50/  100]    Overall Loss 0.561057    Objective Loss 0.561057                                        LR 0.000022    Time 3.300660    
    2023-10-08 07:34:40,223 - Epoch: [999][   60/  100]    Overall Loss 0.558504    Objective Loss 0.558504                                        LR 0.000022    Time 3.299774    
    2023-10-08 07:35:12,637 - Epoch: [999][   70/  100]    Overall Loss 0.557877    Objective Loss 0.557877                                        LR 0.000022    Time 3.291423    
    2023-10-08 07:35:44,734 - Epoch: [999][   80/  100]    Overall Loss 0.557789    Objective Loss 0.557789                                        LR 0.000022    Time 3.281204    
    2023-10-08 07:36:17,023 - Epoch: [999][   90/  100]    Overall Loss 0.558885    Objective Loss 0.558885                                        LR 0.000022    Time 3.275387    
    2023-10-08 07:36:44,814 - Epoch: [999][  100/  100]    Overall Loss 0.560229    Objective Loss 0.560229    Top1 80.204082    Top5 98.367347    LR 0.000022    Time 3.225748    
    2023-10-08 07:36:44,864 - --- validate (epoch=999)-----------
    2023-10-08 07:36:44,864 - 2842 samples (256 per mini-batch)
    2023-10-08 07:37:02,058 - Epoch: [999][   10/   12]    Loss 1.512327    Top1 53.476562    Top5 91.093750    
    2023-10-08 07:37:03,570 - Epoch: [999][   12/   12]    Loss 1.480232    Top1 53.589022    Top5 90.957072    
    2023-10-08 07:37:03,600 - ==> Top1: 53.589    Top5: 90.957    Loss: 1.480
    
    2023-10-08 07:37:03,601 - ==> Confusion:
    [[103   3  23  11  15  14  16  21  12   5   4   5]
     [  9 129   7  13   4   4   4   6  13  21  28   6]
     [ 33   5 102  16  19  18  21  11   7   5   3   4]
     [ 13   9  22 125   8  32   5   4   4   9   3   6]
     [ 25   4  20   6 130   6  18  17  13   2   8   4]
     [  6   3  19  26   5 117  19   4   6   4   5   6]
     [ 17   0  10   9  12  19 132   7  20   6   2   0]
     [ 12   7  22   2  19   5  15 122  14   4   9   5]
     [ 15  15  14   6  10   9  21  14  94   8   9   1]
     [  3  19   5   8   6   4  11   6   9 155   3   2]
     [ 12  27  15   4   5   9   1  13  13   9 128  10]
     [  5   5   4  10  10   6   2   1   3   1  13 186]]
    
    2023-10-08 07:37:03,605 - ==> Best [Top1: 54.574   Top5: 91.027   Sparsity:0.00   Params: 68016 on epoch: 887]
    2023-10-08 07:37:03,605 - Saving checkpoint to: logs/2023.10.04-090029/qat_checkpoint.pth.tar
    2023-10-08 07:37:03,609 - --- test ---------------------
    2023-10-08 07:37:03,609 - 1293 samples (256 per mini-batch)
    2023-10-08 07:37:21,180 - Test: [    6/    6]    Loss 1.795043    Top1 48.955916    Top5 86.852282    
    2023-10-08 07:37:21,216 - ==> Top1: 48.956    Top5: 86.852    Loss: 1.795
    
    2023-10-08 07:37:21,217 - ==> Confusion:
    [[44  0 14  5  8  4  6 10  9  2  2  4]
     [ 2 50  1  8  0  4  3  0  5 11 20  3]
     [10  2 40  8  3 13 10  9  5  5  3  1]
     [ 9  5 10 54  1 12  7  1  4  1  1  3]
     [16  0  5  7 49  2 10  5  4  2  5  2]
     [ 2  2 11 17  2 50 10  2  5  3  1  1]
     [ 6  4  7  6  5  9 61  1  7  1  0  1]
     [ 6  0 10  2  5  3  4 58 12  3  3  2]
     [ 6  8  6  3  6  2 12  7 41  8  6  3]
     [ 3 16  1  0  2  2  2  3  4 70  5  0]
     [ 5 17  6  2  0  1  0 10  6  7 50  4]
     [10  2  5  4  2  5  2  4  2  1  5 66]]
    
    2023-10-08 07:37:21,225 - 
    2023-10-08 07:37:21,225 - Log file for this run: /home/mcudev/MAX78000/ai8x-training/logs/2023.10.04-090029/2023.10.04-090029.log

    训练结果不是很满意,不明白为啥识别率还不到50%,也实在是不知道如何调整参数了。感觉需要调整参数,可是训练时间成本太高了。再次尝试,训练了2天,还是堪堪过50%的识别率!先不管了,先继续往下走!

    FmG8jF-__1VBk1JFORGr760kA_2B

    训练完成,接下来做模型转换  。在ai8x-synthesis 下运行:

    python quantize.py ..\ai8x-training\logs\2023.10.04-090029\qat_best.pth.tar trained\zodiac-q.pth.tar --device MAX78000 -v -c networks/zodiac.yaml

    Fpjxbe7rKESmp2O2MSdfYN3Lqaiq
    完成上一步之后会在trained目录下生成一个文件zodiac-q.pth.tar。这里是模仿着“cats-dogs-hwc.yaml”改写了一个“zodiac.yaml”文件。这里也不是太明白这个文件的作用。

    ---
    # HWC (little data) configuration for Pokemon
    # zodiac Model (modified from 2x Wider Simple Model)
    
    arch: ai85zodiac
    dataset: zodiac
    
    layers:
      - out_offset: 0x2000
        processors: 0x0000000000000007  # 1
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
        data_format: HWC
        streaming: true
      - max_pool: 2
        pool_stride: 2
        pad: 1
        operation: conv2d
        kernel_size: 3x3
        activate: ReLU
        out_offset: 0x0000
        processors: 0x00000000000ffff0  # 2
        streaming: true
      - out_offset: 0x2000
        processors: 0x0ffffffff0000000  # 3
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
      - out_offset: 0x0000
        processors: 0x00000000ffffffff  # 4
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
      - max_pool: 2
        pool_stride: 2
        pad: 1
        operation: conv2d
        kernel_size: 3x3
        activate: ReLU
        out_offset: 0x2000
        processors: 0xffffffff00000000  # 5
      - out_offset: 0x0000
        processors: 0x00000000ffffffff  # 6
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
      - out_offset: 0x2000
        processors: 0xffffffff00000000  # 7
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
      - out_offset: 0x0000
        processors: 0x00000000ffffffff  # 8
        operation: conv2d
        kernel_size: 3x3
        pad: 1
        activate: ReLU
      - max_pool: 2
        pool_stride: 2
        pad: 1
        operation: conv2d
        kernel_size: 3x3
        activate: ReLU
        out_offset: 0x2000
        processors: 0xffffffffffffffff  # 9
      - max_pool: 2
        pool_stride: 2
        pad: 1
        operation: conv2d
        kernel_size: 3x3
        activate: ReLU
        out_offset: 0x0000
        processors: 0xffffffffffffffff  # 10
      - max_pool: 2
        pool_stride: 2
        pad: 0
        operation: conv2d
        kernel_size: 1x1
        activate: ReLU
        out_offset: 0x2000
        processors: 0xffffffffffffffff  # 11
      - max_pool: 2
        pool_stride: 2
        out_offset: 0x0000
        processors: 0xffffffffffffffff  # 12
        operation: conv2d
        kernel_size: 1x1
        pad: 0
        activate: ReLU
      - max_pool: 2
        pool_stride: 2
        pad: 1
        operation: conv2d
        kernel_size: 3x3
        activate: ReLU
        out_offset: 0x2000
        processors: 0xffffffffffffffff  # 13
      - out_offset: 0x0000
        processors: 0xffffffffffffffff  # 14
        operation: conv2d
        kernel_size: 1x1
        pad: 0
        output_width: 32
        activate: None
    

    接下来按官方提供的文档是要做做模型评估(感觉这一步没啥用可以省略,这一步没有生成任何后续需要使用的文档,仅仅是验证一下模型的成功率)在 ai8x-training 下操作:

    python train.py --model ai85cdnet --dataset zodiac --confusion --evaluate --exp-load-weights-from ../ai8x-synthesis/trained/zodiac-q.pth.tar -8 --device MAX78000 --use-bias --data data/zodiac

    FkpHng5YYUeDKK_5jhExOSTclOOH

    然后生成测试样本:这一步会产生一个 sample_zodiac.npy 该样本文件后续生成demo时要用。

    python train.py --model ai85cdnet --save-sample 10 --dataset zodiac --evaluate --exp-load-weights-from ../ai8x-synthesis/trained/zodiac-q.pth.tar -8 --device MAX78000 --use-bias --data data\zodiac

    Fu5sYh_95J7q73bU3vz9sBU0XNJS训练的最后一步,模型转换。拷贝sample_zodiac.npy文件 到 ai8x-synthesis/tests下,然后执行:

    python ai8xize.py --test-dir "sdk/Examples/MAX78000/CNN" --prefix zodiac --checkpoint-file trained/zodiac-q.pth.tar --config-file networks/zodiac.yaml --fifo --softmax --device MAX78000 --timer 0 --display-checkpoint --verbose --compact-data --mexpress --sample-input tests/sample_zodiac.npy --boost 2.5

    执行成功后,就会在 sdk/Examples/MAX78000/CNN 下生成 zodiac的工程。这个就是开发板能够运行的工程文件了。接下来就是做开发板端的工作了!

团队介绍
瞎捣鼓小能手。
团队成员
aramy
单片机业余爱好者,瞎捣鼓小能手。
评论
0 / 100
查看更多
目录
硬禾服务号
关注最新动态
0512-67862536
info@eetree.cn
江苏省苏州市苏州工业园区新平街388号腾飞创新园A2幢815室
苏州硬禾信息科技有限公司
Copyright © 2024 苏州硬禾信息科技有限公司 All Rights Reserved 苏ICP备19040198号