'''
Created on Jul 19, 2011

@author: svenkratz
'''

# TODO: Auswahl der anderen Bilder ueber pulldown-liste

import numpy as np
import matplotlib.pyplot as plt

from Tkinter import *
import tkMessageBox
import tkFont
from haralick import *


classifier = None

class Dialog:
	def __init__(self, master, img1, train_callback=None, verify_callback=None):
		
		self.master = master
		self.img1_arr = img1
		self.img1_draw_arr = np.dstack((img1.copy(), img1.copy(), img1.copy()))
		print self.img1_draw_arr.shape
		self.img1_shp = img1.shape

		self.bgFrame = Frame(master, bd=0)
		self.bgFrame.pack(side=RIGHT)
		
		self.class_radioGroup_selection = IntVar();
		self.mode_selection = IntVar();
		
		self.leftFrame = Frame(master)
		self.leftFrame.pack(side=LEFT, anchor=N)
		
	
		
		Label(self.leftFrame, text="Class Selection", font=tkFont.Font(weight=tkFont.BOLD)).pack(anchor=W)
		
		# RadioGroup for the Class selection
		
		r = Radiobutton(self.leftFrame, text="C1", variable=self.class_radioGroup_selection, value=1, command=self.class_radioGroup_selected)
		r.select()
		r.pack(anchor=W)
		Radiobutton(self.leftFrame, text="C2", variable=self.class_radioGroup_selection, value=2, command=self.class_radioGroup_selected).pack(anchor=W)
		Radiobutton(self.leftFrame, text="C3", variable=self.class_radioGroup_selection, value=3, command=self.class_radioGroup_selected).pack(anchor=W)

		self.selectedClass = 1
		# colors for classses
		self.class_colors = {1: (255, 0, 0),
							 2: (0, 255, 0),
							 3: (0, 0, 255)}
		
		
		# Clear the selections
		Button(self.leftFrame, text="Clear Selections", command=self.btn_clearSelections).pack()
		
		# placeholder
		Label(self.leftFrame, text="", font=tkFont.Font(weight=tkFont.BOLD)).pack(anchor=W)
		
		# Label and Radiogroup for Mode Selection
		Label(self.leftFrame, text="Action", font=tkFont.Font(weight=tkFont.BOLD)).pack(anchor=W)
		r = Radiobutton(self.leftFrame, text="Select", variable=self.mode_selection, value=1, command=self.mode_radioGroup_selected, indicatoron=False, padx=5, pady=5)
		r.select()
		r.pack(anchor=W)
		Radiobutton(self.leftFrame, text="Train", variable=self.mode_selection, value=2, command=self.mode_radioGroup_selected, indicatoron=False, padx=5, pady=5).pack(anchor=W)
		Radiobutton(self.leftFrame, text="Verify", variable=self.mode_selection, value=3, command=self.mode_radioGroup_selected, indicatoron=False, padx=5, pady=5).pack(anchor=W)
		
		self.mode = 1
		


		# the Image Canvas
		self.picture_canvas = Canvas(self.bgFrame, width=self.img1_shp[1], height=self.img1_shp[0], bd=0, highlightthickness=0)
		self.picture_canvas.bind("<Button-1>", self.img_mouseClicked)
		
		self.image = None
		
		self.to_Tkinter_img(self.picture_canvas, self.img1_draw_arr)
		self.picture_canvas.pack()
		
		# Pixel and detected class
		self.bottomFrame = Frame(self.bgFrame)
		self.bottomFrame.pack(anchor=W)
		Label(self.bottomFrame, text="Last Click Location:").pack(side=LEFT)
		
		self.label_lastX_var = StringVar()
		self.label_lastY_var = StringVar()
		self.label_classDetected_var = StringVar()
		
		self.label_lastX_var.set("--")
		self.label_lastY_var.set("--")
		self.label_classDetected_var.set("None")
		
		self.label_lastX = Label(self.bottomFrame, textvariable=self.label_lastX_var, font=tkFont.Font(weight=tkFont.BOLD))
		self.label_lastX.pack(side=LEFT)
		
		self.label_lastY = Label(self.bottomFrame, textvariable=self.label_lastY_var, font=tkFont.Font(weight=tkFont.BOLD))
		self.label_lastY.pack(side=LEFT)
		
		Label(self.bottomFrame, text="Detected Class:").pack(side=LEFT)
		
		self.label_classDetected = Label(self.bottomFrame, textvariable=self.label_classDetected_var, font=tkFont.Font(weight=tkFont.BOLD))
		self.label_classDetected.pack(side=LEFT)
		
		#selections
		self.selections = {1: [],
						   2: [],
						   3: []}
		
		# callbacks
		self.train_callback = train_callback
		self.verify_callback = verify_callback
		
	def to_Tkinter_img(self, canv, img):
		config = {'width':img.shape[1],
				'height':img.shape[0]}
		
		if self.image == None:
			self.image = PhotoImage("Image1", config , self.master)
			
		
		for j in xrange(img.shape[0]):
			for k in xrange(img.shape[1]):
				#self.i.put('#%02x%02x%02x' % tuple(color),(row,col))
				color = img[j, k]
				self.image.put('#%02x%02x%02x' % (color[0], color[1], color[2]), (k, j))
				
				
		self.originalImage = self.image.copy()
		canv.create_image(0, 0, image=self.image, anchor=NW)
	
	def redrawWithNewSelection(self, x, y, w=5):
	
		color = self.class_colors[self.selectedClass]
		for j in range(0, w + 1, 1):
				# draw a diamond
				#print x+w, y+w
				self.image.put('#%02x%02x%02x' % color, (x + j, y + w - j))
				self.image.put('#%02x%02x%02x' % color, (x + j, y - w + j))
				self.image.put('#%02x%02x%02x' % color, (x - j, y + w - j))
				self.image.put('#%02x%02x%02x' % color, (x - j, y - w + j))
		
		#self.to_Tkinter_img(self.picture_canvas, self.img1_draw_arr)
		self.picture_canvas.create_image(0, 0, image=self.image, anchor=NW)
				
	def btn_clearSelections(self):
		''' clear all selections ''' 
		self.selections = {1: [],
						   2: [],
						   3: []}
		# buffer tricks... 
		# repaint the buffer first for feedback
		self.picture_canvas.create_image(0, 0, image=self.originalImage, anchor=NW)
		
		# copy original image to new working image
		self.image = self.originalImage.copy()
		
		
	def img_mouseClicked(self, event):
		x, y = (event.x, event.y)
		
		self.label_lastX_var.set(str(x))
		self.label_lastY_var.set(str(y))
		
		print "Click Location", x, y
		if self.mode == 1:		#select
			self.redrawWithNewSelection(x, y)
			self.selections[self.selectedClass].append((x, y))
		elif self.mode == 2:	#train
			pass # nothing to do		
		elif self.mode == 3:	#verify
			classDetected = self.verify_callback(x, y, self.img1_arr)
			self.label_classDetected_var.set(str(classDetected))
		
	def class_radioGroup_selected(self):
		self.selectedClass = self.class_radioGroup_selection.get()
		print "Class Selected:", self.selectedClass
	
	def mode_radioGroup_selected(self):
		self.mode = self.mode_selection.get()
		print "Selected mode", self.mode
		if self.mode == 1:   #select
			pass
		elif self.mode == 2:   # train
			if self.train_callback != None:
				if (len(self.selections[1]) > 0 and
					len(self.selections[2]) > 0):
					print "Set selections for class 3 to None - WHY???????????"
					if len(self.selections[3]): self.selections[3] = None
					self.train_callback(self.selections, self.img1_arr)
				else:
					tkMessageBox.showwarning("Cannot Train", "Make selections for at least two classes!", icon=tkMessageBox.WARNING)
			else:
				tkMessageBox.showerror("Cannot Train", "No Callback Function Defined!", icon=tkMessageBox.ERROR)
		elif self.mode == 3:       # verify
			if self.verify_callback != None:
				pass
			else:
				tkMessageBox.showwarning("Cannot Verify", "No Callback Function Defined!", icon=tkMessageBox.WARNING)
				
		
		

