Tkinter Optical Character Recognition Training Data Labeler

In this post I’ll demonstrate how to build a object oriented Tkinter GUI application for associating labels to filenames in order to quickly and easily build a set of training data. The Submit button will associate the label with the file, and the Save and Quit button will dump the file and its associated label into a Python dict, and then a cPickle file for later use. This is still a little rough around the edges; it assumes that you’re looking for PNG data in the current directory, and the output overwrites previous output, but it’s a start.

Important Resources

I wold not have been able to do this without the help of zetcode for simple, complete, object-oriented examples, and effbot for detailed information about widgets and other best practices.

Code

These are my imports: Tkinter for the GUI, PIL.Image and ImageTk for handling PNG images, glob for finding the PNG images in the working directory, and cPickle for storing the output.

from Tkinter import *
import PIL.Image, ImageTk, cPickle, glob

This is a pretty simple, one-level GUI. We have an entry box, and then two buttons. This is set up so that you can enter data on the keypad and then hit the keypad Enter key to go to the next data. When you’re done, the application produces a cPickle file for later use.

class App( Frame ):
    
    def __init__( self, parent ):
        
        Frame.__init__( self, parent )
                
        self.parent = parent
        self.parent.grid()
        
        self.train = dict()
        self.i = 0
        self.fns = glob.glob("*.png")
        
        ## start up the UI
        self.initUI()
        
    def initUI( self ):
        
        ## set the key bindings for the Enter key, and the keypad Enter key
        self.parent.bind( '<Return>', self.submit_callback )
        self.parent.bind( '<KP_Enter>', self.submit_callback )
        
        ## name of the program
        self.parent.title( "Label Training Data" )
             
        ## open the image
        image = PIL.Image.open( self.fns[ self.i ] )
        photo = ImageTk.PhotoImage( image, master=self )
        
        ## put the image in a Label object
        img_label = Label( self, image=photo )
        img_label.image = photo
        img_label.grid(row=0,columnspan=2)
            
        ## entry
        ent_label = Label( self, text="Label:" )
        ent_label.grid( row=1, column=0 )
        self.entry = Entry( self )
        self.entry.grid( row=1, column=1 )
        self.entry.focus()
            
        ## submit button
        submit_btn = Button( self, text="Submit" )
        submit_btn.bind( '<Button-1>', self.submit_callback )
        submit_btn.grid( row=2, columnspan=2 )
        
        ## quit button
        quit_btn = Button( self, text="Save and Quit", command=self.quit_callback )
        quit_btn.grid( row=3, columnspan=2 )
             
        ## pack it up
        self.pack()
        
    def submit_callback( self, event=None ):
        ## associate the user input with the filename
        self.train[ self.fns[ self.i ] ] = self.entry.get()
        ## increment the counter
        self.i += 1
        ## if we're at the end of the data..
        if self,i == len( self.fns ):
            ## save/dump the data
            cPickle.dump( self.train, open( "train_complete.pkl", "w" ), -1 )
            ## kill the aplication
            self.parent.destroy()
        ## reload the window
        self.initUI()
        
    def quit_callback( self ):
        ## save/dump the data
        cPickle.dump( self.train, open( "train_partial.pkl", "w" ), -1 )
        ## kill the aplication
        self.parent.destroy()

Now that the application is written out, we execute it as Pythonically as Pythonistically possible,

def main():
    root = Tk()
    root.geometry("250x100+100+100")
    app = App( root )
    root.mainloop()

if __name__=="__main__":
    main()

We can then extract our data by opening the cPickle output.

train = cPickle.load( open( "train.pkl", "r" ) )
for k, v in train.iteritems():
    print '{:>20}{:>10}'.format(k,v)