# coding: utf-8

import json

from collections import defaultdict
from math import ceil
from itertools import chain, groupby
from django.db.models import Count
from django.conf import settings
from celery import shared_task

from gm_types.mimas import GRABBING_PLATFORM
from gm_types.gaia import TAG_V3_TYPE
from qa.models.answer import Answer, Question, QuestionTagV3
from communal.cache.push import personalize_push_cache, doris_ctr_cache
from talos.services import TagV3Service
from utils.common import big_qs_iter
from utils.rpc import rpc_client
from utils.group_routine import GroupRoutine

max_length = 1000
step = 100
cache_key_question_answer_count = "mimas:question:answer_count"
cache_key_push_question_by_interesting_tag_v3 = 'demeter:push:push_question_by_interesting_tag_v3'


@shared_task
def update_answers_count_of_question():
    """
    将问题对应的答案数量写入缓存
    :return:
    """
    question_answer_num = list(Answer.objects.using(settings.SLAVE_DB_NAME).filter(
        is_online=True).exclude(
        platform=GRABBING_PLATFORM.KYC
    ).values('question_id').annotate(answer_cnt=Count('id')))

    offset = 50
    for index in range(0, len(question_answer_num), offset):
        personalize_push_cache.hmset(
            cache_key_question_answer_count,
            dict(zip(
                [str(item['question_id']) for item in question_answer_num[index:index + offset]],
                [item['answer_cnt'] for item in question_answer_num[index:index + offset]],
            )),
        )

    # 删除两个运营账号发送的问题
    exclude_question_ids = list(Question.objects.using(settings.SLAVE_DB_NAME).filter(
        user__in=[22, 29075872, 3161]
    ).values_list('id', flat=True))
    for index in range(0, len(exclude_question_ids), offset):
        personalize_push_cache.hdel(cache_key_question_answer_count, *exclude_question_ids[index:index + offset])


# <---------- 等新逻辑验证通过后废弃 --------------->
def get_tag_v3_ids_by_tag_name(tag_name_list):
    tag_v3_info_list = rpc_client['api/tag_v3/gets_info'](
        tag_names=tag_name_list,
        tag_type=TAG_V3_TYPE.NORMAL
    ).unwrap()
    #
    if not tag_v3_info_list:
        return []
    tag_v3_info_list.sort(key=lambda x: tag_name_list.index(x['name']))
    return tag_v3_info_list


def get_question_id_by_tag_v3(tag_v3_id):
    question_id_list = list(QuestionTagV3.objects.filter(tag_v3_id=tag_v3_id).values_list(
        'question_id', flat=True
    ))
    online_question_id_list = list(Question.objects.filter(
        id__in=question_id_list,
        is_online=True
    ).values_list(
        'id', flat=True
    ))

    return online_question_id_list
# <---------- 等新逻辑验证通过后废弃 --------------->


def check_tag_valid(tag_name):
    if tag_name.isnumeric() or tag_name.find("-") >= 0 or tag_name in ["不感兴趣", "没有想法"]:
        return False
    return True


# <---------- 等新逻辑验证通过后废弃 --------------->
@shared_task
def get_push_questions_for_device(device_id, tag_name_list):
    """
    根据从策略获取的device_id对应的tag_v3标签的名字,将设备应该推送的问题前30写入缓存
    :param device_id:
    :param tag_name_list:
    :return:
    """
    cache_key = cache_key_push_question_by_interesting_tag_v3
    push_info_list = []
    #
    tag_v3_info_list = get_tag_v3_ids_by_tag_name(tag_name_list)
    for tag_v3_info in tag_v3_info_list:
        question_id_list = get_question_id_by_tag_v3(tag_v3_info['id'])
        if not question_id_list:
            continue
        question_answer_count_list = list(zip(
            question_id_list,
            personalize_push_cache.hmget(cache_key_question_answer_count, question_id_list)
        ))
        # 没有回答的问题过滤掉
        question_answer_count_list = list(filter(lambda x: x[1], question_answer_count_list))
        # 按照回答数量排序
        question_answer_count_list.sort(key=lambda x: int(x[1]), reverse=True)
        for i in question_answer_count_list:
            if (i[0], tag_v3_info['name']) not in push_info_list:
                push_info_list.append((i[0], tag_v3_info['name']))
        # push_info_list.extend([(i[0], tag_v3_info['name']) for i in question_answer_count_list])
        # push_info_list = list(set(push_info_list))  # 会导致顺序错误
        # 取前30,避免已读过滤导致无内容可推
        if len(push_info_list) > 30:
            break
    #
    personalize_push_cache.hset(cache_key, device_id, json.dumps(push_info_list[:30]))
