# -*- coding: UTF-8 -*-
# !/usr/bin/env python

import numpy as np
import redis
from libs.cache import redis_client
from trans2es.models.tag import Tag
import logging
import traceback
import json
import pickle

class LinUCB:
    d = 2
    alpha = 0.25
    r1 = 1
    r0 = -16
    default_tag_list = list()

    @classmethod
    def get_default_tag_list(cls):
        try:
            if len(cls.default_tag_list) == 0:

                query_item_results = Tag.objects.filter(is_online=True)

                for item in query_item_results:
                    cls.default_tag_list.append(item.id)

            return cls.default_tag_list[:20]
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return list()

    @classmethod
    def linucb_recommend_tag(cls,redis_linucb_tag_data_dict,user_features_list,tag_list):
        """
        :remark 获取推荐标签
        :param redis_linucb_tag_data_dict:
        :param user_features_list:
        :param tag_list:
        :return:
        """
        try:
            Aa_list = list()
            theta_list = list()

            for tag_id in redis_linucb_tag_data_dict:
                tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
                Aa_list.append(tag_dict["Aa"])
                theta_list.append(tag_dict["theta"])

            xaT = np.array([user_features_list])
            xa = np.transpose(xaT)

            art_max = -1
            old_pa = 0

            AaI_tmp = np.array(Aa_list)
            theta_tmp = np.array(theta_list)

            top_tag_list_len = len(tag_list)/3
            top_np_ind = np.argpartition(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)), -top_tag_list_len)[-top_tag_list_len:]

            top_tag_list = list()
            top_np_list = top_np_ind.tolist()
            for tag_id in top_np_list:
                top_tag_list.append(tag_id)

            #art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]

            return top_tag_list
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return []

    @classmethod
    def init_device_id_linucb_info(cls, redis_cli,redis_prefix, device_id, tag_list):
        try:
            redis_key = redis_prefix + str(device_id)

            user_tag_linucb_dict = dict()
            for tag_id in tag_list:
                init_dict = {
                    "Aa": np.identity(cls.d),
                    "theta": np.zeros((cls.d, 1)),
                    "ba": np.zeros((cls.d, 1)),
                    "AaI": np.identity(cls.d)
                }
                pickle_data = pickle.dumps(init_dict)
                user_tag_linucb_dict[tag_id] = pickle_data

            redis_cli.hmset(redis_key, user_tag_linucb_dict)

            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False

    @classmethod
    def update_linucb_info(cls, user_features,reward, tag_id, device_id, redis_prefix,redis_cli):
        try:
            if reward == -1:
                logging.warning("reward val error!")
            elif reward == 1 or reward == 0:
                if reward == 1:
                    r = cls.r1
                else:
                    r = cls.r0

                xaT = np.array([user_features])
                xa = np.transpose(xaT)

                redis_key = redis_prefix + str(device_id)
                ori_redis_tag_data = redis_cli.hget(redis_key, tag_id)

                if not ori_redis_tag_data:
                    LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id])
                else:
                    ori_redis_tag_dict = pickle.loads(ori_redis_tag_data)
                    new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT)
                    new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d))
                    new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa

                    user_tag_dict = {
                        "Aa": new_Aa_matrix,
                        "ba": new_ba_matrix,
                        "AaI": new_AaI_matrix,
                        "theta": np.dot(new_AaI_matrix, new_ba_matrix)
                    }

                    redis_cli.hset(redis_key, tag_id, pickle.dumps(user_tag_dict))
            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False