import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def set_layer(inputs,in_size,out_size,layer_name,activate_function = None):
W = tf.Variable(tf.random_uniform([in_size, out_size], -1.0, 1.0), name="W" + layer_name)
bias = tf.Variable(tf.constant(0.1, shape=[out_size]), name="bias_"+ layer_name)
Wx_Plus_b = tf.matmul(inputs, W) + bias
Wx_Plus_b = tf.nn.dropout(Wx_Plus_b, keep_prob)#防止过拟合,keep_prob为每个元素的被保留概率
if activate_function is None:
outputs = Wx_Plus_b
else:
outputs = activate_function(Wx_Plus_b)
return outputs
## 参数设定
hidden_layers = 1
hidden_units1 = 200
hidden_units2 = 50
n_input = 784
n_classes = 10
learning_rate = 0.8
## 神经网络的构建
xs = tf.placeholder(tf.float32, [None, n_input], name="input")
ys = tf.placeholder(tf.float32, [None, n_classes], name="output")
keep_prob = tf.placeholder(tf.float32)
h1 = set_layer(xs, n_input, hidden_units1, 'hidden_layer_1', activate_function=tf.nn.tanh)
h2 = set_layer(h1, hidden_units1, hidden_units2, 'hidden_layer_2', activate_function=tf.nn.tanh)
prediction = set_layer(h2, hidden_units2, n_classes, 'prediction_layer', activate_function=None)
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ys, logits=prediction))
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
tf.summary.scalar('loss', cross_entropy)
## 训练结果准确性
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
## 训练
init = tf.initialize_all_variables()
n_epochs = 40
batch_size = 100
with tf.Session() as sess:
st = time.time()
write = tf.summary.FileWriter('logs/', sess.graph)
sess.run(init)
for epoch in range(n_epochs):
n_batch = int(mnist.train.num_examples / batch_size)
for i in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob:0.75})
print ('epoch', epoch, 'accuracy:', sess.run(accuracy, feed_dict={keep_prob:1.0, xs: mnist.test.images, ys: mnist.test.labels}))
end = time.time()
print ('*' * 30)
print ('training finish. cost time:', int(end-st) , 'seconds; accuracy:', sess.run(accuracy, feed_dict={keep_prob:1.0, xs: mnist.test.images, ys: mnist.test.labels}))