# coding=utf-8

from __future__ import absolute_import, unicode_literals, print_function

from .node import PreCalcNode, AdhocNode
from django.conf import settings
from math import log
from random import Random
from django.utils import timezone

from search.utils.es import get_es


class CreatedTimeNode(AdhocNode):
    def __init__(self, weight=1.0, alpha=29624104.0, beta=-27.452):
        super(CreatedTimeNode, self).__init__(weight=weight)
        self.alpha = alpha
        self.beta = beta
        self.node_name = 'created_time'

    def _calc_score(self, obj, ext_params):
        timedelta = (ext_params.now-obj.created_time).total_seconds()
        return self.beta*log(1.0+timedelta/self.alpha)

    @classmethod
    def solve_params(cls, t1, y1, t2, y2):
        '''
        log(y) = beta*log(1+t/alpha)
        t1, t2 in seconds
        '''
        def f(alpha, t1, lgy1, t2, lgy2):
            return lgy1*log(1+t2/alpha)-lgy2*log(1+t1/alpha)
        def fd(alpha, t1, lgy1, t2, lgy2):
            return (t1*(t2+alpha)*lgy2-t2*(t1+alpha)*lgy1)/(alpha*(t1+alpha)*(t2+alpha))

        lgy1 = log(y1)
        lgy2 = log(y2)
        alpha = 1.0
        for i in range(1, 100):
            alpha = alpha-f(alpha, t1, lgy1, t2, lgy2)/fd(alpha, t1, lgy1, t2, lgy2)
        beta = (1+lgy1)/log(1+t1/alpha)

        return alpha, beta


class UserRelatedTagNode(AdhocNode):
    def __init__(self, weight=1.0, not_related_value=0.8):
        super(UserRelatedTagNode, self).__init__(weight=weight)
        self.not_related_value = log(not_related_value)
        self.node_name = 'user_related_tag'

    def _calc_score(self, obj, ext_params):
        if not ext_params.user_related_tag_ids:
            return self.not_related_value # log(0.8)
        if obj.problemtag_set.select_related('tag').filter(tag__is_online=True, tag__id__in=ext_params.user_related_tag_ids).exists():
            return 0.0
        else:
            return self.not_related_value # log(0.8)


class ViewCountNode(PreCalcNode):
    def __init__(self, weight=1.0, thres=1000, base=0.2, coeff=0.08, pw=1.0/3):
        super(ViewCountNode, self).__init__(weight=weight)
        self.thres = thres
        self.base = base
        self.coeff = coeff
        self.pw = pw
        self.node_name = 'view_count'

    def _calc_score(self, obj):
        view_count = int(obj.view_amount)
        if view_count >= self.thres:
            res = 0.0
        else:
            res = log(self.base+self.coeff*pow(view_count, self.pw))
        return res


class VoteCountNode(PreCalcNode):
    def __init__(self, weight=1.0, thres=100, base=0.2, coeff=0.08, pw=0.5):
        super(VoteCountNode, self).__init__(weight=weight)
        self.thres = thres
        self.base = base
        self.coeff = coeff
        self.pw = pw
        self.node_name = 'vote_count'

    def _calc_score(self, obj):
        vote_count = int(obj.vote_amount)
        if vote_count >= self.thres:
            res = 0.0
        else:
            res = log(self.base+self.coeff*pow(vote_count, self.pw))
        return res


class ReplyCountNode(PreCalcNode):
    def __init__(self, weight=1.0, thres=100, base=0.2, coeff=0.08, pw=0.5):
        super(ReplyCountNode, self).__init__(weight=weight)
        self.thres = thres
        self.base = base
        self.coeff = coeff
        self.pw = pw
        self.node_name = 'reply_count'

    def _calc_score(self, obj):
        reply_count = int(obj.reply_num)
        if reply_count >= self.thres:
            res = 0.0
        else:
            res = log(self.base+self.coeff*pow(reply_count, self.pw))
        return res


