首页
人工智能
网络安全
手机
搜索
登录
搜索
golden81
累计撰写
154
篇文章
累计收到
0
条评论
首页
栏目
首页
人工智能
网络安全
手机
包含标签 【2d】 的文章
2025-4-27
PyTorch 实现FCN网络用于图像语义分割
本文主要介绍了如何在昇腾上,使用pytorch对图像分割任务的开山之作FCN网络在VOC2012数据集上进行训练的实战过程讲解。主要内容包括FCN网络的创新点分析、FCN网络架构分析,实战训练代码分析等等。 本文的目录结构安排如下所示: FCN创新点分析 FCN网络架构分析 FCN网络搭建过程及代码详解 端到端训练Voc2012数据集全过程分析 FCN(Fully Convolutional Networks)网络创新点分析 采用全卷积结构替换全连接层 FCN将传统分类网络中的全连接层替换为卷积层,使得网络能够接受任意尺寸的输入图像,而不需要固定输入大小。这种设计使得FCN能够直接对图像进行端到端的像素级预测,适用于语义分割任务。 使用多尺度特征融合的思想 通过改进的上采样和下采样技术及高效的跳跃连接结构,FCN能够整合网络中不同层的特有特征。采用多尺度特征融合可以显著地提高模型在复杂场景分析中的精度和鲁棒性。 自适应和动态卷积核 FCN中的卷积核可以根据输入数据的特性动态调整其大小和形状,从而更有效地提取特征。这种自适应能力使得FCN在多种不同类型的图像处理任务中都能表现出色。 跳跃连接(跟第二点结合) FCN使用了类似ResNet的跳跃连接结构,将深层的粗糙语义信息和浅层的精细表征信息融合,以实现更加精细的语义分割。这种跳跃连接结构在上采样过程中融合了不同维度的特征,保留了更多细节,帮助模型更精细地重建图像信息。 使用反卷积操作用于上采样 FCN使用反卷积层进行上采样,将最后一个卷积层的特征图恢复到与输入图像相同的尺寸。反卷积层通过不同尺度的上采样操作,保留了原始输入图像的空间信息,使得网络对每个像素都能实现高效的类别预测。 FCN网络架构分析 从网络的总体架构图中可以看出FCN的结构非常简单。首先,输入图片经过若干个卷积层实现特征提取后,再通过反卷积操作将图像大小还原到指定大小实现像素级别的类别预测。最后输出的channel维度是21,这是因为论文中使用的pascal voc数据集总共20个类别,加上背景一起总共21个类别。 根据这21个值进行softmax处理就能得到图像中每个像素属于这21个类别的概率值,取最大的那个值作为该像素最终的类别预测结果,这样就可以得到整张图所有像素点的类别预测情况,然后,每个类别用不同的颜色区分从而整张图片的背景与类别被清晰的分割开(例如:图中的猫、狗及背景分别用蓝色、棕色与绿色区分)。 上图通过使用全连接层得到最终的维度为1000的向量,由于全连接层要求的输入大小必须是固定的,因此作者将网络中的全连接层转换为卷积层,输入图像的大小可以是任意的。 那么最后的输出就不是一个一维向量了,就变成了(m,n,c),对应每个channel就是一个2D的数据,可以可视化成一个heatmap图。 图中将全连接层全部替换成了卷积层,其中全计算量与卷积层的计算量分别为:全连接是25088 × 4096 = 102760448,卷积的计算量是7 × 7 × 512 × 4096 = 102760448,可以看到他们的计算量是一模一样的相当于把全连接的权重进行了reshape操作。 论文中FCN网络有三种模式的模型,分别是FCN-32s,16s,8s,其中数字的含义是将最后得到的特征图通过上采样多少倍后能够恢复到原图尺寸的大小。图中省略了卷积层与其他层级信息,只保留了池化层用于展示多尺度特征融合的过程。 整个网络第一步通过将特征图上采样32倍得到原图大小的输出,此时得到的是FCN-32s模型。然后将该特征图进行2倍的上采样与pool4层的特征图进行结合得到FCN-16s模型,此时的网络能够预测更精细的细节,同时保留高级语义信息。同理,将得到的FCN-16s模型进行2倍的上采样后与pool3的特征图进行融合得到FCN-8s模型,该模型可以得到更加精准的预测。 除此以外,从上述的分析可以发现,FCN网络在结合不同尺度特征信息的过程中,还可以继续往深层次的继续结合得到FCN-4s,2s模型,这里可以根据需要结合前面pool2与pool1层的信息即可。 FCN网络搭建过程及代码详解 基于torch搭建FCN网络,需要导入torch相关模块,其中nn.Module是各个神经网络模型需要继承的基类。 import torch import torch.nn as nn 由于FCN网络采用的是全类卷积层操作,论文中分别使用Alexnet、VGG16与GoogleNet网络作为backone后用VOC数据集进行微调对比,得到FCN-VGG16的mean IU最高,IU与目标检测模型中的IOU意思一样,用来反映模型预测与框定的效果好坏。 目前PyTorch官方实现中使用ResNet-50作为backbone,原始论文中提出的FCN使用的是VGG16作为backbone,但在PyTorch的官方实现中,由于ResNet-50在性能上有更好的表现,因此一般都会选择ResNet-50作为backbone后用数据做微调。本文的实现不采用任何backbone,从零到一搭建一个FCN8s网络模型。 整个FCN8s网络模型通过一个FCN8类来实现,其中FCN8类中继承了'nn.Module'模块,网络总共包含两部分,前一部分对图像输入进行特征提取并不断降维,后一部分通过对得到的特征图进行不同倍率的上采样,从而融合不同尺度特征得到FCN8s模型。 前一部分总共包含5个stage,每个stage最后都用一个'Maxpool'操作用于特征图提取与降维,对应类中'nn.MaxPool2d(kernel_size=2,padding=0)'。stage1、stage2与stage5均定义了一层'Conv2d'、'Relu'与'BatchNorm2d'操作组合并结合'Maxpool'操作。stage3与stage4分别定义了三层与两层Conv2d、Relu与BatchNorm2d操作组合并结合Maxpool操作。 后一部分定义了upsample_2、upsample_4、upsample_81与upsample_82,也就是2、4与8三种不同倍率的下采样。VOC数据集图片的输入后本文会将其裁剪到224x224,因此网络的输入size是224x224。 class FCN8(nn.Module): def __init__(self, num_classes): # 调用super方法调用父类nn.Module的初始化函数 super(FCN8, self).__init__() ''' 定义stage1, Conv2d中in_channels=3与输入图像3通道相对应,out_channels=96表示输出的通道维度是96,kernel=3表示卷积核大小是3x3,对输入图像padding=1。 根据卷积的size计算公式output= ((i + 2p -k) /s + 1),i表示输入图像的尺寸,p表示padding,k表示卷积核大小,s表示步长。 假设batch=1的情况下输入图像为224x224x3,通过'Conv2d'后输出的size为 (224 + 2 -3)/1 +1 = 224,因此conv2d后输出图像为224x224x96。 BatchNorm2d输入前后不改变图像size大小,通过MaxPool2d操作降维后得到最终输出图像大小为112x112x96。 ''' self.stage1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=96), nn.MaxPool2d(kernel_size=2, padding=0) ) ''' 定义stage2, Conv2d中in_channels=96与stage1中输出维度相对应,out_channels=256表示输出的通道维度是256。同理通过Conv2d后得到输出size为112x112x256 通过'MaxPool2d'操作后变为56x56x256是stage2的最终输出。 ''' self.stage2 = nn.Sequential( nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=256), nn.MaxPool2d(kernel_size=2, padding=0) ) # 定义stage3, 假设batch=1,input = 56x56x256,则output= 28x28x256。 self.stage3 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=384), nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=384), nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=256), nn.MaxPool2d(kernel_size=2, padding=0) ) # 定义stage4, 假设batch=1,input = 28x28x256,则output= 14x14x512。 self.stage4 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=512), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=512), nn.MaxPool2d(kernel_size=2, padding=0) ) # 定义stage5, 假设batch=1,input = 14x14x512,则output= 7x7xnum_classes。 self.stage5 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=num_classes), nn.MaxPool2d(kernel_size=2,padding=0) ) ''' 定义2倍率的上采样过程,可以分别得到一个2倍率上采样的特征图便于做特征融合。 上采样过程 out_size = (i -1)*S-2P + k,其中i、S、P与k分别表示输入图像size,步长、padding与卷积核大小。 upsample_2结合的是stage4的输出特征图,因此input = 14x14x512,output = 28x28x512。 ''' self.upsample_2 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, padding= 1, stride=2) ''' 定义4倍率的上采样过程,可以分别得到一个4倍率上采样的特征图便于做特征融合。 upsample_4结合的是stage5的输出特征图,因此input = 7x7xnum_classes,output = 28x28xnum_classes。 ''' self.upsample_4 = nn.ConvTranspose2d(in_channels=num_classes, out_channels=num_classes, kernel_size=4, padding= 0,stride=4) ''' 下述upsample_81与upsample_82的作用是将上述特征融合后的图像分别通过4倍与2的上采样将图像还原成原始输入大小(28 * 4 * 2 = 224)。 其中in_channels与out_channels中512 + num_classes + 256表示三个不同维度的channel进行拼接得到。 ''' self.upsample_81 = nn.ConvTranspose2d(in_channels=512 + num_classes + 256, out_channels=512 + num_classes + 256, kernel_size=4, padding= 0,stride=4) self.upsample_82 = nn.ConvTranspose2d(in_channels=512 + num_classes + 256, out_channels=512 + num_classes + 256, kernel_size=4, padding= 1,stride=2) # 最后的预测模块,input:224x224x(512 + num_classes + 256), output:224x224xnum_classes。 self.final = nn.Sequential( nn.Conv2d(512 + num_classes + 256, num_classes, kernel_size=7, padding=3), ) def forward(self, x): x = x.float() # conv1->pool1->输出 x = self.stage1(x) # conv2->pool2->输出 x = self.stage2(x) # conv3->pool3->输出, 经过上采样后, 需要用pool3暂存 x = self.stage3(x) pool3 = x # conv4->pool4->输出, 经过上采样后, 需要用pool4暂存 x = self.stage4(x) pool4 = self.upsample_2(x) x = self.stage5(x) conv7 = self.upsample_4(x) # 对所有上采样过的特征图进行concat, 在channel维度上进行叠加 x = torch.cat([pool3, pool4, conv7], dim = 1) # 经过一个分类网络,输出结果(这里采样到原图大小,分别一次2倍一次4倍上采样来实现8倍上采样) output = self.upsample_81(x) output = self.upsample_82(output) output = self.final(output) return output 将网络模型结构进行打印,可以看到网络的整体结构与上述描述相一致。至此,FCN8s网络架构全部搭建完成,接下来将用该网络来介绍如何训练VOC数据集。 print(FCN8(21)) FCN8( (stage1): Sequential( (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (stage2): Sequential( (0): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (stage3): Sequential( (0): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (4): ReLU() (5): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): ReLU() (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (stage4): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (4): ReLU() (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (stage5): Sequential( (0): Conv2d(512, 21, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (upsample_2): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (upsample_4): ConvTranspose2d(21, 21, kernel_size=(4, 4), stride=(4, 4)) (upsample_81): ConvTranspose2d(789, 789, kernel_size=(4, 4), stride=(4, 4)) (upsample_82): ConvTranspose2d(789, 789, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (final): Sequential( (0): Conv2d(789, 21, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3)) ) ) 端到端训练Voc2012数据集全过程分析 VOC数据集介绍 VOC数据集,全称Visual Object Classes,是一个广泛使用的计算机视觉数据集,主要用于目标检测、图像分割和图像分类等任务。该数据集最初由英国牛津大学的计算机视觉小组创建,并在PASCAL VOC挑战赛中使用。VOC数据集包含了大量不同类别的标记图像,每个图像都有与之相关联的边界框(bounding box)和对象类别的标签。 VOC数据集在类别上可以分为4大类,20小类,涵盖了人、汽车、猫、狗等常见目标类别。此外,VOC数据集还提供了用于图像分割任务的像素级标注,该数据集分为21类,其中20类为前景物体,1类为背景。数据集量级方面,VOC2007和VOC2012是两个最流行的版本,分别包含了约10000张和20000张标注图像,本文采用VOC2012数据集。 下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar VOC数据集结构 VOC数据集的结构相对复杂,但非常有序。它主要包含以下几个文件夹: ImageSets:包含三个子文件夹(Layout、Main、Segmentation),用于存放不同数据集划分(训练集、验证集、测试集)的文件名列表。 JPEGImages:存放所有的图片,包括训练、验证和测试用到的所有图片。 SegmentationClass:包含已经标注好的图像。 Annotations:存放每张图片相关的标注信息,以XML格式的文件存储。这些文件包含了图像中每个目标的类别、边界框坐标等详细信息。 SegmentationObject:文件夹中包含实例分割用到的标签图像。 其中本文实验需要用到的3个文件夹均已标粗。 如图所示,图像分割任务需要将图中的物体与物体间,物体与背景间信息区分开来,不同物体标记不同颜色,本文实验用到的VOC数据集总共包含21种类别,'VOC_COLORMAP'定义了每一个类别的颜色信息,包含RGB三个,'VOC_CLASSES'对应数据集中21个类别。 VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] # 定义一个一维向量colormap2label,含有256^3个元素,目的是为了让三通道图像的每一点像素特征都有所对应类别所对应。 colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8) # 给包含类别的物体赋予颜色标签,不属于类别内的rgb是全为0,也就是整个图片中除了背景与物体以外的颜色为全黑。 for i, colormap in enumerate(VOC_COLORMAP): colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i VOC数据集patch 为了方便整个训练过程快速有效的进行,我们对输入的数据按照一定的patch输入进行训练,因此需要定义一个数据个数转换类'VOCSegDataset',用来生成每一个批次送给网络所需要的数据。 在定义该类以前,我们需要定义一些对于文件及标签处理操作函数,分别是'voc_label_indices'、'read_file_list'与'voc_rand_crop'。 import numpy as np def voc_label_indices(colormap): """ convert colormap (PIL image) to colormap2label (uint8 tensor). """ colormap = np.array(colormap).astype('int32') idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]) return colormap2label[idx] # 针对于VOC2012数据集读取训练与验证集文件返回训练集或验集所有图片路径及标签 def read_file_list(root, is_train=True): txt_fname = root + '/ImageSets/Segmentation/' + ('train.txt' if is_train else 'val.txt') with open(txt_fname, 'r') as f: filenames = f.read().split() images = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in filenames] labels = [os.path.join(root, 'SegmentationClass', i + '.png') for i in filenames] return images, labels # file list # 对输入的VOC图片进行裁剪到指定的height与width def voc_rand_crop(image, label, height, width): """ Random crop image (PIL image) and label (PIL image). """ i, j, h, w = transforms.RandomCrop.get_params( image, output_size=(height, width)) image = transforms.functional.crop(image, i, j, h, w) label = transforms.functional.crop(label, i, j, h, w) return image, label 数转换类'VOCSegDataset',继承'torch.utils.data.Dataset'用于迭代送入图片给模型进行训练及验证, class VOCSegDataset(torch.utils.data.Dataset): def __init__(self, is_train, crop_size, voc_root): """ crop_size: (h, w) """ self.transform = transforms.Compose([ transforms.ToTensor(), #transforms.Normalize(mean=self.rgb_mean, std=self.rgb_std) ]) # (h, w) self.crop_size = crop_size images, labels = read_file_list(root=voc_root, is_train=is_train) # images list self.images = self.filter(images) # labels list self.labels = self.filter(labels) print('Read ' + str(len(self.images)) + ' valid examples') # 过滤掉尺寸小于crop_size的图片 def filter(self, imgs): return [img for img in imgs if ( Image.open(img).size[1] >= self.crop_size[0] and Image.open(img).size[0] >= self.crop_size[1])] def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] image = Image.open(image).convert('RGB') label = Image.open(label).convert('RGB') image, label = voc_rand_crop(image, label, *self.crop_size) image = self.transform(image) label = voc_label_indices(label) # float32 tensor, uint8 tensor return image, label def __len__(self): return len(self.images) 调用上述定义好的数据格式转化类'VOCSegDataset'生成训练集与验证集集合'voc_train'与'voc_val',从打印可以看出本次实验只读取了1456与1436张图片用于训练与测试。 import os from torchvision import transforms from PIL import Image voc_train = VOCSegDataset(is_train = True, crop_size=(224,224), voc_root = '/home/pengyongrong/workspace/VocData') voc_val = VOCSegDataset(is_train = False, crop_size=(224,224), voc_root = '/home/pengyongrong/workspace/VocData') Read 1456 valid examples Read 1436 valid examples 接下来对训练集中的部分图片进行可视化,通过引入matplotlib库来进行可视化,这里展示了5张图片及对应标签,如果想要展示更多,可以设置i的取值即可。从可视化结果可以看出图中每一个不同类别与背景都被用不同颜色区分开啦,例如图一中的飞机、人与背景分别用暗紫色、黄色与紫色区分开来。 import matplotlib.pyplot as plt for i, (img, label) in enumerate(voc_train): plt.figure(figsize=(10,10)) plt.subplot(221) plt.imshow(img.moveaxis(0,2)) plt.subplot(222) plt.imshow(label) plt.show() plt.close() if i ==5: break 导入昇腾npu相关库transfer_to_npu、该模块可以使能模型自动迁移至昇腾上。 import torch_npu from torch_npu.contrib import transfer_to_npu /home/pengyongrong/miniconda3/envs/AscendCExperiments/lib/python3.9/site-packages/torch_npu/dynamo/__init__.py:18: UserWarning: Register eager implementation for the 'npu' backend of dynamo, as torch_npu was not compiled with torchair. warnings.warn( /home/pengyongrong/miniconda3/envs/AscendCExperiments/lib/python3.9/site-packages/torch_npu/contrib/transfer_to_npu.py:164: ImportWarning: ************************************************************************************************************* The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now.. The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now.. The backend in torch.distributed.init_process_group set to hccl now.. The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now.. The device parameters have been replaced with npu in the function below: torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty ************************************************************************************************************* warnings.warn(msg, ImportWarning) from torch.utils.data import Dataset, DataLoader #创建dataloader,定义每一批次送入模型进行训练的batch_size这里设置成8,也可以根据需要改成任意>=2的取值。 trainloader = DataLoader(voc_train, batch_size = 8, shuffle=True,) testloader = DataLoader(voc_val, batch_size = 4) optim实现了各种优化算法的库(例如:SGD与Adam),在使用optimizer时候需要构建一个optimizer对象,这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。 # 导入torch及相关模块库,便于后续搭建神经网络模型使用 import torch.optim as optim import torch.nn.functional as F #定义模型训练在哪种类型的设备上跑 device = 'npu' # 构建模型,这里VOC数据类别是21,因此入参num_classes=21,若是其他的类别,此处可以根据需要进行设置。 net = FCN8(num_classes=21) #将网络模型加载到指定设备上,这里device是昇腾的npu net = net.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=1.0, weight_decay=5e-4) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,0.1,steps_per_epoch=len(trainloader), epochs=150,div_factor=25,final_div_factor=10000,pct_start=0.3) 训练模块: 根据传入的迭代次数开始训练网络模型,这里需要在model开始前加入net.train(),使用随机梯度下降算法是将梯度值初始化为0(zero_grad()),计算梯度、通过梯度下降算法更新模型参数的值以及统计每次训练后的loss值(每隔100次打印一次) from tqdm import tqdm def train(epoch): net.train() train_loss = 0.0 epoch_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, 0)): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() lr_scheduler.step() train_loss += loss.item() epoch_loss += loss.item() if batch_idx % 100 == 99: # 每100次迭代打印一次损失 print(f'[Epoch {epoch + 1}, Iteration {batch_idx + 1}] loss: {train_loss / 100:.3f}') train_loss = 0.0 return epoch_loss / len(trainloader) 测试模块: 每训练一轮将会对最新得到的训练模型效果进行测试,使用的是数据集准备时期划分得到的测试集。 def test(): net.eval() val_loss = 0 val_loss_all=[] val_num = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) out = F.log_softmax(outputs, dim=1) loss = criterion(out, targets) val_loss += loss.item() * len(targets) val_num += len(targets) # 计算一个epoch在验证集上的损失和精度 val_loss_all.append(val_loss / val_num) return val_loss_all[-1] 训练与测试的次数为2次,这里用户可以根据需要自行选择设置更高或更低,每个epoch的准确率都会被打印出来,如果不需要将代码注释掉即可,这里可以看到两个epoch间的loss在下降(从1.94-\>1.63)。 #开启模型训练与测试过程 for epoch in range(2): epoch_loss = train(epoch) test_accuray = test() print(f'Epoch loss for FCN8s at epoch {epoch + 1}: {epoch_loss:.3f}') 55%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 100/182 [01:49<01:28, 1.08s/it] [Epoch 1, Iteration 100] loss: 1.940 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 182/182 [03:14<00:00, 1.07s/it] 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [02:36<00:00, 2.29it/s] Epoch loss for FCN8s at epoch 1: 1.825 55%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 100/182 [01:42<01:22, 1.01s/it] [Epoch 2, Iteration 100] loss: 1.628 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 182/182 [03:04<00:00, 1.01s/it] 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [02:39<00:00, 2.25it/s] Epoch loss for FCN8s at epoch 2: 1.630 内存使用情况: 整个训练过程的内存使用情况可以通过"npu-smi info"命令在终端查看,因此本文实验只用到了单个npu卡(也就是chip 0),内存占用约13G,对内存、精度或性能优化有兴趣的可以自行尝试进行优化。 Reference ========= \[1\] Long, Jonathan , E. Shelhamer , and T. Darrell . "Fully Convolutional Networks for Semantic Segmentation." IEEE Transactions on Pattern Analysis and Machine Intelligence 39.4(2015):640-651.
2025年-4月-27日
11 阅读
0 评论
人工智能