import AgeGenderDist.face_model
import argparse
import cv2
import sys
import numpy as np
import datetime
from AgeGenderDist.mtcnn_detector import MtcnnDetector
import os
import mxnet as mx
import math

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src', 'common'))

parser = argparse.ArgumentParser(description='face model test')
# general
parser.add_argument('--image-size', default='112,112', help='')
parser.add_argument('--image', default='Tom_Hanks_54745.png', help='')
parser.add_argument('--model', default='model/model,0', help='path to load model.')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
parser.add_argument('--det', default=0, type=int, help='mtcnn option, 1 means using R+O, 0 means detect from begining')
args = parser.parse_args()

def get_age_gender_dist(img_src):
   if args.gpu>=0:
      ctx = mx.gpu(args.gpu)
   else:
      ctx = mx.cpu()
   det_threshold = [0.6,0.7,0.8]
   mtcnn_path = os.path.join(os.path.dirname(__file__), 'mtcnn-model')
   if args.det==0:
      detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=det_threshold)
   else:
      detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=[0.0,0.0,0.2])
   ret = detector.detect_face(img_src, det_type = args.det)
   if ret is None:
      print('ret is none')
   bbox, points = ret
   points = points[0,:].reshape((2,5)).T
   im=img_src.copy()
   lf_eye=points[0]
   rt_eye=points[1]
   tmp=rt_eye-lf_eye
   dist=math.hypot(tmp[0],tmp[1])
  
   model=face_model.FaceModel(args)
   img=model.get_input(img_src)
   gender,age=model.get_ga(img)
    
   return age,gender,dist

def get_gender_age(img):
    model=face_model.FaceModel(args)
    img=model.get_input(img)
    gender,age=model.get_ga(img)
    return gender,age

if __name__=='__main__':          
   dirs= os.listdir(args.image)
   for file in dirs:
       print('file',file)
       imgdir=os.path.join('testimg',file)
       img_src = cv2.imread(imgdir)
       gender,age=get_gender_age(img_src)
       age,gender,dist=get_age_gender_dist(img_src)
       print(age,gender,dist)
       #print('gender:',gender)