class HasServiceNode(PreCalcNode):
    def __init__(self, weight=1.0, not_service_value=0.1):
        super(HasServiceNode, self).__init__(weight=weight)
        self.not_service_value = log(not_service_value)
        self.node_name = 'has_service'

    def _calc_score(self, obj):
        if hasattr(obj, 'diary') and hasattr(obj.diary, 'service'):
            return 0.0
        else:
            return self.not_service_value

pre_calc = (
    6.0, [
        ViewCountNode(weight=1.0, thres=1000, base=0.2, coeff=0.08, pw=1.0/3),
        VoteCountNode(weight=2.0, thres=100, base=0.2, coeff=0.08, pw=0.5),
        ReplyCountNode(weight=2.0, thres=100, base=0.2, coeff=0.08, pw=0.5),
        HasServiceNode(weight=1.0, not_service_value=0.1),
    ],
)
pre_calc_node = PreCalcNode.build(pre_calc)


def get_topic_pre_calc_score(topic):
    return pre_calc_node.calc_score(topic)


def _suggest_topic(offset, size, user_params={}, suggest_type=0):
    q = {
        '_source': 'user.id',
        'query': {
            'filtered': {
                'filter': {
                    'bool': {
                        'must': [
                            {'exists': {'field': 'suggest.pre_calc_score'}},
                            {'term': {'is_doctor': False}},
                            {'term': {'is_online':True}},
                        ],
                    },
                },
            },
        },
    }

    sort_list = [
        {'is_sink': {'order': 'asc'}},  # grep field in project to find document
    ]

    if suggest_type == 0:
        sort_list += [
            {'_script': {
                'lang': settings.ES_SCRIPT_LANG,
                'script_file': 'suggest_topic-default',
                'type': 'number',
                'order': 'desc',
                'params': {
                    'user_related_tag_ids': user_params['user_related_tag_ids'],
                    'pre_calc_weight': float(pre_calc_node.weight),
                    'adhoc_user_weight': 1.0,
                    'adhoc_time_weight': 1.0,
                }
            }},
        ]

    q['sort'] = sort_list

    es = get_es()
    res = es.search(
        index=settings.ES_READ_INDEX,
        doc_type='problem',
        timeout=settings.ES_SEARCH_TIMEOUT,
        body=q,
        from_=offset,
        size=size)

    def get_user_id(topic):
        try:
            return topic['_source']['user']['id']
        except (KeyError, TypeError):
            return None

    _res = [
        {
            'id': int(topic['_id']),
            'user_id': get_user_id(topic),
        }
        for topic in res['hits']['hits']
    ]
    return _res


def suggest_topic(offset=0, size=5, user_params={}, suggest_type=0):
    REORDER_BUFFER_SIZE = 1000
    SHUFFLE_SIZE = 200

    lower_bound = offset
    upper_bound = offset + size

    if lower_bound >= REORDER_BUFFER_SIZE:
        result_topics = _suggest_topic(
            offset,
            size,
            user_params,
            suggest_type,
        )
    else:
        fetch_upper_bound = max(REORDER_BUFFER_SIZE, upper_bound) # at least REORDER_BUFFER_SIZE
        fetched_topics = _suggest_topic(0, fetch_upper_bound, user_params, suggest_type)

        user_id_set = set()
        distinct_topics = []
        duplicated_topics = []
        for topic in fetched_topics[:REORDER_BUFFER_SIZE]:
            user_id = topic['user_id']
            if user_id in user_id_set:
                duplicated_topics.append(topic)
            else:
                if user_id is not None:
                    user_id_set.add(user_id)
                distinct_topics.append(topic)
        reordered_fetched_topics = distinct_topics + duplicated_topics + fetched_topics[REORDER_BUFFER_SIZE:]

        # 对前SHUFFLE_SIZE个item按照用户id+小时进行random shuffle
        rnd = Random(str(user_params['user_id'])+timezone.now().strftime('_%Y%m%d%H'))
        shuffled_topics = reordered_fetched_topics[0:SHUFFLE_SIZE]
        rnd.shuffle(shuffled_topics)
        reordered_fetched_topics[0:SHUFFLE_SIZE] = shuffled_topics

        result_topics = reordered_fetched_topics[lower_bound:upper_bound]

    return {
        'topic_ids': [topic['id'] for topic in result_topics],
    }
