from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from scipy import misc
import sys
import os
import argparse
import numpy as np
import mxnet as mx
import random
import cv2
import sklearn
from sklearn.decomposition import PCA
from time import sleep
from easydict import EasyDict as edict
from AgeGenderDist.mtcnn_detector import MtcnnDetector
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src', 'common'))
from AgeGenderDist import face_image
from AgeGenderDist import face_preprocess
from pkg_resources import resource_filename


MODEL_STR = resource_filename(__name__, "model/model")


def do_flip(data):
  for idx in range(data.shape[0]):
    data[idx,:,:] = np.fliplr(data[idx,:,:])

def get_model(ctx, image_size, layer):
  epoch = 0
  prefix = MODEL_STR
  sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
  all_layers = sym.get_internals()
  sym = all_layers[layer+'_output']
  model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
  model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
  model.set_params(arg_params, aux_params)
  return model


class FaceModelArgs:
  def __init__(self, image_size='112,112', gpu=0, det=0, flip=0, threshold=1.24):
    self.image_size = image_size
    self.gpu = gpu
    self.det = det
    self.flip = flip
    self.threshold = threshold


class FaceModel:
  def __init__(self, args):
    self.args = args
    if args.gpu>=0:
      ctx = mx.gpu(args.gpu)
    else:
      ctx = mx.cpu()
    _vec = args.image_size.split(',')
    assert len(_vec)==2
    image_size = (int(_vec[0]), int(_vec[1]))
    self.model = None
    self.model = get_model(ctx, image_size, 'fc1')

    self.det_minsize = 50
    self.det_threshold = [0.6,0.7,0.8]
    #self.det_factor = 0.9
    self.image_size = image_size
    mtcnn_path = os.path.join(os.path.dirname(__file__), 'mtcnn-model')
    if args.det==0:
      detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=self.det_threshold)
    else:
      detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=[0.0,0.0,0.2])
    self.detector = detector


  def get_input(self, face_img):
    ret = self.detector.detect_face(face_img, det_type = self.args.det)
    if ret is None:
      return None
    bbox, points = ret
    if bbox.shape[0]==0:
      return None
    bbox = bbox[0,0:4]
    points = points[0,:].reshape((2,5)).T
    nimg = face_preprocess.preprocess(face_img, bbox, points, image_size='112,112')
    nimg = cv2.cvtColor(nimg, cv2.COLOR_BGR2RGB)
    aligned = np.transpose(nimg, (2,0,1))
    input_blob = np.expand_dims(aligned, axis=0)
    data = mx.nd.array(input_blob)
    db = mx.io.DataBatch(data=(data,))
    return db, ret

  def get_ga(self, data):
    self.model.forward(data, is_train=False)
    ret = self.model.get_outputs()[0].asnumpy()
    g = ret[:,0:2].flatten()
    gender = np.argmax(g)
    a = ret[:,2:202].reshape( (100,2) )
    a = np.argmax(a, axis=1)
    age = int(sum(a))

    return gender, age

  def get_age_gender_dist(self, img_src):
    img, ret =self.get_input(img_src)
    if ret is None:
        print('ret is none')
        return None
    bbox, points = ret
    points = points[0,:].reshape((2,5)).T
    im=img_src.copy()
    lf_eye=points[0]
    rt_eye=points[1]
    tmp=rt_eye-lf_eye
    dist=math.hypot(tmp[0],tmp[1])
    gender,age=self.get_ga(img)
    return age,gender,dist

