TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类

  • A+
所属分类:TensorFlow教程
广告也精彩
第3次的推送中,我们构建了一个线性回归模型。模型本身不重要,重要的是体会到用Tensorflow写程序的模式:定义Tensor对象,定义模型,定义loss function,选择优化器,运行Session

今天,我们把线性回归模型升级为卷积神经网路,并且依然按照这个模式进行。

推送中不涉及CNN的理论细节,想学习CNN本身的请移步斯坦福CS231n课程主页:http://cs231n.stanford.edu/

    本次推送对应的源码(戳原文获取):https://github.com/SaoYan/LearningTensorflow/blob/master/exp4_CNN_mnist.py

艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

辅助函数

    为了后面代码的简洁直观,我们首先定义以下几个函数

weight()

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这个函数返回一个从截断正太分布(truncated normal distribution)中初始化的权重参数。后面调用这个函数或得每一次卷积的“核”。注意返回类型是tf.Variable,因为卷积核是可更新参数。(请复习推送2、3)

bias()

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这个函数返回一个常量0.1构成的偏置项参数。注意返回类型也是tf.Variable。

conv2d()

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这个函数构造一个卷积层,给定输入、卷积核,调用Tensorflow函数tf.nn.conv2d计算卷积输出。

关于tf.nn.conv2d的输入参数:

这个是Tensorflow内置的二维卷积操作函数。

参数1:输入Tensor,尺寸必须为(batch, in_height, in_width, in_channels),例如送入一批320x320x3的图像,数量为64个,那么输入参数就是一个尺寸为64x320x320x3的Tensor。

参数2:卷积核,尺寸必须满足(filter_height, filter_width, in_channels, out_channels),例如处理上面那个64x320x320x3的输入,采用5x5的卷积核,输出特征为32个(32个核),那么卷积核应该是一个5x5x3x32的Tensor。

参数3:卷积步长。在二维卷积中,只有height和width这两个维度上能调整步长的参数,因此这个参数必须等于[1,s,s,1],其中s是某正整数。对应输入Tensor尺寸的四个维度batch, height, width, channels,也就是在batch和channels两个维度上面不设步长(也就是步长为1),在height和width两个维度上面设置卷积步长s。

参数4:填充类型。如果是"SAME", 那卷积之前填充0以保证输入和输出的height、width一致;如果是"VALID",那么不进行填充,输出的height、width小于输入。

max_pool_2x2()

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这个函数定义了一个2x2的池化层。调用了Tensorflow内置函数tf.nn.max_pool。

关于tf.nn.max_pool的输入参数:

这个是Tensorflow内置的二维池化操作函数。

参数1:输入Tensor,尺寸要求与tf.nn.conv2d相同。

参数2:每个池化单元的尺寸,与tf.nn.conv2d中strides参数同理,必须等于[1,k,k,1]。也就是只能在height、width两个维度上面进行池化。

参数3:步长,同上。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类 

输入和输出数据

    我们使用经典的MNIST手写字符数据集。Tensorflow已经内置了数据读取和预处理的函数,直接调用即可。输入数据是单通道图像,输出数据是0~9的标签值,标签采用one-hot编码(例如标签0编码为1000000000,标签1编码为0100000000,以此类推)。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    然后我们定义placeholder对象(忘记相关概念的请参见推送1)

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类 

定义模型

两层卷积

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这里调用了前面定义的四个函数,还调用了Tensorflow内置的ReLU激活函数tf.nn.relu。整段代码思路是很清晰的。

全连接层+dropout

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    注意这里两个代码实现上的技巧:

第一,用卷积来实现全连接层。假设前面所有卷积以后得到了WxHxC的一个特征(feature map),如果我们希望通过全连接层得到1xK的特征,那么等效的卷积操作就是用一个WxHxCxK的卷积核,并且不进行填充(这点不要忘记!)

    第二,tf.nn.dropout的第二个输入参数采用placeholder类型。为什么要这样呢?因为这个参数在Session运行的过程中需要改变。在进行梯度下降的时候,要实施dropout防止过拟合;但是在测试结果的时候,不进行dropout,也就是“p_keep=1”。

输出层

    最后再进行一次全连接,或得输出。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    注意:这里依然使用了卷积操作来等效全连接层,但是由于使用的卷积操作,那么输出的尺寸就必定符合卷积的要求,也就是说,输出的尺寸是Nx1x1x10,其中N是batch的大小。然而我们希望的输出是Nx10(对batch里的每一个样本,输出一个类别向量),所以我们调用了tf.squeeze函数,这个函数的作用是去掉尺寸为1的维度。也就是Nx1x1x10降维成Nx10。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类 

定义loss function和优化器

    交叉商损失函数(cross entropy)

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    Adam优化器

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    另外,我们还想衡量一下准确率,也就是预测正确的样本/总样本。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这里的train_step,accuracy等都是定义了一个操作符,后面需要用Session运行它们或得结果。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类 

训练CNN

    首先是一些参数

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    然后开始运行Session

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类
    这里每一句代码都作了注释,思路也很清晰,因此不做赘述。

TensorFlow学习笔记(4)建一个CNN用于MNIST手写字符分类 

一点总结

1.在实际编程过程中,有很多需要注意的问题,最常见的一个就是Tensor的尺寸。不同的函数,输入、输出的尺寸都有严格的规定。实践中,特别注意一下几个问题

(1)WHC不等于CHW。一个图像数据,既可以是320x320x3的矩阵,也可以通过转置操作变换成3x320x320的矩阵,但是Tensorflow各种函数的输入参数格式是固定的。因此,果你使用的数据集不符合格式,注意进行合理的数据预处理。

(2)N不等于Nx1。前面提到,输出层之后调用了tf.squeeze函数,去除了冗余的维度。这一点也很重要。请时刻检查你的网络中各处数据的维度,确保它们是合适的。

2.这次推送中很多Tensorflow的内置函数没有一一讲解,一是因为太多太杂,二是因为它们都很简单很直观,通过查阅官方文档完全可以理解。(例如tf.argmax,tf.reduce_mean,tf.cast)

  • 微信
  • 扫一扫
  • weinxin
  • 微信公众号
  • 扫一扫
  • weinxin
广告也精彩
半身裙时尚
宽松衬衫
唐人街探案 大朋vr一体机M2 Pro 头戴式VR眼镜  虚拟现实电影视频 1万+部影视 百款游戏
iPhone 配件
广告也精彩

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: