import io
import json
import os
import time
import traceback

import cv2
import dlib
import numpy as np
import requests
from PIL import Image

BASE_DIR = os.getcwd()
MODELS_DIR = os.path.join(BASE_DIR, "_models")
FACEREC_PATH = os.path.join(MODELS_DIR, "dlib_face_recognition_resnet_model_v1.dat")
SHAPE_PATH = os.path.join(MODELS_DIR, "shape_predictor_68_face_landmarks.dat")

face_rec = dlib.face_recognition_model_v1(FACEREC_PATH)
face_detector = dlib.get_frontal_face_detector()
shape_predictor = dlib.shape_predictor(SHAPE_PATH)


def url_to_ndarray(url):
    result = requests.get(url, timeout=30)
    if result.ok:
        img = Image.open(io.BytesIO(result.content))
        img = img.convert("RGB")
        data = np.array(img)
        return data
    else:
        print("http get: {}".format(result.status_code))
        return None


def file_to_ndarray(path):
    result = cv2.imread(path)
    img = Image.fromarray(result).convert("RGB")
    data = np.array(img)
    return data


def face_to_vec(img, max_size=700):
    start = time.time()
    orig_img = img
    print("detect image...")
    scale = 1

    height, width = img.shape[:2]
    if max(height, width) > max_size:
        scale = max_size / float(max(height, width))
        size = (int(width * scale), int(height * scale))
        img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)

    try:
        dets = face_detector(img, 1)
    except Exception as e:
        print(e)
        traceback.print_exc()
        return []

    print("Number of faces detected: {}".format(len(dets)))

    faces = []

    for i, d in enumerate(dets):
        face = {}

        shape = shape_predictor(img, d)
        rect = shape.rect
        landmark = extract_landmark(shape, scale=scale)
        face["rect"] = check_rect(rect, orig_img.shape, scale=scale, landmark=landmark)
        face["landmark"] = json.dumps(landmark)

        try:
            face_descriptor = face_rec.compute_face_descriptor(img, shape)
            face["feature"] = json.dumps(np.array(face_descriptor).tolist())
            faces.append(face)
            # del face_descriptor
        except Exception as e:
            print(e)
            traceback.print_exc()
        # del shape

    print("Compute face cost: {}".format(time.time() - start))
    # del dets

    return faces


def check_rect(rect, shape, scale=1, landmark=[]):
    w, h = shape[1], shape[0]

    # 面部区域最大化
    top = rect.top() / scale
    right = rect.right() / scale
    bottom = rect.bottom() / scale
    left = rect.left() / scale
    #     print
    top = max(min([top] + [p[1] for p in landmark]), 0)
    right = min(max([right] + [p[0] for p in landmark]), w)

    bottom = min(max([bottom] + [p[1] for p in landmark]), h)
    left = max(min([left] + [p[0] for p in landmark]), 0)

    # 检测框转换成正方形
    squareY = bottom - top
    squareX = right - left
    diff = abs(squareY - squareX) / 2
    if squareX > squareY:
        bottom = min(bottom + diff, h)
        top = max(top - diff, 0)
    elif squareX < squareY:
        right = min(right + diff, w)
        left = max(left - diff, 0)

    # 检测框放大
    rect_scale = min((h - bottom) / squareY, top / squareY, (w - right) / squareX, left / squareX, 0.1)
    top -= squareY * rect_scale
    bottom += squareY * rect_scale
    left -= squareX * rect_scale
    right += squareX * rect_scale
    return int(top), int(right), int(bottom), int(left)


def extract_landmark(shape, scale=1):
    """特征点提取"""
    ret = []
    for i in range(68):
        point = shape.part(i)
        ret.append((int(point.x / scale), int(point.y / scale)))
    return ret
