#!/usr/bin/env python
# aggregate sound from mfcc or fft similarity of chunks

import numpy as np
import scipy.io.wavfile
from features import mfcc
from features import logfbank
from features import base
import copy
import os
import platform;

if int(platform.python_version_tuple()[0])>2:
	from tkinter import *
	from tkinter.filedialog import *
	from tkinter.messagebox import *
else:
	from Tkinter import *
	from tkFileDialog import *
	from tkMessageBox import *


source_dir = "../sound/source/"
render_dir = "../sound/render/"
version = "0.0.2"

def msg(msg):
    print(msg)

def clear_msg():
    pass

def render_stats():
    pass

def fadeinout(s,slength,elength):
    s = copy.deepcopy(s)
    for i in range(0,slength):
        m = float(i)/slength;
        s[i]*=m
    for i in range(0,elength):
        m = float(i)/elength;
        s[(len(s)-1)-i]*=m
    return s

def normalise(s):
    m = 0
    p = 999999999999999999
    for i in range(0,len(s)):
        if m<s[i]: m=s[i]
        if p>s[i]: p=s[i]
    b = max(m,-p)
    if b>0:
        s/=float(b/10000.0)
    return s

def chop(wav,size,overlap,rand,norm):
    ret = []
    pos = 0
    seg = []
    samples = wav[1]
    while (pos+size<len(samples)):
        if norm:
            ret.append([0,fadeinout(normalise(samples[pos:pos+size]),500,500)])
        else:
            ret.append([0,fadeinout(samples[pos:pos+size],500,500)])
        pos+=(size-overlap)
    return ret

def fftify(chopped):
    return map(lambda i: np.fft.fft(i[1]), chopped)

def mfccify(chopped,rate):
    ret = []
    for sig in chopped:
        ret.append(logfbank(sig[1],rate))
    return ret

def fftdiff(a,b):
    return (abs(a-b)).sum(dtype=np.float128)

def diffify(a,b):
    return map(lambda a,b: fftdiff(a,b), a, b)

def search(fft,bank):
    closest = 99999999999999999
    ret = -1
    for i,tfft in enumerate(bank):
        dist = fftdiff(fft,tfft)
        if dist<closest:
            ret = i
            closest = dist
    msg(str(ret))
    return ret

#unit_test()

class transponge():
    def __init__(self):
        self.src=[]
        self.dst=[0,[]]
        self.src_chp=[]
        self.dst_chp=[]
        self.src_fft=[]
        self.dst_fft=[]

    def set_target(self,dst_filename):
        self.dst = scipy.io.wavfile.read(dst_filename)
        msg("succesfully loaded "+dst_filename+" as the target")

    def add_source(self,src_filename):
        ret = [src_filename,scipy.io.wavfile.read(src_filename)]
        self.src.append(ret)
        msg("succesfully loaded "+src_filename+" as a source")
        msg("(now have "+str(len(self.src))+" sounds in brain...)")
        return ret

    def delete_source(self,src_filename):
        self.src = filter(lambda i: i[0]!=src_filename, self.src)

    def prepare(self,chp_size,chp_overlap,mfcc,norm):
        self.chp_size = chp_size
        self.chp_overlap = chp_overlap
        self.src_chp=[]
        self.dst_chp=[]
        self.src_fft=[]
        self.dst_fft=[]

        msg("chopping up target wav...")
        self.dst_chp = chop(self.dst,self.chp_size,self.chp_overlap,0,norm)
        msg("number of target blocks: "+str(len(self.dst_chp)))
        if mfcc:
            self.dst_fft = mfccify(self.dst_chp,self.dst[0])
        else:
            self.dst_fft = fftify(self.dst_chp)
        render_stats()

        self.dst_chp = [] # clear
        self.dst_size = len(self.dst[1])

        for i,src in enumerate(self.src):
            name = src[0]
            src = src[1]
            chopped=chop(src,self.chp_size,self.chp_overlap,0,norm)
            msg(name+" has "+str(len(chopped))+" blocks")
            self.src_chp+=chopped
            if mfcc:
                self.src_fft+=mfccify(chopped,src[0])
            else:
                self.src_fft+=fftify(chopped)
            render_stats()


    def process(self,filename):
        out = np.zeros(self.dst_size,dtype=self.src_chp[0][1].dtype)
        pos = 0
        for i,seg in enumerate(self.dst_fft):
            # collect indices of closest sections
            ii = search(seg,self.src_fft)
            clear_msg()
            for s in range(0,self.chp_size):
                if pos+s<self.dst_size:
                    sample = self.src_chp[ii][1][s]
                    out[pos+s]=out[pos+s]+(sample*0.25)
            pos+=(self.chp_size-self.chp_overlap)
            msg(str((i/float(len(self.dst_fft)))*100.0)[:5]+"%")
            if i%10==0: scipy.io.wavfile.write(filename,44100,out)



