【TensorFlow系列】基于AlexNet的猫狗大战

  • A+
所属分类:TensorFlow教程
广告也精彩

本文讲述基于TensorFlow如何实现AlexNet,并用AlexNet来对kaggle猫狗数据集做分类,完成基于AlexNet的猫狗大战。

废话不多说,直接贴代码。

import tensorflow as tf
from PIL import Image
import os

def image2tfrecord(path):
    file_list = os.listdir(path)
    tf_write = tf.python_io.TFRecordWriter(r"D:\BaiduNetdiskDownload\dataset_kaggledogvscat\train.tfrecords")
    for i in range(len(file_list)):
        image_name = file_list[i]
        label = 0
        if image_name.startswith(('cat')):
            label = 0
        else:
            label = 1
        image_path = os.path.join(path,image_name)
        image = Image.open(image_path)
        image = image.resize((227,227))
        image = image.tobytes()
        features = {}
        features["image"] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
        features["label"] = tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)]))
        tf_features = tf.train.Features(feature=features)
        example = tf.train.Example(features=tf_features)
        tf_serialized = example.SerializeToString()
        tf_write.write(tf_serialized)
    tf_write.close()

#image2tfrecord(r"D:\BaiduNetdiskDownload\dataset_kaggledogvscat\train")

def pares_tf(example_proto):
    dics = {}
    dics['label'] = tf.FixedLenFeature(shape=[],dtype=tf.int64)
    dics['image'] = tf.FixedLenFeature(shape=[],dtype=tf.string)
    #调用接口解析一行样本
    parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
    image = tf.decode_raw(parsed_example['image'],out_type=tf.uint8)
    image = tf.reshape(image,shape=[227,227,3])
    image = tf.cast(image, tf.float32)
    image = tf.image.per_image_standardization(image)
    label = parsed_example['label']
    label = tf.cast(label,tf.int32)
    label = tf.one_hot(label, depth=2, on_value=1)
    return image,label

#定义输入与标签
x = tf.placeholder(dtype=tf.float32,shape=[None,227,227,3])
y_= tf.placeholder(dtype=tf.float32,shape=[None,2])

#定义第一个卷积层
w1 = tf.Variable(tf.random_normal(shape=[11,11,3,96],dtype=tf.float32,stddev=0.1))
b1 = tf.Variable(tf.zeros(shape=[96]))
conv1 = tf.nn.conv2d(x,w1,strides=[1,4,4,1],padding="VALID")
out1 = tf.nn.bias_add(conv1,b1)
relu1 = tf.nn.relu(out1)
pool1 = tf.nn.max_pool(relu1,ksize=[1,3,3,1],strides=[1,2,2,1],padding="VALID")

#定义第二层卷积
w2 = tf.Variable(tf.random_normal(shape=[5,5,96,256],dtype=tf.float32,stddev=0.1))
b2 = tf.Variable(tf.zeros(shape=[256]))
conv2 = tf.nn.conv2d(pool1,w2,strides=[1,1,1,1],padding="SAME")
out2 = tf.nn.bias_add(conv2,b2)
relu2 = tf.nn.relu(out2)
pool2 = tf.nn.max_pool(relu2,ksize=[1,3,3,1],strides=[1,2,2,1],padding="VALID")

#第三层卷积
w3 = tf.Variable(tf.random_normal(shape=[3,3,256,384],dtype=tf.float32,stddev=0.1))
b3 = tf.Variable(tf.zeros(shape=[384]))
conv3 = tf.nn.conv2d(pool2,w3,strides=[1,1,1,1],padding="SAME")
out3 = tf.nn.bias_add(conv3,b3)
relu3 = tf.nn.relu(out3)

#第四层卷积
w4 = tf.Variable(tf.random_normal(shape=[3,3,384,384],dtype=tf.float32,stddev=0.1))
b4 = tf.Variable(tf.zeros(shape=[384]))
conv4 = tf.nn.conv2d(relu3,w4,strides=[1,1,1,1],padding="SAME")
out4 = tf.nn.bias_add(conv4,b4)
relu4 = tf.nn.relu(out4)

#第五层卷积
w5 = tf.Variable(tf.random_normal(shape=[3,3,384,256],dtype=tf.float32,stddev=0.1))
b5 = tf.Variable(tf.zeros(shape=[256]))
conv5 = tf.nn.conv2d(relu4,w5,strides=[1,1,1,1],padding="SAME")
out5 = tf.nn.bias_add(conv5,b5)
relu5 = tf.nn.relu(out5)
pool5 = tf.nn.max_pool(relu5,ksize=[1,3,3,1],strides=[1,2,2,1],padding="VALID")

#reshape
fc1_input = tf.reshape(pool5,shape=[-1,6*6*256])
#第一个全连接层
fc1_w = tf.Variable(tf.random_normal(shape=[6*6*256,4096],dtype=tf.float32,stddev=0.1))
fc1_b = tf.Variable(tf.zeros(shape=[4096]))
fc1 = tf.matmul(fc1_input,fc1_w)
fc1 = tf.nn.bias_add(fc1,fc1_b)
fc1 = tf.nn.relu(fc1)

#第二个全连接层
fc2_w = tf.Variable(tf.random_normal(shape=[4096,4096],dtype=tf.float32,stddev=0.1))
fc2_b = tf.Variable(tf.zeros(shape=[4096]))
fc2 = tf.matmul(fc1,fc2_w)
fc2 = tf.nn.bias_add(fc2,fc2_b)
fc2 = tf.nn.relu(fc2)

#第三个全连接层
fc3_w = tf.Variable(tf.random_normal(shape=[4096,2],dtype=tf.float32,stddev=0.1))
fc3_b = tf.Variable(tf.zeros(shape=[2]))
fc3 = tf.matmul(fc2,fc3_w)
fc3 = tf.nn.bias_add(fc3,fc3_b)
fc3 = tf.nn.relu(fc3)

#定义损失
# 使用softmax将NN计算输出值表示为概率
y = tf.nn.softmax(fc3)

# 定义交叉熵损失函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))
#定义solver
train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=cross_entropy)

#评估模型
correct_pred = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

init = tf.global_variables_initializer()

dataset = tf.data.TFRecordDataset(r"D:\BaiduNetdiskDownload\dataset_kaggledogvscat\train.tfrecords")
dataset = dataset.map(pares_tf)
dataset = dataset.shuffle(buffer_size=100).batch(32).repeat(20)

iterator = dataset.make_one_shot_iterator()

next_element = iterator.get_next()

with tf.Session() as sess:
    print("start")
    sess.run(init)
    i = 0
    try:
        while True:
            #通过session每次从数据集中取值
            image,label= sess.run(fetches=next_element)
            sess.run(fetches=train, feed_dict={x: image, y_: label})
            if i % 100 == 0:
                train_accuracy = sess.run(fetches=accuracy, feed_dict={x: image, y_: label})
                print(i, "accuracy=", train_accuracy)
            i = i + 1
    except tf.errors.OutOfRangeError:
        print("end!")

  • 微信
  • 扫一扫
  • weinxin
  • 微信公众号
  • 扫一扫
  • weinxin
广告也精彩
商品3
高帮鞋
印花修身无袖连衣裙
懒人鞋
广告也精彩

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: