# -*- coding: UTF-8 -*-
# !/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import redis
import sys
sys.path.append("/home/gmuser/gm_mab/")
from libs.cache import redis_client

import logging
import traceback
import json
import pickle
import pymysql
import random
import time


class LinUCB:
    d = 6
    alpha = 0.01
    r1 = 10
    r0 = -0.1
    default_tag_list = list()

    zhengxing_host = "172.16.30.141"
    zhengxing_user = "work"
    zhengxing_password = "BJQaT9VzDcuPBqkd"
    zhengxing_database = "zhengxing"

    redis_name_linucb_matrix_prefix = "strategy:linucb:content_type:"
    @classmethod
    def linucb_recommend_tag(cls,redis_linucb_tag_data_dict,user_features_list,tag_list):
        """
        :remark 获取推荐标签
        :param redis_linucb_tag_data_dict:
        :param user_features_list:
        :param tag_list:
        :return:
        """
        try:
            Aa_list = list()
            theta_list = list()

            for tag_id in tag_list:
                tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
                Aa_list.append(tag_dict["Aa"])
                theta_list.append(tag_dict["theta"])

            xaT = np.array([user_features_list])
            xa = np.transpose(xaT)

            art_max = -1
            old_pa = 0

            AaI_tmp = np.array(Aa_list)
            theta_tmp = np.array(theta_list)

            begin1 = time.time()
            np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa))
            print(time.time()-begin1)
            # top_tag_list_len = int(np_array.size/2)
            # top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
            #
            # top_tag_list = list()
            # top_np_list = top_np_ind.tolist()
            # for tag_id in top_np_list:
            #     top_tag_list.append(tag_id)

            #art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]


            top_tag_set = set()
            top_tag_dict = dict()
            np_score_list = list()
            np_score_dict = dict()

            for score_index in range(0,np_array.size):
                score = np_array.take(score_index)
                np_score_list.append(score)
                if score not in np_score_dict:
                    np_score_dict[score] = [score_index]
                else:
                    np_score_dict[score].append(score_index)

            sorted_np_score_list = sorted(np_score_list,reverse=True)
            for top_score in sorted_np_score_list:
                for top_score_index in np_score_dict[top_score]:
                    # tag_id = str(tag_list[top_score_index], encoding="utf-8")
                    tag_id = tag_list[top_score_index]
                    top_tag_dict[tag_id] = top_score
                    top_tag_set.add(tag_id)
                    if len(top_tag_dict) >= 20:
                        break

                if len(top_tag_dict) >= 20:
                    break

            return (top_tag_dict,top_tag_set)
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return ({},())

    @classmethod
    def init_all_arm_by_card_content(cls,card_content="diary",user_features_list=list()):
        try:
            redis_name_linucb_matrix = cls.redis_name_linucb_matrix_prefix + card_content
            if card_content == "diary":
                zhengxing_conn = pymysql.connect(
                    host=cls.zhengxing_host,
                    user=cls.zhengxing_user,
                    password=cls.zhengxing_password,
                    database=cls.zhengxing_database,
                    charset="utf8")

                zhengxing_cursor = zhengxing_conn.cursor()
                diary_id_sql = "select id from api_diary where is_online=true and content_level in (5,6);"

                diary_id_list = list()
                zhengxing_cursor.execute(diary_id_sql)
                sql_tag_results = zhengxing_cursor.fetchall()
                for item in sql_tag_results:
                    diary_id = int(item[0])
                    diary_id_list.append(diary_id)

                for diary_id in diary_id_list:
                    init_dict = {
                        "Aa": np.identity(cls.d),
                        "theta": np.zeros((cls.d, 1)),
                        "ba": np.zeros((cls.d, 1)),
                        "AaI": np.identity(cls.d)
                    }
                    pickle_data = pickle.dumps(init_dict)
                    redis_client.hset(redis_name_linucb_matrix,diary_id,pickle_data)


                    user_feature_index = random.randint(0,9)
                    user_feature = user_features_list[user_feature_index]

                    cls.update_linucb_info(user_feature,1,diary_id,redis_name_linucb_matrix,redis_client)

                    print(str(user_feature) + "\t" + str(diary_id))
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False

    @classmethod
    def init_device_id_linucb_info(cls, redis_cli,redis_name_linucb_matrix, tag_list):
        try:
            user_tag_linucb_dict = dict()
            for tag_id in tag_list:
                init_dict = {
                    "Aa": np.identity(cls.d),
                    "theta": np.zeros((cls.d, 1)),
                    "ba": np.zeros((cls.d, 1)),
                    "AaI": np.identity(cls.d)
                }
                pickle_data = pickle.dumps(init_dict)
                user_tag_linucb_dict[tag_id] = pickle_data

            redis_cli.hmset(redis_name_linucb_matrix, user_tag_linucb_dict)

            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False

    @classmethod
    def update_linucb_info(cls, user_features,reward, content_id, redis_name_linucb_matrix,redis_cli):
        try:
            if reward == -1:
                logging.warning("reward val error!")
            elif reward == 1 or reward == 0:
                if reward == 1:
                    r = cls.r1
                else:
                    r = cls.r0

                xaT = np.array([user_features])
                xa = np.transpose(xaT)

                ori_redis_tag_data = redis_cli.hget(redis_name_linucb_matrix, content_id)

                if not ori_redis_tag_data:
                    LinUCB.init_device_id_linucb_info(redis_client, redis_name_linucb_matrix,[content_id])
                else:
                    ori_redis_tag_dict = pickle.loads(ori_redis_tag_data)
                    new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT)
                    new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d))
                    new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa

                    user_tag_dict = {
                        "Aa": new_Aa_matrix,
                        "ba": new_ba_matrix,
                        "AaI": new_AaI_matrix,
                        "theta": np.dot(new_AaI_matrix, new_ba_matrix)
                    }

                    redis_cli.hset(redis_name_linucb_matrix, content_id, pickle.dumps(user_tag_dict))
            else:
                logging.warning("not standard linucb reward")
            return True
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return False


if __name__ == "__main__":
    # user_features = [
    #     [1, 2, 1, 1, 3, 1],
    #     [1, 4, 3, 1, 3, 1],
    #     [3, 2, 1, 5, 5, 1],
    #     [1, 2, 4, 2, 3, 2],
    #     [1, 5, 7, 1, 4, 4],
    #     [3, 4, 1, 1, 3, 1],
    #     [5, 2, 1, 6, 3, 1],
    #     [1, 2, 3, 2, 3, 5],
    #     [1, 2, 1, 1, 2, 4],
    #     [1, 2, 6, 4, 2, 1],
    # ]

    user_features = [
        ["1", "2", "9", "1", "3", "b"],
        ["2", "2", "1", "2", "2", "1"],
        ["3", "7", "1", "1", "3", "1"],
        ["4", "2", "4", "5", "3", "a"],
        ["3", "2", "r", "1", "4", "1"],
        ["5", "5", "2", "3", "3", "1"],
        ["3", "2", "1", "1", "3", "1"],
        ["7", "3", "5", "1", "5", "1"],
        ["8", "2", "1", "9", "3", "c"],
        ["2", "2", "7", "1", "32", "1"],
    ]

    LinUCB.init_all_arm_by_card_content(user_features_list=user_features)


    test_user_feature = ["1", "2", "9", "1", "3", "b"]

    begin=time.time()
    all_diary_content_redis_dict = redis_client.hgetall("strategy:linucb:content_type:diary")
    top_tag_dict, top_tag_set = LinUCB.linucb_recommend_tag(all_diary_content_redis_dict,test_user_feature,list(all_diary_content_redis_dict.keys()))
    print(time.time()-begin)
    print(top_tag_dict)