#!/usr/bin/env python
# -*- coding: utf-8 -*-

# PUSH 基类
from bs4 import BeautifulSoup
import datetime
from functools import partial
import json

from gm_types.gaia import (
    PLATFORM_CHOICES,
)

from gm_types.push import (
    PUSH_INFO_TYPE,
)
from django.conf import settings

from talos.cache.gaia import push_cache  # TODO 公立出来
from talos.logger import push_logger
from utils.common import big_data_iter

hour = 60 * 60  # 1个小时
minute = 60  # 1分钟


def get_datetime_to_str(d, fmt="%Y%m%d%H"):
    """
    获取字符串类型的时间
    :param d:
    :param fmt:
    :return:
    """
    return d.strftime(fmt)


class ParamsBuilderBase(object):
    """
    push参数构造
    """
    @staticmethod
    def _build_params(**kwargs):
        """
        构建参数
        :param kwargs:
        :return:
        """
        return {
            "user_ids": kwargs.get("user_ids", []),
            "platform": kwargs.get("platform") or [PLATFORM_CHOICES.ANDROID, PLATFORM_CHOICES.IPHONE],
            "alert": kwargs.get("alert", ""),
            "extra": {
                "type": kwargs.get("push_info_type") or PUSH_INFO_TYPE.GM_PROTOCOL,
                "msgType": 4,
                "pushUrl": kwargs.get("push_url", ""),
                "push_url": kwargs.get("push_url", ""),
            },
            "push_type": kwargs.get("push_type", ""),
            "labels": kwargs.get("labels", {}),
            "eta": kwargs.get("eta", None),
            "others": kwargs.get("others", None),
        }

    @staticmethod
    def truncate_push_content(rich_text, limit=30, extra="..."):
        """
        截断推送文本
        :param rich_text:
        :param limit:
        :param extra:
        :return:
        """
        if rich_text:
            content = BeautifulSoup(rich_text, "lxml").get_text()
        else:
            content = ""

        if not content or len(content) <= limit:
            _content = content

        elif len(content) > limit and extra and len(extra) < limit:
            _content = "{}{}".format(content[:limit - len(extra)], extra)

        else:
            _content = ""

        return _content


class PushServiceBase(object):
    # TODO settings
    start_push_hour = settings.PUSH_RATE_LIMIT_SETTINGS.get("start_push_hour", 9)
    end_push_hour = settings.PUSH_RATE_LIMIT_SETTINGS.get("end_push_hour", 23)

    ratelimit_whole_nums_common_use = False  # 速率控制 -- 推送总数是否共用
    ratelimit_whole_rate_limit = 4  # 速率控制 -- 每人每天推送总量  实时+定时任务共同消耗
    ratelimit_unit_times = 2  # 速率控制 -- 每人每小时某种类型推送数量

    # 必须写的配置项
    _default_cache_name = ""
    push_handlers = None

    @staticmethod
    def get_now_datetime():
        """
        获取当前的datetime
        :return:
        """
        return datetime.datetime.now()

    @classmethod
    def in_push_limit_time(cls, now):
        """
        推送时间段
        :return:
        """
        _hour = now.hour

        if cls.start_push_hour <= _hour < cls.end_push_hour:
            return True
        return False

    @staticmethod
    def current_times(now):
        """
        获取当前到小时,到天的字段值
        :return:
        """
        return {
            "hour_str": get_datetime_to_str(now),
            "date_str": get_datetime_to_str(now, fmt="%Y%m%d"),
        }

    @classmethod
    def expire_time(cls, unit_time=60 * 60, need_datetime=False):
        """
        获取需要设置的过期时间。
        :param unit_time: 单位时间 默认一小时
        :param need_datetime: 是否需要datetime类型数据。False 返回（秒）True 返回datetime类型
        :return:
        """
        today_start_time = datetime.datetime.combine(datetime.date.today(), datetime.time(hour=cls.start_push_hour))
        valid_seconds = int(datetime.timedelta(cls.end_push_hour - cls.start_push_hour).total_seconds())

        now = datetime.datetime.now()
        valid_time_axis = today_start_time
        for i in range(valid_seconds // unit_time):
            next_valid_hour = today_start_time + datetime.timedelta(seconds=i * unit_time)
            if now < next_valid_hour:
                valid_time_axis = next_valid_hour
                break

        if need_datetime:
            expire_time = valid_time_axis
        else:
            expire_time = int((valid_time_axis - now).total_seconds())

        return expire_time

    @staticmethod
    def _get_valid_user_ids(user_ids, rate_limit_nums, cache_name, cache_key):
        """
        获取有效的用户id
        :param user_ids:
        :param rate_limit_nums: 速率控制
        :param cache_name: str类型，这是个缓存名
        :param cache_key: func 这是个函数，需要用user_id构建对应的key
        :return:
        """
        valid_user_ids = []  # 有效的user_id

        for user_ids in big_data_iter(user_ids, fetch_num=200):
            cache_keys = [cache_key(user_id) for user_id in user_ids]
            vs = push_cache.hmget(cache_name, cache_keys)  # 批量获取

            hmset_dic = {}
            for uid, key, n in zip(user_ids, cache_keys, vs):
                if n and int(n) >= rate_limit_nums:
                    continue

                valid_user_ids.append(uid)
                _n = n and int(n) or 0
                hmset_dic.update({
                    key: _n + 1,  # 计数 + 1
                })

            if hmset_dic:  # 批量处理写入缓存
                push_cache.hmset(cache_name, hmset_dic)

        return valid_user_ids

    # < ----------- 用户 全天推送数量计算 begin ------------ >
    @staticmethod
    def _get_user_rate_limit_whole_key(user_id):
        """
        获取用户一天内数量限制的key
        :param user_id:
        :param times:字符串类型的格式化时间
        :return:
        """
        return "push_user_id_{}".format(user_id)

    @classmethod
    def _get_user_rate_limit_whole_cache_name(cls, times):
        """
        获取用户一天内推送总数缓存的name
        :param times:字符串类型的格式化时间
        :return:
        """
        _cache_name = "{d_cache_name}_rate_limit_whole_date_{date_str}".format(
            d_cache_name="" if cls.ratelimit_whole_nums_common_use else cls._default_cache_name,
            date_str=times
        )
        if not push_cache.exists(_cache_name):
            push_cache.hincrby(_cache_name, "rate_limit_whole")
            push_cache.expire(_cache_name, hour * 26)

        return _cache_name

    @classmethod
    def _get_unlimit_rate_whole_users(cls, user_ids, time_str, rate_limit_whole, **kwargs):
        """
        批量获取 总数未限制的用户id
        :param user_ids:
        :param time_str: 时间字符串
        :param rate_limit_whole:  总数限制
        :return:
        """
        _cache_name = cls._get_user_rate_limit_whole_cache_name(time_str)
        cache_key_func = cls._get_user_rate_limit_whole_key

        # 有效的user_id
        valid_user_ids = cls._get_valid_user_ids(user_ids, rate_limit_whole, _cache_name, cache_key_func)

        push_logger.info(json.dumps({
            "step": 4,
            "sole_sign": kwargs.get("sole_sign", ""),
            "resume": "get unlimit rate whole users.",
            "rate_limit_whole": rate_limit_whole,
            "_cache_name": _cache_name,
            "row_user_ids": user_ids,
            "valid_user_ids": valid_user_ids,
        }))

        return valid_user_ids

    # < ----------- 用户 全天推送数量计算 end ------------ >

    # < ----------- 用户 单位时间 推送数量计算 begin ------------ >
    @classmethod
    def _get_user_times_control_cache_name(cls, times):
        """
        获取 某段时间内 推送次数缓存 name
        PS: 在各子类中完善！！！
        :param times:
        :return:
        """
        # 判断缓存是否存在，不存在的话，则给个过期时间 + 默认数据
        _cache_name = "{}_times_{}".format(cls._default_cache_name, times)  # 某个时间段内实时推送的缓存name
        if not push_cache.exists(_cache_name):
            push_cache.hincrby(_cache_name, cls._default_cache_name)
            push_cache.expire(_cache_name, hour * 24)

        return _cache_name

    @classmethod
    def _get_user_times_control_cache_key(cls, user_id, action_type):
        """
        获取 某段时间内 推送次数缓存 key
        PS: 在各子类中完善！！！
        :param user_id:
        :param action_type:
        :param times:
        :return:
        """
        return "{}:user_id_{}_action_type_{}".format(
            cls._default_cache_name,
            user_id,
            action_type
        )

    @classmethod
    def _get_times_unlimit_control_user_ids(cls, user_ids, action_type, time_str, rate_limit, **kwargs):
        """
        获取 某段时间 内 未被限制的用户
        :param user_ids:  用户id列表
        :param action_type: 推送类型
        :param time_str: 字符串类型时间
        :param rate_limit: 速率限制(次数)
        :return:
        """
        _cache_name = cls._get_user_times_control_cache_name(time_str)
        cache_key_func = partial(cls._get_user_times_control_cache_key, action_type=action_type)
        # 有效的user_id
        valid_user_ids = cls._get_valid_user_ids(
            user_ids,
            rate_limit,
            _cache_name,
            cache_key_func
        )

        push_logger.info(json.dumps({
            "step": 3,
            "sole_sign": kwargs.get("sole_sign", ""),
            "renume": "get times unlimit rate users",
            "action_type": action_type,
            "rate_limit": rate_limit,
            "cache_time": _cache_name,
            "row_user_ids": user_ids,
            "valid_user_ids": valid_user_ids,
        }))

        return valid_user_ids

    # < ----------- 用户 单位时间 推送数量计算 end ------------ >

    # < ----------- 获取未受限用户ids begin ------------ >
    @classmethod
    def get_unlimit_users(cls, user_ids, action_type, **kwargs):
        """
        获取未受限制的用户id
        PS: 先判断某段时间内，再判断整天
        :param user_ids:
        :param action_type:
        :return:
        """
        sole_sign = kwargs.get("sole_sign", "")
        # 先过用户某种类型，某段时间内推送数量判断
        if kwargs.get("limit_rate", False):
            unlimit_times_uids = cls._get_times_unlimit_control_user_ids(
                user_ids=user_ids,
                action_type=action_type,
                time_str=kwargs.get("hour_str", ""),
                rate_limit=kwargs.get("rate_limit_hour"),
                sole_sign=sole_sign
            )
        else:
            unlimit_times_uids = user_ids

        # 再过用户推送总数判断
        if kwargs.get("limit_whole", False) and unlimit_times_uids:  # 总数控制开关开了并且在某类型内未到上限，则查总数是否到上限。
            unlimit_uids = cls._get_unlimit_rate_whole_users(
                unlimit_times_uids,
                time_str=kwargs.get("date_str", ""),
                rate_limit_whole=kwargs.get("rate_limit_whole"),
                sole_sign=sole_sign,
            )
        else:
            unlimit_uids = unlimit_times_uids

        push_logger.info(json.dumps({
            "step": 5,
            "sole_sign": sole_sign,
            "resume": "get unlimit users.",
            "action_type": action_type,
            "other_params": kwargs,
            "raw_user_ids": user_ids,
            "unlimit_user_ids": unlimit_uids,
        }))

        return unlimit_uids

    # < ----------- 获取未受限用户ids end ------------ >

    # < ----------- 构建推送信息 begin ------------ >
    @classmethod
    def build_push_params(cls, user_ids, action_type, **kwargs):
        """
        构建推送信息
        :param user_ids:
        :param action_type:
        :param kwargs:
        :return:
        """
        sole_sign = kwargs.get("sole_sign", "")
        push_logger.info(json.dumps({
            "step": 2,
            "sole_sign": sole_sign,
            "resume": "will build push params",
            "user_ids": user_ids,
            "action_type": action_type,
            "other_params": kwargs,
        }))

        event_handler = cls.push_handlers.get(action_type)
        if not event_handler:
            return

        # 构建推送信息
        params_builder = event_handler.get("params_builder")
        if not params_builder:
            return

        params = params_builder(user_ids=user_ids, **kwargs)
        if not params:
            return

        times_str_dic = kwargs.pop("time_str_dic", {}) or cls.current_times(cls.get_now_datetime())
        push_user_ids = cls.get_unlimit_users(
            user_ids,
            action_type,
            sole_sign=sole_sign,
            date_str=times_str_dic.get("date_str", ""),
            hour_str=times_str_dic.get("hour_str", ""),
            limit_rate=event_handler.get("limit_rate", False),  # 单位时间内推送限制开关
            limit_whole=event_handler.get("limit_whole", False),  # 一天内推送限制开关
            rate_limit_whole=cls.ratelimit_whole_rate_limit,  # 一天内推送限制总数
            rate_limit_hour=(event_handler.get("rate_limit_unit_times") or cls.ratelimit_unit_times)  # 单位时间内推送限制数
        )

        if not push_user_ids:
            return

        params.update({
            "user_ids": push_user_ids,
        })

        return params
    # < ----------- 构建推送信息 end ------------ >
