# -*- coding: UTF-8 -*-
# !/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import sys
sys.path.append("/home/gmuser/gm_mab/")
from libs.cache import redis_client
from linucb.core.Linucb import *

import logging
import traceback
import json
import pickle
import pymysql
import random
import time
import pymysql
from sklearn.preprocessing import LabelEncoder
from elasticsearch import Elasticsearch


class Generate_Feature_Info(object):

    host = "172.16.30.141"
    user = "work"
    password = "BJQaT9VzDcuPBqkd"
    database = "zhengxing"

    redis_name_content_linucb_feature_prefix = "strategy:linucb:feature:content_type:"
    redis_name_user_linucb_feature_prefix = "strategy:linucb:feature:user"


    @classmethod
    def get_tagv3_word_list(cls):
        try:
            tagv3_name_set = set()
            zhengxing_conn = pymysql.connect(
                host=cls.host,
                user=cls.user,
                password=cls.password,
                database=cls.database,
                charset="utf8")

            zhengxing_cursor = zhengxing_conn.cursor()

            tag_v3_project_sql = """
                    select name from api_tag_3_0 where is_online=1 and tag_type=1;
            """

            zhengxing_cursor.execute(tag_v3_project_sql)
            sql_tag_results = zhengxing_cursor.fetchall()
            for item in sql_tag_results:
                encode_item = item[0].encode("utf-8")
                tagv3_name_set.add(encode_item)

            return tagv3_name_set
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return set()

    @classmethod
    def get_tagv3_label_encode(cls,tagv3_name_list):
        try:
            label_encoder = LabelEncoder().fit(tagv3_name_list)
            return label_encoder
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return None

    @classmethod
    def generate_content_feature_to_redis(cls,label_encoder,content_type="diary"):
        try:
            redis_name_content_linucb_feature = cls.redis_name_content_linucb_feature_prefix + content_type


            es = Elasticsearch([
                {
                    'host': '172.16.31.17',
                    'port': 9000,
                }
            ])
            page = es.search(
                index='gm-dbmw-diary-read',
                doc_type='diary',
                scroll='10m',
                search_type='scan',
                size=10,
                body={
                    "query": {
                        "filtered": {
                            "filter": {
                                "bool": {
                                    "must": [
                                        {"term": {"is_online": True}},
                                        {"terms": {"content_level": [3,3.5,4,5,6]}}
                                    ]
                                }
                            }
                        }
                    },
                    "_source": {"include": ["id", "tags_v3", "content_level"]}
                }
            )
            sid = page['_scroll_id']
            scroll_size = page['hits']['total']
            while (scroll_size > 0):
                try:
                    page = es.scroll(scroll_id=sid, scroll='10m')
                    sid = page['_scroll_id']
                    scroll_size = len(page['hits']['hits'])

                    for item in page['hits']['hits']:
                        diary_id = item["_source"]["id"]
                        tags_v3 = item["_source"]["tags_v3"] if "tags_v3" in item["_source"] else list()
                        content_level = item["_source"]["content_level"] if item["_source"]["content_level"] else -1

                        offi_tags_v3 = list()
                        for tag_item in tags_v3:
                            offi_tags_v3.append(tag_item.encode("utf-8"))

                        tag_label = -1
                        if len(offi_tags_v3)>0:
                            try:
                                tag_label = label_encoder.transform(offi_tags_v3)[0]
                            except:
                                pass

                        diary_feature_list = [tag_label,content_level]
                        redis_client.hset(redis_name_content_linucb_feature,diary_id,json.dumps(diary_feature_list))

                except:
                    logging.error("catch exception,err_msg:%s" % traceback.format_exc())

        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())


    @classmethod
    def _get_user_portrait_tag3_redis_key(cls,device_id):
        return "doris:user_portrait:tag3:device_id:" + str(device_id)

    @classmethod
    def generate_user_feature_to_redis(cls,device_id,label_encoder):
        try:
            redis_name_user_linucb_feature = cls.redis_name_user_linucb_feature_prefix

            user_portrait_redis_name = cls._get_user_portrait_tag3_redis_key(device_id)
            user_portrait_redis_data = redis_client.get(user_portrait_redis_name)
            user_portrait_dict = json.loads(user_portrait_redis_data) if user_portrait_redis_data else dict()

            user_tag_label = -1
            if len(user_portrait_dict)>0:
                sorted_user_portrait_list = sorted(user_portrait_dict["projects"].items(), key=lambda x: x[1], reverse=True)
                user_max_score_tag = sorted_user_portrait_list[0][0].encode("utf-8")
                try:
                    user_tag_label = label_encoder.transform([user_max_score_tag])[0]
                except:
                    pass

            user_feature_list = [user_tag_label]
            redis_client.hset(redis_name_user_linucb_feature,device_id,json.dumps(user_feature_list))

        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())


    @classmethod
    def get_user_feature_by_device_id(cls,device_id):
        try:
            redis_name_user_linucb_feature = cls.redis_name_user_linucb_feature_prefix
            user_feature_redis_data = redis_client.hget(redis_name_user_linucb_feature,device_id)

            user_feature_list = list()
            if user_feature_redis_data:
                user_feature_list = [float(json.loads(user_feature_redis_data)[0])]

            return user_feature_list
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return list()

    @classmethod
    def get_content_feature(cls,card_id,content_type="diary"):
        try:
            redis_name_content_linucb_feature = cls.redis_name_content_linucb_feature_prefix + content_type
            user_feature_list = list()

            content_feature_redis_data = redis_client.hget(redis_name_content_linucb_feature, card_id)
            if content_feature_redis_data:
                content_feature_redis_list = json.loads(content_feature_redis_data)
                user_feature_list = [float(item) for item in content_feature_redis_list]

            return user_feature_list
        except:
            logging.error("catch exception,err_msg:%s" % traceback.format_exc())
            return list()


if __name__ == "__main__":

    tagv3_name_set = Generate_Feature_Info.get_tagv3_word_list()

    label_encoder = Generate_Feature_Info.get_tagv3_label_encode(tagv3_name_list=list(tagv3_name_set))

    # device_id="868771031984211"
    # Generate_Feature_Info.generate_user_feature_to_redis(device_id,label_encoder)
    # Generate_Feature_Info.generate_content_feature_to_redis(label_encoder)


    linucb_matrix_redis_name = "strategy:linucb:matrix:content_type:diary"

    diary_click_file = "/data/log/duan_test/feed_query_data/feed_click_info.txt"
    diary_fd = open(diary_click_file,"r")

    for line in diary_fd.readlines():
        line = line.strip()
        line = line.strip("\n")
        line = line.strip("\r")
        line = line.strip(" ")

        item_list = line.split(",")
        device_id = item_list[3]
        diary_id = item_list[4]
        Generate_Feature_Info.generate_user_feature_to_redis(device_id, label_encoder)

        user_feature_list = Generate_Feature_Info.get_user_feature_by_device_id(device_id)
        content_feature_list = Generate_Feature_Info.get_content_feature(diary_id)

        user_feature_list = user_feature_list+content_feature_list
        print(user_feature_list)

        LinUCB.update_linucb_info(user_features=user_feature_list,reward=1,content_id=diary_id,
                                  redis_name_linucb_matrix=linucb_matrix_redis_name,redis_cli=redis_client)
    diary_fd.close()
    # test_val_list = ["切开双眼皮"]
    #
    # label_results_list = label_encoder.transform(test_val_list)