class win:
    def __init__(self):
        self.sponge = transponge()

        # create window
        self.root = Tk()
        self.root.title("sample brain "+version)
        top = Frame(self.root)
        top.pack(fill=BOTH)

        f=Frame(top)
        f.pack(side=LEFT,fill=NONE);
        Label(f,text="statistix (last run)").pack()
        self.stats_brain_blocks = StringVar()
        Label(f,textvariable=self.stats_brain_blocks).pack()
        self.stats_brain_blocks.set("brain blocks: not run yet...")
        self.stats_target_blocks = StringVar()
        Label(f,textvariable=self.stats_target_blocks).pack()
        self.stats_target_blocks.set("target blocks: not run yet...")

        f=Frame(top)
        f.pack(side=LEFT);
        Button(f, text="add to brain", command=self.load_source).grid(row=0, column=0, sticky="we")
        Button(f, text="set the target", command=self.load_target).grid(row=0, column=1, sticky="we")
        Button(f, text="save as", command=self.save_as).grid(row=0, column=2, sticky="we")
        self.output_filename = "brain_out.wav"
        self.run_button = Button(f, text="run", command=self.run)
        self.run_button.grid(row=0, column=3, sticky="we")
        self.mfcc_var = IntVar()
        cb=Checkbutton(f, text="use mfcc", variable=self.mfcc_var, command=self.mfcc)
        cb.grid(row=1, column=0)
        cb.select()

        self.norm_var = IntVar()
        cb=Checkbutton(f, text="normalise", variable=self.norm_var, command=self.norm)
        cb.grid(row=1, column=1)

        rf = Frame(f)
        rf.grid(row=1, column=3)
        Label(rf, text="overlap").grid(row=0,column=0)
        self.overlap_entry = Entry(rf, width=5)
        self.overlap_entry.grid(row=1, column=0)
        self.overlap_entry.delete(0, END)
        self.overlap_entry.insert(0, "0.75")

        rf = Frame(f)
        rf.grid(row=1, column=4)
        Label(rf, text="window size").grid(row=0,column=0)
        self.window_entry = Entry(rf, width=5)
        self.window_entry.grid(row=1, column=0)
        self.window_entry.delete(0, END)
        self.window_entry.insert(0, "3000")

        self.target_name = StringVar()
        Label(top,textvariable=self.target_name).pack()
        self.target_name.set("no target yet...")

        Label(top,text="brain contents:").pack()
        self.samples=Frame(top)
        self.samples.pack(fill=NONE,side=RIGHT);



        self.debug = Text(self.root, font = "Helvetica 24 bold", height=10, width=60)
        self.debug.pack()
        self.debug.insert(END, "ready...\n")

        #Label(lf, text="Branch length").grid(row=0,column=0)
        #self.length_scale = Scale(lf, from_=0, to=100, orient=HORIZONTAL, command=self.on_length)
        #self.length_scale.grid(row=1, column=0)
        #self.length_scale.set(30)

        #t.grid_rowconfigure(1,weight=0)
        # start event loop

    def build_sample_gui(self,name,size):
        f = Frame(self.samples)
        f.pack()
        Button(f, text="x", command=lambda: self.delete_sample(name)).pack(side=LEFT)
        Label(f, text=os.path.basename(name)+" ("+str(size)[:5]+" secs)").pack(side=RIGHT)

    def delete_sample(self,name):
        self.sponge.delete_source(name)
        msg("deleted "+name+" from my brain")
        for widget in self.samples.winfo_children():
            widget.destroy()
        for i in self.sponge.src:
            self.build_sample_gui(i[0],len(i[1][1])/float(i[1][0]))


    def msg(self,msg):
        self.debug.insert(0.0, msg+"\n")
        self.root.update()

    def clear_msg(self):
        self.debug.delete(0.0, END)
        self.root.update()

    def load_target(self):
        filename = askopenfilename(title = "load target wav")
        if filename!="":
            self.sponge.set_target(filename)
            self.target_name.set("target: "+os.path.basename(filename))

    def load_source(self):
        filename = askopenfilename(title = "load source wav into brain")
        if filename!="":
            i=self.sponge.add_source(filename)
            self.build_sample_gui(i[0],len(i[1][1])/float(i[1][0]))

    def save_as(self):
        filename = asksaveasfilename(title = "set the output wav")
        if filename!="":
            self.output_filename = filename

    def render_stats(self):
        self.stats_brain_blocks.set("brain blocks: "+str(len(self.sponge.src_chp)))
        self.stats_target_blocks.set("target blocks: "+str(len(self.sponge.dst_fft)))

    def run(self):
        window_size = float(self.window_entry.get())
        overlap = float(self.overlap_entry.get())

        if self.norm_var.get()==1:
            msg("normalising: on")
        else:
            msg("normalising: off")

        self.sponge.prepare(int(window_size),int(window_size*overlap),
                            self.mfcc_var.get()==1,
                            self.norm_var.get()==1)


        msg("processing...")
        self.sponge.process(self.output_filename)
        msg("done, saved in "+self.output_filename)


    def mfcc(self):
        pass
    def norm(self):
        pass


w = win()

msg = w.msg
clear_msg = w.clear_msg
render_stats = w.render_stats

try:
    w.root.mainloop()
except Exception,e:
    msg(e)
