python


Tensorflow linear classifier not training


I'm trying to create a simple linear classifier for MNIST data and I can not get my loss to go down. What could be the problem?
Here is my code:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
class LinearClassifier(object):
def __init__(self):
print("LinearClassifier loading MNIST")
self._mnist = input_data.read_data_sets("mnist_data/", one_hot = True)
self._buildGraph()
def _buildGraph(self):
self._tf_TrainX = tf.placeholder(tf.float32, [None, self._mnist.train.images.shape[1]])
self._tf_TrainY = tf.placeholder(tf.float32, [None, self._mnist.train.labels.shape[1]])
self._tf_Weights = tf.Variable(tf.random_normal([784,10]), tf.float32)
self._tf_Bias = tf.Variable(tf.zeros([10]), tf.float32)
self._tf_Y = tf.nn.softmax(tf.matmul(self._tf_TrainX, self._tf_Weights) + self._tf_Bias)
self._tf_Loss = tf.reduce_mean(-tf.reduce_sum(self._tf_TrainY * tf.log(self._tf_Y), reduction_indices=[1]))
self._tf_TrainStep = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(self._tf_Loss)
self._tf_CorrectGuess = tf.equal(tf.argmax(self._tf_Y, 1), tf.arg_max(self._tf_TrainY, 1))
self._tf_Accuracy = tf.reduce_mean(tf.cast(self._tf_CorrectGuess, tf.float32))
self._tf_Initializers = tf.global_variables_initializer()
def train(self, epochs, batch_size):
self._sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
self._sess.run(self._tf_Initializers)
for i in range(epochs):
batchX, batchY = self._mnist.train.next_batch(batch_size)
self._loss, _, self._accurracy = self._sess.run([self._tf_Loss, self._tf_TrainStep, self._tf_Accuracy], feed_dict ={self._tf_TrainX: batchX, self._tf_TrainY: batchY})
print("Epoch: {0}, Loss: {1}, Accuracy: {2}".format(i, self._loss, self._accurracy))
When I run this via:
lc = LinearClassifier()
lc.train(1000, 100)
... I gett something like this:
Epoch: 969, Loss: 8.19491195678711, Accuracy: 0.17999999225139618
Epoch: 970, Loss: 9.09421157836914, Accuracy: 0.1899999976158142
....
Epoch: 998, Loss: 7.865959167480469, Accuracy: 0.17000000178813934
Epoch: 999, Loss: 9.281349182128906, Accuracy: 0.10999999940395355
What could be the reason why the tf.train.GradientDescentOptimizer is not training my weights and bias correctly?
The main thing is your learning rate (0.001) is too low. I ran this after changing it to 0.5 like they did in the mnist tensorflow tutorial and I'm getting accuracy and loss more like:
Epoch: 997, Loss: 0.6437355875968933, Accuracy: 0.8999999761581421
Epoch: 998, Loss: 0.6129786968231201, Accuracy: 0.8899999856948853
Epoch: 999, Loss: 0.6442205905914307, Accuracy: 0.8999999761581421
Another thing that's a little unusual is in your original code you have this
self._tf_Y = tf.nn.softmax(tf.matmul(self._tf_TrainX, self._tf_Weights) + self._tf_Bias)
self._tf_Loss = tf.reduce_mean(-tf.reduce_sum(self._tf_TrainY * tf.log(self._tf_Y), reduction_indices=[1]))
In this case you would be doing the softmax twice. I did run it before changing that and the train accuracy was around 85% so it does make some difference. Also doing the softmax twice is less theoretically interpretable.
Finally, they mention in the tutorial that using the form of the softmax above, -reduce_sum(label * log(y)), is numerically unstable, so it's better to use the builtin softmax layer which calculates an analytically equivalent but more numerically stable softmax. After applying these two changes, the affected lines look like:
self._tf_Y = tf.matmul(self._tf_TrainX, self._tf_Weights) + self._tf_Bias
self._tf_Loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=self._tf_TrainY, logits=self._tf_Y))

Related Links

Division by 3 in Python
plot histogram in python using csv file as input
Running python script without installed libraries
In a django data model, is there any way to create a data field for each json attribute stored in a postgres table?
Updating an Existing XML Document in Python
Change in PANDAS .to_csv default formats? Or is it Anaconda?
URL does not work with formatted string but plain string does
Static class variables in Python — Lists & Objects [duplicate]
Python, Heroku & Memcachier - access settings.py variable
os.listdir outputting different files than there are in the folder
Python iterator not working as anticipated
Form placeholder in django doesn't show properly
How the OS handles python and subprocesses of a python script…?
Need help to work with characters longer than 2 or more bytes in Python
Python script to check Namenode status
Python While Loop how to rerun

Categories

HOME
sql-server-2008
converter
batch-processing
signalr
shopify
warnings
wildcard
translation
search-engine
quill
i2c
currency
google-plus
wysiwyg
uiscrollview
tee
spring-cloud-config
zope
session-timeout
derived
ejbca
scalaz7
ghost-inspector
dspic
multiplayer
bcrypt
http-method
searchbar
jflex
hanami
continuous-deployment
metis
widevine
pptp
fileinfo
pep8-assembly
ivy
percentage
oracle-xml-db
spring-insight
data-extraction
rhomobile
fabric-digits
von-neumann
hashcat
oscommerce
npm-publish
equivalence
crop
matlab-cvst
rich-text-editor
fastq
facebook-chatbot
bids
rapidweaver
vcf
w3-total-cache
dimple.js
sapui
istorage
tiddlywiki
execl
pspice
ebtables
c64
tomee
packagemaker
cron-task
visible
tomcat5
qregexp
soda
inotifypropertychanged
gemini
usb-drive
anti-cheat
android-recyclerview
insertion-sort
function-fitting
programming-paradigms
webhdfs
pyrocms
castle-windsor-3
guzzle6
lexicographic
jquery-autocomplete
mov
sid
code-testing
neoload
data-generation
lumx
openkinect
hamsterdb
bsod
confusion-matrix
facebook-sdk-3.1
driver-signing
dig
entity-framework-4.1
gridworld
shiva3d
pascals-triangle
plone-funnelweb
macruby
rose-db-object
expression-evaluation
ninject-extensions
point-sprites
flexicious
msgbox
symbol-server
delegatecommand
windows-controls
microsoft.ink
divx

Resources

Mobile Apps Dev
Database Users
javascript
java
csharp
php
android
MS Developer
developer works
python
ios
c
html
jquery
RDBMS discuss
Cloud Virtualization
Database Dev&Adm
javascript
java
csharp
php
python
android
jquery
ruby
ios
html
Mobile App
Mobile App
Mobile App