import pymysql
import datetime
import json
import redis
import pandas as pd
from sqlalchemy import create_engine


def get_mysql_data(host,port,user,passwd,db,sql):
    db = pymysql.connect(host=host, port=port, user=user, passwd=passwd,db=db)
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    db.close()
    return result


def get_esmm_users():
    try:
        stat_date = (datetime.date.today() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        sql = "select distinct device_id,city_id from data_feed_exposure_precise " \
              "where stat_date = '{}'".format(stat_date)
        result = get_mysql_data('172.16.40.170', 4000, 'root','3SYz54LS9#^9sBvC','jerry_prod',sql)
        result = list(result)
        return result
    except:
        return []


def get_user_profile(device_id,top_k = 5):
    try:
        r = redis.Redis(host="172.16.40.135", port=5379, password="", db=2)
        key = "user:portrait_tags:cl_id:" + str(device_id)
        if r.exists(key):
            tmp = json.loads(r.get(key).decode('utf-8'))
            tag_score = {}
            for i in tmp:
                if i["type"] == "tag":
                    tag_score[i["content"]] = i["score"]
                elif i["content"] in name_tag.keys():
                    tag_score[name_tag[i["content"]]] = i["score"]
            tag_sort = sorted(tag_score.items(), key=lambda x: x[1], reverse=True)
            tags = []
            if len(tag_sort) > top_k:
                for i in range(top_k):
                    tags.append(tag_sort[i][0])
            else:
                for i in tag_sort:
                    tags.append(i[0])

            return tags
        else:
            return []
    except:
        return []


def get_searchworlds_to_tagid():
    try:
        sql = 'select id, name from api_tag where is_online = 1 and tag_type < 4'
        tag_id = get_mysql_data('172.16.30.141', 3306, 'work', 'BJQaT9VzDcuPBqkd', 'zhengxing', sql)
        searchworlds_to_tagid = {}
        for i in tag_id:
            searchworlds_to_tagid[i[1]] = i[0]

        return searchworlds_to_tagid
    except Exception as e:
        return {}


def get_queues(device_id,city_id):
    try:
        db = pymysql.connect(host='172.16.40.170', port=4000, user='root',
                             passwd='3SYz54LS9#^9sBvC', db='jerry_test')
        cursor = db.cursor()
        sql = "select native_queue, nearby_queue, nation_queue, megacity_queue from esmm_device_diary_queue " \
              "where device_id = '{}' and city_id = '{}'".format(device_id, city_id)
        cursor.execute(sql)
        result = cursor.fetchone()
        db.close()
        if result is not None:
            return list(result)
        else:
            return []
    except:
        return []


def tag_boost(cid_str, tag_list):
    if cid_str is not None and cid_str != "" and len(tag_list) > 0:
        cids = cid_str.split(",")
        try:
            if len(cids) > 6 and len(tag_list) > 0:
                sql = "select id,group_concat(diary_id) from " \
                      "(select a.diary_id,b.id from src_mimas_prod_api_diary_tags a left join src_zhengxing_api_tag b " \
                      "on a.tag_id = b.id where b.tag_type < '4' and a.diary_id in {}) tmp " \
                      "where id in {} group by id".format(tuple(cids), tuple(tag_list))
                result = get_mysql_data('172.16.40.170', 4000, 'root', '3SYz54LS9#^9sBvC','eagle',sql)
                if len(result) > 0:
                    tag_cids = {}
                    left_cids = []
                    for i in result:
                        tmp = i[1].split(",")
                        tmp = [i for i in cids if i in tmp]
                        tag_cids[i[0]] = tmp
                        left_cids.extend(tmp)

                    left_cids = list(set(left_cids))
                    right_cids = [i for i in cids if i not in left_cids]
                    tag_cids["right"] = right_cids
                    tag_list.append("right")
                    sort_cids = []
                    n = 0
                    while n != len(tag_cids) - 1:
                        for i in tag_list:
                            if i in tag_cids.keys():
                                if len(tag_cids[i]) > 0:
                                    sort_cids.append(tag_cids[i][0])
                                    value = tag_cids[i]
                                    value.pop(0)
                                    tag_cids[i] = value
                                    if len(value) == 0 and i != "right":
                                        n = n + 1

                    if len(tag_cids["right"]) > 0:
                        sort_cids.extend(tag_cids["right"])

                    news_ids = []
                    for id in sort_cids:
                        if id not in news_ids:
                            news_ids.append(id)

                    new_str = ",".join([str(i) for i in news_ids])
                    return new_str

                else:
                    return cid_str
            else:
                return cid_str
        except:
            #TODO 往sentry发，并且在本地也要打出日志
            return cid_str

    else:
        return cid_str


def to_data_base(df):
    sql = "select distinct device_id from esmm_resort_diary_queue"
    result = get_mysql_data('172.16.40.170', 4000, 'root','3SYz54LS9#^9sBvC', 'jerry_test',sql)
    old_uid = [i[0] for i in result]
    if len(old_uid) > 0:
        old_uid = set(df["device_id"].values)&set(old_uid)
        old_number = len(old_uid)
        if old_number > 0:
            db = pymysql.connect(host='172.16.40.170', port=4000, user='root',
                                 passwd='3SYz54LS9#^9sBvC', db='jerry_test')
            sql = "delete from esmm_resort_diary_queue where device_id in {}".format(tuple(old_uid))

            cursor = db.cursor()
            cursor.execute(sql)
            db.commit()
            cursor.close()
            db.close()

    yconnect = create_engine('mysql+pymysql://root:3SYz54LS9#^9sBvC@172.16.40.170:4000/jerry_test?charset=utf8')
    pd.io.sql.to_sql(df, "esmm_resort_diary_queue", yconnect, schema='jerry_test', if_exists='append', index=False,
                     chunksize=200)
    print("insert done")


if __name__ == "__main__":
    users_list = get_esmm_users()
    print("user number")
    print(len(users_list))

    if len(users_list) > 0:
        name_tag = get_searchworlds_to_tagid()
        n = 500
        split_users_list = [users_list[i:i + n] for i in range(0, len(users_list), n)]
        for child_users_list in split_users_list:
            total_samples = list()
            for uid_city in child_users_list:
                tag_list = get_user_profile(uid_city[0])
                queues = get_queues(uid_city[0], uid_city[1])
                if len(queues) > 0:
                    new_native = tag_boost(queues[0], tag_list)
                    new_nearby = tag_boost(queues[1], tag_list)

                    insert_time = str(datetime.datetime.now().strftime('%Y%m%d%H%M'))
                    sample = [uid_city[0], uid_city[1], new_native, new_nearby, queues[2], queues[3], insert_time]
                    total_samples.append(sample)

            if len(total_samples) > 0:
                df = pd.DataFrame(total_samples)
                df = df.rename(columns={0: "device_id", 1: "city_id",2:"native_queue",
                                        3:"nearby_queue",4:"nation_queue",5:"megacity_queue",6:"time"})

                to_data_base(df)
    else:
        print("没有获取到用户")
















