我的博客

tensorflow 手写数字识别1

目录
  • 代码概览

main.py: https://paste.ubuntu.com/p/VMDgyHKKcD/

  • 载入数据集

input_data.py: https://paste.ubuntu.com/p/jvhTc7YzCT/ 这个代码可以自动下载或从本地载入训练数据和测试数据。

  • 查看测试数据

import input_data
mnist = input_data.read_data_sets(“MNIST_data/“, one_hot=True)

mnist是DataSet类型的对象,数据分为三部分:mnist.test,mnist.train,mnist.validation这三部分又都包含images和labels,他们的类型均为numpy.ndarray。可以通过他们的shape查看其维度。

> mnist.train.images.shape
(55000, 784)
> mnist.train.labels.shape
(55000, 10)

train包含55000幅图像,validation包含5000幅,test包含10000幅。图像大小是28*28=784。 可以使用matplotlib显示数据集中的图像。

> from matplotlib import pyplot as plt
> plt.imshow(mnist.train.images[0].reshape(28,28))

<matplotlib.image.AxesImage object at 0x7f738147e790>

> plt.show()
> mnist.train.labels[0]
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])

通过matplotlib查看数据图像 通过查看标签可以看到原来这个形似3的数字实际上是个7。

  • 定义模型

import tensorflow as tf
x = tf.placeholder(“float”, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

  • 训练模型

y_ = tf.placeholder(“float”, [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

  • 测试准确率

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, “float”))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

准确率大概是91%

评论无需登录,可以匿名,欢迎评论!