# -*- coding: UTF-8 -*-
# !/usr/bin/env python

import numpy as np
import redis
from libs.cache import redis_client
from trans2es.models.tag import Tag
import logging
import traceback
import json
import pickle
from django.conf import settings


class LinUCB:
    d = 2
    alpha = 0.1
    r1 = 6
    r0 = -0.5
    default_tag_list = list()

    @classmethod
    def get_default_tag_list(cls):
        try:
            if len(cls.default_tag_list) == 0:

                cls.default_tag_list = Tag.objects.using(settings.SLAVE_DB_NAME).filter(is_online=True,collection=1).values_list("id",flat=True)[0:100]

            return cls.default_tag_list
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return list()

    @classmethod
    def linucb_recommend_tag(cls,device_id,redis_linucb_tag_data_dict,user_features_list,tag_list):
        """
        :remark 获取推荐标签
        :param redis_linucb_tag_data_dict:
        :param user_features_list:
        :param tag_list:
        :return:
        """
        try:
            Aa_list = list()
            theta_list = list()

            for tag_id in tag_list:
                tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
                Aa_list.append(tag_dict["Aa"])
                theta_list.append(tag_dict["theta"])

            xaT = np.array([user_features_list])
            xa = np.transpose(xaT)

            art_max = -1
            old_pa = 0

            AaI_tmp = np.array(Aa_list)
            theta_tmp = np.array(theta_list)

            np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa))
            # top_tag_list_len = int(np_array.size/2)
            # top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
            #
            # top_tag_list = list()
            # top_np_list = top_np_ind.tolist()
            # for tag_id in top_np_list:
            #     top_tag_list.append(tag_id)

            #art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]


            top_tag_dict = dict()
            np_score_list = list()
            np_score_dict = dict()

            for score_index in range(0,np_array.size):
                score = np_array.take(score_index)
                np_score_list.append(score)
                if score not in np_score_dict:
                    np_score_dict[score] = [score_index]
                else:
                    np_score_dict[score].append(score_index)

            sorted_np_score_list = sorted(np_score_list,reverse=True)
            for top_score in sorted_np_score_list:
                for top_score_index in np_score_dict[top_score]:
                    tag_id = str(tag_list[top_score_index], encoding="utf-8")
                    top_tag_dict[tag_id] = top_score
                    if len(top_tag_dict) >= 50:
                        break

                if len(top_tag_dict) >= 50:
                    break

            logging.info("duan add,device_id:%s,sorted_np_score_list:%s,np_score_dict:%s" % (str(device_id), str(sorted_np_score_list), str(np_score_dict)))
            return top_tag_dict
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return {}

    @classmethod
    def init_device_id_linucb_info(cls, redis_cli,redis_prefix, device_id, tag_list):
        try:
            redis_key = redis_prefix + str(device_id)

            user_tag_linucb_dict = dict()
            for tag_id in tag_list:
                init_dict = {
                    "Aa": np.identity(cls.d),
                    "theta": np.zeros((cls.d, 1)),
                    "ba": np.zeros((cls.d, 1)),
                    "AaI": np.identity(cls.d)
                }
                pickle_data = pickle.dumps(init_dict)
                user_tag_linucb_dict[tag_id] = pickle_data

            redis_cli.hmset(redis_key, user_tag_linucb_dict)

            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False

    @classmethod
    def update_linucb_info(cls, user_features,reward, tag_id, device_id, redis_prefix,redis_cli):
        try:
            if reward == -1:
                logging.warning("reward val error!")
            elif reward == 1 or reward == 0:
                if reward == 1:
                    r = cls.r1
                else:
                    r = cls.r0

                xaT = np.array([user_features])
                xa = np.transpose(xaT)

                redis_key = redis_prefix + str(device_id)
                ori_redis_tag_data = redis_cli.hget(redis_key, tag_id)

                if not ori_redis_tag_data:
                    LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id])
                else:
                    ori_redis_tag_dict = pickle.loads(ori_redis_tag_data)
                    new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT)
                    new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d))
                    new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa

                    user_tag_dict = {
                        "Aa": new_Aa_matrix,
                        "ba": new_ba_matrix,
                        "AaI": new_AaI_matrix,
                        "theta": np.dot(new_AaI_matrix, new_ba_matrix)
                    }

                    redis_cli.hset(redis_key, tag_id, pickle.dumps(user_tag_dict))
            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False