TensorFlow学习笔记(3)构建一个线性回归模型

  • A+
所属分类:TensorFlow教程
广告也精彩
 上次推送算是对TF框架的概述,今天我们来一起完成一个linear regression的程序。这个例子本身很简单,但是却能够浓缩用TF构建机器学习程序的基本步骤:

1.定义全部Tensor对象,包括:可更新的参数、输入、参考输出(groundtruth)

2.定义model

3.定义loss function

4.选择一个优化算法:最基本的例如Gradient Descent,Momentum,比较高级的例如Adadelta,Adam,RMSProp等。

5.运行Session,执行优化算法

后文慢慢展开。

代码可以从主页菌的Github上获取,Github用户名SaoYan,仓库名称:

LearningTensorflow

本次推送对应的demo链接如下(戳原文即可获取):

https://github.com/SaoYan/LearningTensorflow/blob/master/exp2_simple_linear_model.py

OK,准备好你喜欢的编辑器,我们开始写代码吧。

TensorFlow学习笔记(3)构建一个线性回归模型 

定义全部Tensor对象

0、import

Python编程日常,无需多言

TensorFlow学习笔记(3)构建一个线性回归模型
1、可更新参数

    基本的线性回归模型为:y=Wx+b。其中W和b是待定参数,需要通过拟合数据得到。上次推送中已经讲过,待学习的模型参数用tf.Variable对象。

TensorFlow学习笔记(3)构建一个线性回归模型
    初始值可以随便设置,这里主页菌把初值设置成了W=0.3,b=-0.3

2、输入、输出

    根据上次推送的内容,这里应该选用tf.placeholder对象,具体的数值在运行Session的时候再指定。

TensorFlow学习笔记(3)构建一个线性回归模型
TensorFlow学习笔记(3)构建一个线性回归模型 

定义模型

    线性回归模型很简单:

TensorFlow学习笔记(3)构建一个线性回归模型
    主页菌忍不住再强调一下上次推送中反复说的一点:我们现在只是搭建了模型的“骨架”,这里面没有任何实际的数据。

TensorFlow学习笔记(3)构建一个线性回归模型 

定义loss function

    线性回归的loss function是均方误差。

TensorFlow学习笔记(3)构建一个线性回归模型

    y是前面定义的placeholder对象,是输出端的groundtruth。out是模型计算出的输出。

tf.square是求平方的函数,tf.reduce_sum是求和函数。注意:对Tensorflow中的Tensor对象作运算操作必须用Tensorflow内置的操作函数。道理很简单,只有内置函数能接受Tensor类型的输入参数。

TensorFlow学习笔记(3)构建一个线性回归模型 

选择优化器

    线性回归这种级别的模型用普通的梯度下降足够了。

TensorFlow学习笔记(3)构建一个线性回归模型
    这里optimizer是一个优化器对象,learning rate=0.001。train是一个操作符对象,也就是运行1步梯度下降。用Session运行的对象就是这个操作符,运行的时候相关的参数(W,b)都会被更新。

TensorFlow学习笔记(3)构建一个线性回归模型 

运行Session

TensorFlow学习笔记(3)构建一个线性回归模型
    1.首先生成两组随机数,用作训练集。

2.不要忘记运行Variable的初始化!

3.将梯度下降运行1000步,每一步都计算出当前loss的值并显示。

注意:

1.因为两个placeholder对象(x和y)的类型是tf.float32,所以实际“喂”的数据也必须是float32类型,因此在生成随机数的时候作了类型强制转换。

2.大家不妨在程序结尾加这样两行代码:

TensorFlow学习笔记(3)构建一个线性回归模型
    输出结果已经不是初始值0.3和-0.3了,因为它们的值在运行优化器的过程中被隐式地更改了。

一些总结

TensorFlow学习笔记(3)构建一个线性回归模型
1.在实践层面的一些提醒

(1)有哪些可以选用的优化器?最好的方法就是去查Tensorflow的API文档,网址:

https://www.tensorflow.org/api_docs/python/

可以动用浏览器的搜索功能,搜索关键字"optimizer'。

TensorFlow学习笔记(3)构建一个线性回归模型
注意:在初始阶段只关注tf.train.XXXXOptimizer,tf.contrib模块是一个高层封装的API,不建议入门阶段使用。

(2)对Tensor可以进行哪些操作?这篇推送中我们使用到了诸如tf.reduce_sum,tf.square等操作符,这些函数不需要一次性学会,最好的方式就是实践中积累。比如你某一次项目中需要实现对Tensor的某种操作,这时候动用搜索引擎一般都能获得很好的答案(忠告:翻墙用Google英文搜索,远离百度,手动微笑)。如果是一些比较复杂的操作,比如主页菌这几天弄得一个unpooling layer,可能需要求助开源社区,比如Github上面的Issues版面(见下图),或者Quora这样的平台。

TensorFlow学习笔记(3)构建一个线性回归模型

2.有了这次推送的基础,大家不妨试着读一下下面这个代码(仓库中的exp3):https://github.com/SaoYan/LearningTensorflow/blob/master/exp3_SoftmaxRegression_mnist.py

  • 微信
  • 扫一扫
  • weinxin
  • 微信公众号
  • 扫一扫
  • weinxin
广告也精彩
Y40 便携头戴式耳机
粉色毛呢大衣
唐人街探案 大朋vr一体机M2 Pro 头戴式VR眼镜  虚拟现实电影视频 1万+部影视 百款游戏
Wireless无线蓝牙运动耳机
广告也精彩

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: