CNN手写文字识别

深度学习平台

数据集下载(账号:hhzhu 密码:123456)

平台是借助学校的超算环境搭建的,目前有两张GPU卡,CPU的资源比较充足。框架选择的是pytorch,我们也可以在本地搭建pytorch环境,电脑有英伟达的GPU的话,跑代码还是很快的。windows环境搭建pytorch环境参见之前的一篇文章click here。我们的深度学习平台只是帮我们弄好了环境,配置比较优秀,别的和本地是一样的。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 3 13:15:59 2020
@author: hhzhu
"""
#​运用CNN进行MNIST的分类任务

import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
from torch import optim
from torchvision import transforms

# 定义CNN
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=5), # 32, 28*28, 24*24
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2)) # 32, 24*24, 12*12
self.layer2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=5), # 64,12*12, 8*8
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2)) # 128, 8*8, 4*4
self.fc = nn.Sequential(
nn.Linear(1024,1000),
nn.ReLU(inplace=True),
nn.Linear(1000,10))
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
return x


# 使用内置函数下载mnist数据集
#train_set = mnist.MNIST('./data',train=True)
#test_set = mnist.MNIST('./data',train=False)

# 预处理=>将各种预处理组合在一起
# 定义一个新的data_tf,用于除以255归一化

def data_tf(x):
x = np.array(x, dtype='float32')
x = x / 255
x = transforms.ToTensor()(x)
return x
#data_tf = transforms.Compose(
# [transforms.ToTensor(),
# transforms.Normalize([0.5],[0.5])])

train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=False)
test_set = mnist.MNIST('./data',train=False,transform=data_tf,download=False)

import pylab
#%matplotlib inline

print(type(train_set))
print(len(train_set.train_data))
#print(train_set.train_data[0])
im = train_set.train_data[0]
im = im.reshape(-1,28)
print(im)
pylab.imshow(im)
pylab.show()

train_data = DataLoader(train_set,batch_size=64,shuffle=True)
test_data = DataLoader(test_set,batch_size=128,shuffle=False)


net = CNN()
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(),1e-1)
optimizer = optim.Adam(net.parameters(), lr = 0.0001)
nums_epoch = 10

# 开始训练
losses =[]
acces = []
eval_losses = []
eval_acces = []
for epoch in range(nums_epoch):
if torch.cuda.is_available():
net = net.cuda()
train_loss = 0
train_acc = 0
net = net.train()
for img , label in train_data:
if torch.cuda.is_available():
img = Variable(img.cuda())
label = Variable(label.cuda())
else:
img = Variable(img)
label = Variable(label)
out = net(img)
loss = criterion(out,label)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录误差
train_loss += loss.item()
# 计算分类的准确率
_,pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc
losses.append(train_loss / len(train_data))
acces.append(train_acc / len(train_data))
eval_loss = 0
eval_acc = 0
for img , label in test_data:
if torch.cuda.is_available():
img = Variable(img.cuda())
label = Variable(label.cuda())
else:
img = Variable(img)
label = Variable(label)
out = net(img)
loss = criterion(out,label)
# 记录误差
eval_loss += loss.item()
_ , pred = out.max(1)
num_correct = (pred==label).sum().item()
acc = num_correct / img.shape[0]
eval_acc += acc
eval_losses.append(eval_loss / len(test_data))
eval_acces.append(eval_acc / len(test_data))
print('Epoch {} Train Loss {} Train Accuracy {} Teat Loss {} Test Accuracy {}'.format(
epoch+1, train_loss / len(train_data),train_acc / len(train_data), eval_loss / len(test_data), eval_acc / len(test_data)))

10个Epoch的结果

image-20200408110321833

  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.

请我喝杯咖啡吧~

支付宝
微信