##### demo Callback Receivers

def training_callback(selections, image):
	""" performs training on the image, with the selections """
	print "Training Callback, your selected classes and their pixels"
	
	global classifier
	classifier = Classifier()
	
	for key, s in selections.iteritems():
		if s is not None and len(s) > 0:		
			for x, y in s:
				region = getRegion(x, y, image)
				h = Haralick(region, 64, [(0, 1), (90, 1)])
				classifier.addToClass(h, key)

	
def verification_callback(x, y, image):
	""" returns the class detected at x,y in the image """
	if classifier is not None:
		region = getRegion(x, y, image)
		h = Haralick(region, 64, [(0, 1), (90, 1)])
	
	#return classifier.classifyByCentroid(h)
	return classifier.classifyBySingleNN(h)
	


def getRegion(x, y, image):
	#some cheating here
	r = 20
	h, w = image.shape
	
	xr = x
	yr = y

	min_x = r
	max_x = w - r
	min_y = r
	max_y = h - r
	
	if x < min_x:
		xr = min_x
	if x > max_x:
		xr = max_x
	if y < min_y:
		yr = min_y
	if y > max_y:
		yr = max_y
	
	region = image[yr - r:yr + r, xr - r:xr + r]
	return region


#### 


root = Tk()
root.title("Classification GUI")

img1 = plt.imread('./textur6.png') * 255

dialog = Dialog(root, img1, train_callback=training_callback, verify_callback=verification_callback)




root.mainloop()