# <---------- 等新逻辑验证通过后废弃 --------------->


@shared_task
def record_push_content_of_ctr_device():
    """
    根据从策略获取的device_id对应的tag_v3标签的名字,将设备应该推送的问题前30写入缓存
    :return:
    """
    tag_step = 10
    push_info_list_nums = 30

    def get_question_tag_map_from_sql(tag_ids):
        """
        通过标签ID 获取 标签与问题ID的映射关系
        :param tag_ids:
        :return:
        """
        tag_question_ids_map = []
        if not tag_ids:
            return tag_question_ids_map

        map_objs = QuestionTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id__in=tag_ids
        ).values_list("question_id", "tag_v3_id")

        for item in big_qs_iter(map_objs):
            tag_question_ids_map.append(item)

        return tag_question_ids_map

    def get_online_question_ids(q_ids):
        """
        获取在线的问题ID
        :param q_ids:
        :return:
        """
        if not q_ids:
            return []

        return list(Question.objects.using(settings.SLAVE_DB_NAME).filter(
            pk__in=q_ids,
            is_online=True
        ).values_list("id", flat=True))

    pre_mapping_relations = {}  # 预处理数据 类似于 {device_id: [tag_name1, tag_name2]}
    temporary_storage_tag_names = set()  # 暂存标签名

    # scan_key = "doris:user_portrait:tag3:device_id:*"  # 原用户画像, 所有用户
    scan_key = "doris:user_portrait:experience:tag3:device_id:*"  # 新用户画像,仅有经验标签的用户
    key_iter = doris_ctr_cache.scan_iter(scan_key)

    # 从用户画像读数据
    for _key in key_iter:
        cache_key = str(_key, encoding="utf8")
        device_id = cache_key.split(':')[-1]
        user_portrait = json.loads(doris_ctr_cache.get(cache_key))

        # 取前6个分值最高的,且不为xx值的项目标签名【注意这块是所有的项目标签】。
        project_tags = user_portrait.get("projects") or {}
        project_tag_name_list = list(
            filter(
                check_tag_valid,
                map(
                    lambda item: item[0],
                    sorted(
                        project_tags.items(),
                        key=lambda item: float(item[1]), reverse=True
                    )
                )
            )
        )[:6]

        pre_mapping_relations[device_id] = project_tag_name_list
        temporary_storage_tag_names.update(set(project_tag_name_list))

    temporary_storage_tag_names = list(temporary_storage_tag_names)
    if not temporary_storage_tag_names:
        return

    # 处理数据
    # 先根据标签名,获取标签的信息
    tag_name_map_info_dic = {}
    for n in range(int(ceil(len(temporary_storage_tag_names) / step))):
        _tag_names = temporary_storage_tag_names[n * step: (n + 1) * step]
        try:
            _tag_v3_infos = TagV3Service.get_tag_v3_ids_by_tag_names(
                tag_name_list=_tag_names,
                tag_type=TAG_V3_TYPE.NORMAL
            )
            tag_name_map_info_dic.update({
                item["name"]: item for item in _tag_v3_infos
            })
        except:
            continue

    if not tag_name_map_info_dic:
        return

    # 通过标签ID找关联且在线的问题ID
    _all_tag_v3_ids = list(
        filter(
            None,
            map(
                lambda item: item.get("id", 0),
                tag_name_map_info_dic.values()
            )
        )
    )

    # 获取在线问题对应的回答数
    tag_v3_map_valid_question_dic, question_answer_count_dic = defaultdict(list), dict()
    _offline_question_ids, _zero_answer_qids = set(), set()
    # 处理所有的3.0标签对应的问题数据。切片处理
    for j in range(int(ceil(len(_all_tag_v3_ids) / tag_step))):
        tag_question_map_infos = get_question_tag_map_from_sql(_all_tag_v3_ids[j * tag_step: (j + 1) * tag_step])

        # 部分数据校验
        _all_valid_question_ids = set(question_answer_count_dic.keys())
        _part_question_ids = set(map(lambda item: item[0], tag_question_map_infos))  # 标签关联的所有问题

        # 校验所有未知状态的数据 是否在线, 回答数是否大于0
        _part_need_check_qids = list(
            _part_question_ids
            - _all_valid_question_ids  # 有效的问题
            - _offline_question_ids   # 下线的问题
            - _zero_answer_qids  # 回答数为0的问题
        )
        if _part_need_check_qids:
            for k in range(int(ceil(len(_part_need_check_qids) / max_length))):
                _check_qids = _part_need_check_qids[k * max_length: (k + 1) * max_length]

                # 先校验在不在线
                _online_qids = get_online_question_ids(_check_qids)
                _offline_question_ids.update(set(_check_qids) - set(_online_qids))  # 下线的问题

                if not _online_qids:
                    continue

                # 再校验回答数量
                valid_question_dic = dict(
                    filter(
                        lambda item: item[1],
                        zip(
                            _online_qids,
                            personalize_push_cache.hmget(
                                cache_key_question_answer_count,
                                _online_qids
                            )
                        )
                    )
                )

                # 更新当前校验数据
                _finall_valid_qids = set(valid_question_dic.keys())  # 有效的问题
                _zero_answer_qids.update(set(_online_qids) - _finall_valid_qids)  # 在线且回答数为0的问题

                if not valid_question_dic:
                    continue

                _all_valid_question_ids.update(_finall_valid_qids)
                question_answer_count_dic.update(valid_question_dic)

        # 有效数据转换
        _valid_qids = _part_question_ids & _all_valid_question_ids
        if not _valid_qids:
            continue

        for tag_v3_id, items in groupby(
            sorted(  # 通过标签排序,用于分组
                sorted(  # 过滤有效的数据,并按照回答数排序
                    filter(lambda item: item[0] in _valid_qids, tag_question_map_infos),
                    key=lambda item: int(question_answer_count_dic.get(item[0]) or 0),
                    reverse=True
                ),
                key=lambda item: item[1],
            ),
            key=lambda x: x[1]  # 通过标签分组
        ):

            _q_ids = []
            for qid, _ in items:
                _q_ids.append(qid)
                if len(_q_ids) >= push_info_list_nums:
                    break

            tag_v3_map_valid_question_dic[tag_v3_id].extend(_q_ids)

    if not tag_v3_map_valid_question_dic:
        return

    # 组装数据
    write_in_cache_infos = {}
    for device_id, tag_name_list in pre_mapping_relations.items():
        sorted_tag_infos = list(
            filter(
                None,
                (tag_name_map_info_dic.get(tag_name) for tag_name in tag_name_list)
            )
        )
        if not sorted_tag_infos:
            continue

        push_info_list, _info_nums_status = [], False
        for sort_tag_info in sorted_tag_infos:
            sort_tag_id = sort_tag_info.get("id", 0)
            tag_name = sort_tag_info.get("name", "")

            question_ids = tag_v3_map_valid_question_dic.get(sort_tag_id) or []
            if not question_ids:
                continue

            for qid in question_ids:
                _tuple_key = (qid, tag_name)
                if _tuple_key not in push_info_list:
                    push_info_list.append(_tuple_key)

                if len(push_info_list) >= push_info_list_nums:
                    _info_nums_status = True
                    break

            if _info_nums_status:
                break

        write_in_cache_infos.update({
            device_id: json.dumps(push_info_list[:push_info_list_nums]),
        })

        # 批量写入缓存
        if len(write_in_cache_infos) >= max_length:
            personalize_push_cache.hmset(
                cache_key_push_question_by_interesting_tag_v3,
                write_in_cache_infos
            )
            write_in_cache_infos = {}  # 字段重置

    # 若还有数据,则再次写入缓存
    if write_in_cache_infos:
        personalize_push_cache.hmset(
            cache_key_push_question_by_interesting_tag_v3,
            write_in_cache_infos
        )


def for_test(device_id='', tag_name=''):
    device_id = device_id or '139580A6-85B7-41CD-B7C4-92366C23B4F0'
    tag_name = tag_name or '双眼皮'
    params = json.dumps(dict(projects={tag_name: 1.2}))
    cache_key = "doris:user_portrait:tag3:device_id:{}".format(device_id)
    doris_ctr_cache.set(cache_key, params)
    update_answers_count_of_question()
    record_push_content_of_ctr_device()