学校大创项目做了关于车辆违章检测的模型,现在简单记录一下~~~
项目主要的模块为车辆目标检测+车辆违章行为检测+车牌识别+微信小程序开发
选取网络
在项目中违章行为识别的思想主要是分类问题,可以简化为二分类(违章+非违章),或者复杂一点的多分类(将违章的情况细分为压实线、占用自行车道、占用人行横道等)
当然更好的方法是通过检测一些可能造成违章的标识,如禁止停车、自行车道标志、白色实线等,但考虑到复杂程度,我还是选择了分类【笑哭】
最终选取了较简单的Mobilenet网络,其中心思想是深度可分卷积,所以速度很快,并非常适合分类问题。
Mobilenet 深度可分卷积
准备数据集
由于涉及个人隐私等问题,与交管部门沟通无果,只好通过网络爬虫和自己拍摄来收集数据集。。。
因为数量较少,所以在训练时使用了数据增强
数据集中违章与非违章的比例约为1:2
训练集与数据集的比例约为10:1,没有设置验证集【数据实在是太少了呜呜呜...】
所有图片都转化为灰度图,代码如下
import cv2 as cv img = cv.imread(image) img=cv.cvtColor(img,cv.COLOR_RGB2GRAY)
将数据集组织好后,放入./data文件夹下
网络训练 Pytorch
使用github上Mobilenet公布的源码:pytorch-mobilenet-master
启动训练代码:
CUDA_VISIBLE_DEVICES=3 python main.py -a mobilenet --resume mobilenet_sgd.pth.tar --lr 0.01 ./data > log.txt
网络结构
class Net(nn.Module): def __init__(self): super(Net, self).__init__() def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False),#inp:input channel,oup:output channel nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def conv_dw(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) #哈哈哈哈在这里可见pytorch真是简单啊~~~ self.model = nn.Sequential( conv_bn( 3, 32, 2), conv_dw( 32, 64, 1), conv_dw( 64, 128, 2), conv_dw(128, 128, 1), conv_dw(128, 256, 2), conv_dw(256, 256, 1), conv_dw(256, 512, 2), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 1024, 2), conv_dw(1024, 1024, 1), nn.AvgPool2d(7), ) self.fc1 = nn.Linear(1024, 2) #这里将输出改为2,因为是二分类 def forward(self, x): x = self.model(x) x = x.view(-1, 1024) x = self.fc1(x) return x
参数设置
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')#数据集存放的位置parser.add_argument('data', metavar='DIR', help='path to dataset')#使用的网络结构 -a mobilenetparser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')#训练的epoch总数parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run')#每次训练从第几个epoch开始parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')#设置batch-sizeparser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)')#设置学习率parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')#设置动量parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)')#设置选用的预训练模型 项目中使用mobilenet提供的模型:mobilenet_sgd.pth.tarparser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
加载预训练模型
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
注意只挑选共同存在的部分加载
# optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch=0 best_prec1 = checkpoint['best_prec1'] pretrained_dict=checkpoint['state_dict'] model_dict = model.state_dict() #注意这里!因为对网络结构进行了修改,所以这里加载resume时,只挑选共同存在的部分加载!! pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True
数据集加载
點擊查看更多內容
為 TA 點贊
評論
評論
共同學習,寫下你的評論
評論加載中...
作者其他優質文章
正在加載中
感謝您的支持,我會繼續努力的~
掃碼打賞,你說多少就多少
贊賞金額會直接到老師賬戶
支付方式
打開微信掃一掃,即可進行掃碼打賞哦