problem description
the MNIST train code in the book is changed by a b offset and the result is cost = nan
the environmental background of the problems and what methods you have tried
ordinary training, but I feel like there is a parameter. I don"t understand why I changed it with this setting
.related codes
/ / Please paste the code text below (do not replace the code with pictures)
mnist = input_data.read_data_sets ("MNIST_data/", one_hot = True)
tf.reset_default_graph ()
x = tf.placeholder (tf.float32, [None, 784])
y = tf.placeholder (tf.float32, [None, 10])
W = tf.Variable (tf.random_normal ([784,10]))
b = tf.Variable (tf.zeros ([10]))
z = tf.matmul (x, W) + b
maxout = tf.reduce_max (z, axis = 1, keep_dims = True)
W2 = tf.Variable (tf.truncated_normal ([1,10], stddev = 0.1))
b2 = tf.Variable (tf.zeros ([1]))
pred = tf.nn.softmax (tf.matmul (maxout, W2) + b2)
learning_rate = 0.01
cost = tf.reduce_mean (- tf.reduce_sum (y*tf.log (pred), reduction_indices = 1))
optimizer = tf.train.GradientDescentOptimizer (learning_rate) .minimize (cost)
training_epochs = 200
batch_size = 100
display_step = 1
with tf.Session () as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int (mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer, cost], feed_dict = {x: batch_xs, y:batch_ys})
avg_cost += c/total_batch
if(epoch + 1 ) % display_step == 0:
print("Epoch:", "%04d" % (epoch + 1), "cost = " , "{:.9f}".format(avg_cost))
print("Finished!")
what result do you expect? What is the error message actually seen?
want to change B2 to 10-dimensional
b2 = tf.Variable (tf.zeros ([1]))
-> b2 = tf.Variable (tf.zeros ([10]))
1 is the same as a book, but the COST of a book can be close to 0.28
, while mine is only 1.7
. If changed to 10, it can be reduced to 1.5. But then it will become nan
topic description
sources of topics and their own ideas
related codes
/ / Please paste the code text below (do not replace the code with pictures)
what result do you expect? What is the error message actually seen?
problem description
the environmental background of the problems and what methods you have tried
related codes
/ / Please paste the code text below (do not replace the code with pictures)