import pymysql
import datetime
import json
import redis


def get_esmm_users():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root',
                         passwd='3SYz54LS9#^9sBvC', db='jerry_prod')
    cursor = db.cursor()
    stat_date = (datetime.date.today() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
    print(stat_date)
    sql = "select distinct device_id,city_id from data_feed_exposure_precise " \
          "where stat_date = '{}'".format(stat_date)
    cursor.execute(sql)
    result = list(cursor.fetchall())
    db.close()
    return result


def get_user_profile(device_id = "9C5E7C73-380C-4623-8F48-A64C8034E315"):
    r = redis.Redis(host="172.16.40.135", port=5379, password="", db=2)
    key = "user:portrait_tags:cl_id:" + str(device_id)
    tmp = json.loads(r.get(key).decode('utf-8'))
    tag_score = {}
    for i in tmp:
        if i["type"] == "tag":
            tag_score[i["content"]] = i["score"]
        elif i["content"] in name_tag.keys():
            tag_score[name_tag[i["content"]]] = i["score"]
    tag_sort = sorted(tag_score.items(), key=lambda x: x[1], reverse=True)
    tags = []
    if len(tag_sort) > 5:
        for i in range(5):
            tags.append(i[0])
    else:
        for i in tag_sort:
            tags.append(i[0])

    return tags


def get_searchworlds_to_tagid():
    try:
        sql = 'select id, name from api_tag where is_online = 1 and tag_type < 4'
        db = pymysql.connect(host='172.16.30.141', port=3306, user='work',
                             passwd='BJQaT9VzDcuPBqkd', db='zhengxing')
        cursor = db.cursor()
        cursor.execute(sql)
        tag_id = cursor.fetchall()
        db.close()
        searchworlds_to_tagid = {}

        for i in tag_id:
            searchworlds_to_tagid[i[1]] = i[0]

        return searchworlds_to_tagid
    except Exception as e:
        print(e)


def get_queues(device_id,city_id):
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root',
                         passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    cursor = db.cursor()
    sql = "select native_queue, nearby_queue, nation_queue, megacity_queue from esmm_device_diary_queue " \
          "where device_id = '{}' and city_id = '{}'".format(device_id,city_id)
    cursor.execute(sql)
    result = cursor.fetchone()
    db.close()
    if len(result) > 0:
        return list(result)
    else:
        return []

def tag_boost(cid_str, tag_list=[15,21,22,85,86]):
    if cid_str is not None or cid_str != "":
        cids = cid_str.split(",")
        if len(cids) > 6 and len(tag_list) > 0:
            db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='eagle')

            sql = "select id,group_concat(diary_id) from " \
                  "(select a.diary_id,b.id from src_mimas_prod_api_diary_tags a left join src_zhengxing_api_tag b " \
                  "on a.tag_id = b.id where b.tag_type < '4' and a.diary_id in {}) tmp " \
                  "where id in {} group by id".format(tuple(cids), tuple(tag_list))
            cursor = db.cursor()
            cursor.execute(sql)
            result = cursor.fetchall()
            if len(result) > 0:
                tag_cids = {}
                left_cids = []
                for i in result:
                    tmp = i[1].split(",")
                    tag_cids[i[0]] = tmp
                    left_cids.extend(tmp)

                left_cids = list(set(left_cids))
                right_cids = [i for i in cids if i not in left_cids]
                tag_cids["right"] = right_cids
                tag_list.append("right")
                sort_cids = []
                n = 0
                while n != len(tag_cids)-1:
                    for i in tag_list:
                        if i in tag_cids.keys():
                            if len(tag_cids[i]) > 0:
                                sort_cids.append(tag_cids[i][0])
                                value = tag_cids[i]
                                value.pop(0)
                                tag_cids[i] = value
                                if len(value) == 0 and i != "right":
                                    n = n + 1

                if len(tag_cids["right"]) > 0:
                    sort_cids.extend(tag_cids["right"])
                print(tag_cids)
                print(sort_cids)
                news_ids = []
                for id in sort_cids:
                    if id not in news_ids:
                        news_ids.append(id)
                print(len(news_ids))
                print(len(cids))
                return news_ids


            else:
                return cid_str
        else:
            return cid_str
    else:
        return cid_str



def to_data_base(df,table_name = "tag_boost_device_diary_queue"):
    pass


def make_sample(uid,city_id,native_queue,nearby_queue,megacity_queue,nation_queue):
    pass


if __name__ == "__main__":
    # users_list = get_esmm_users()
    # total_samples = list()
    # name_tag = get_searchworlds_to_tagid()
    # for i in users_list:
    #     tag_list = get_user_profile(i[0])
    #     queues = get_queues(i[0],i[1])
    #
    #     native_sort_list = tag_boost(native, tag_score)
    #     nearby_sort_list = tag_boost(nearby, tag_score)
    #
    #     sample = make_sample(uid,city_id,native_queue,nearby_queue,megacity_queue,nation_queue)
    #     total_samples.append(sample)
    #
    # total_samples.todf
    # to_data_base(df)
    # to kv
    cid_str = "16473983,16296886,16199213,16193883,16419499,16372783,16430184,16617593,16498902,16238415,16214258,15715721,16213338,15349114,14091428,16268804,15485655,16448547,16179842,16685025,16612412,16683132,15646229,16482213,16485831,16436136,16353856,16400696,16193006,16294202,16393228,16716816,16713343,16780702,16107140,16647027,16112786,16503037,16372681,16207971,16179934,16480641,16295094,16204980,16317847,16434907,16117929,15633591,16116818"
    tag_boost(cid_str)