#!/usr/bin/env python
import tempfile
import os
from subprocess import call
from time import time as clock


def make_tempdir(name):
        tempdir = tempfile.mkdtemp()
        td = os.path.join(tempdir, name)
        os.mkdir(td)
        return td

TEMPDIR = make_tempdir('m')
LEFT_TD = make_tempdir('l')
RIGHT_TD = make_tempdir('r')
ANIM_TD = make_tempdir('a')

class Morph():
    def __init__(self, p_list):
        self.p_list = p_list
        self.num_frames = 10
        self.l_image = None
        self.r_image = None
        self.l_renders = []
        self.r_renders = []
        self.combined = []
        self.distort = "Shepards"
        self.ext = ".jpg"
            
    def reset(self):
        self.l_renders = []
        self.r_renders = []
        self.combined = []
        td_name = LEFT_TD, RIGHT_TD, ANIM_TD
        curdir = os.getcwd()
        for td in td_name:
            os.chdir(td)
            files_to_remove = os.listdir(td)
            for f in files_to_remove:
                os.remove(f)
        for td in td_name:
            os.chdir(td)
        os.chdir(curdir)
            
    def set_distort(self, d, w):
        if not d:
            self.distort = "Perspective"
        elif d == 1:
            self.distort = "Shepards"
        else:
            self.distort = "Bilinear"
               
    def set_ext(self, side, ext):
        if side == "left":
            if self.r_image:
                r_ext = self.r_image[-4:]
        if side == "right":
            if self.l_image:
                l_ext = self.l_image[-4:]
        try:  
            if ext == r_ext or ext == l_ext:
                self.ext = ext
                return ext
        except:
            return self.ext
       
    def get_prefix(self, inverse):
        if not inverse:
            fn = self.l_image
        else:
            fn = self.r_image
        base = os.path.basename(fn)
        p, ext = os.path.splitext(base)
        return p
        
    def save_to_temp(self, side, filename):
        base = os.path.basename(filename)
        p, ext = os.path.splitext(base)
        base = side + self.set_ext(side, ext)
        tmp_filename = os.path.join(TEMPDIR, base)        
        if side == "left":
            self.l_image = tmp_filename
        elif side == "right":
            self.r_image = tmp_filename
        call(["convert", filename, tmp_filename])
                  
    def crop(self, side, x ,y ,w, h):
        if side == "left":
            fn = self.l_image
        elif side == "right":
            fn = self.r_image
        param = str(w)+"x"+str(h)+"+"+str(x)+"+"+str(y)
        call(["convert",
            fn,
            "-crop", param,
            fn])
        if side == "left":
            return self.l_image
        if side == "right":
            return self.r_image
        
    def set_frame_filename(self, frame_num, inverse):
        p = self.get_prefix(inverse)
        base_filename = "%s-%03d%s" % (p, frame_num, self.ext)
        if inverse:
            td = LEFT_TD
        else:
            td = RIGHT_TD
        return os.path.join(td, base_filename)
        
    def change_fn_dir(self, fn, new_dir):
        base = os.path.basename(fn)
        return os.path.join(new_dir, base)
        
    def save_animation(self, fn, num_frames, reverse, mode, fps):
        self.num_frames = num_frames
        curdir = os.getcwd()
        self.reset()
        delay = str(100 / int(fps))
        names = "*" + self.ext
        prefix_side = False
        if mode == 5:
            self.render_animation_frames(True)
            os.chdir(LEFT_TD)
            prefix_side = True
        elif mode == 4:
            self.render_animation_frames(False)
            os.chdir(RIGHT_TD)
            prefix_side = False
        else:
            self.combine()
            os.chdir(ANIM_TD)
        start = clock()
        if fn[-3:] == "gif":
            if reverse:
                self.duplicate_reversed()
            call(["convert", "-delay", delay, names, fn])
        else:
            if reverse:
                self.duplicate_reversed()
            names = self.get_prefix(prefix_side) + "-%03d" + os.listdir(os.getcwd())[0][-4:]
            call(["ffmpeg",
                "-r", str(fps),
                #"-b", "1800",
                "-y", "-i",
                names, fn])
        os.chdir(curdir)
        end = clock()
        
        print "Total render time", int(end - start), "seconds"
        print "Saved as", fn
    
    def duplicate_reversed(self):
        images = os.listdir(os.getcwd())[::-1]
        prefix = images[0][:-7]
        ext = images[0][-4:]
        for i, e in enumerate(images):
            if i:
                rev_name = "%s%03d%s" % (prefix, len(images) - 1 + i, ext) 
                call(['cp', e, rev_name])
            
    def combine(self):        
        self.render_animation_frames(False)
        self.render_animation_frames(True)
        self.r_renders.reverse()
        p = 100 / (self.num_frames - 1)
        start = clock()
        for i in range(self.num_frames):
            percent = str(i * p)
            l = self.l_renders.pop()
            r = self.r_renders.pop()
            out = self.change_fn_dir(l, ANIM_TD)
            call(["composite", l, "-blend", percent, r, out])
            self.combined.append(out)
        end = clock()
        print "images blended in", int(end - start), "seconds"
    
    def render_animation_frames(self, inverse):
        nf = self.num_frames
        start = clock()
        for n in range(nf):
            if not inverse:
                self.render_frame(n, False)
            if inverse:
                self.render_frame(n, True)
        end = clock()
        print "rendering animation frames in", int(end - start) , "seconds"
             
    def render_frame(self, frame_num, inverse):
        outfile = self.set_frame_filename(frame_num, inverse)       
        if inverse:
            pts1 = self.get_frame_points(self.num_frames - 1)
            pts2 = self.get_frame_points(self.num_frames - 1 - frame_num)
            self.r_renders.append(outfile)
            infile = self.r_image
        else:
            pts1 = self.get_frame_points(0)
            pts2 = self.get_frame_points(frame_num)
            self.l_renders.append(outfile)
            infile = self.l_image
        points = self.pairs_of_points_to_str(pts1, pts2)
        string = ["convert", infile, "-distort", self.distort] + [points] + [outfile]
        call(string)
        print "frame", frame_num, "rendered"
                 
    def pairs_of_points_to_str(self, pts1, pts2):
        ns = ""
        for i, p1 in enumerate(pts1):
            for j, p2 in enumerate(pts2):
                if i == j:
                    ns = ns + str(p1[0])+","+str(p1[1]) + " " + str(p2[0])+","+str(p2[1]) + "  "
        return ns
        
    def get_frame_points(self, frame_num):
        nf = self.num_frames -1
        lp = self.p_list.get_all("left")
        rp = self.p_list.get_all("right")
        if nf == frame_num:
            return rp
        elif frame_num == 0:
            return lp
        else:
            np = []
            for i, l in enumerate(lp):
                for j, r in enumerate(rp):
                    if i == j:
                        lx, ly = l 
                        rx, ry = r 
                        nx = lx + (((rx - lx) *frame_num) / nf)
                        ny = ly + (((ry - ly) * frame_num) / nf)
                        np.append((nx,ny))
            return np
            
    def save_single(self, out_fn, d):
        self.reset()
        if not d:
            self.num_frames = 3
            self.combine()
            in_fn = self.combined[1]
            self.combined = []
        else:
            self.num_frames = 2
            if d == 1:
                self.render_frame(1, False)
                in_fn = self.l_renders.pop()          
            elif d == 2:
                self.render_frame(1, True)
                in_fn = self.r_renders.pop()
        call(['convert', in_fn, out_fn])
        