# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import time
import datetime
import logging
import traceback
import django.db.models
from django.conf import settings
import elasticsearch
import elasticsearch.helpers
import sys
from api.models import brand, item, product, collect, wikitag
from gm_types.mq.dbmw import DBMWEndPoints
from api.models import hot_wiki_keyword

import api.models as am
import answer.models as aw
from api.models.ranklist import RankBoard
from search.utils.es import get_es, get_talos_es, get_talos_es6

from injection.utils.table_scan import ITableChunk
from search.utils.es import es_index_adapt
from rpc.context import get_rpc_remote_invoker
import functools
from talos.models.topic import Problem
from talos.models.diary import Diary

__es = None
__talos_es = None
__es6 = None


def get_elasticsearch_instance():
    global __es
    if __es is None:
        __es = get_es()
    return __es


# TODO: 在talos迁移到mimas的迁移中 把这个也带走～ @zhangyunyu
def get_talos_elasticsearch_instance():
    global __talos_es
    if __talos_es is None:
        __talos_es = get_talos_es()
    return __talos_es


def get_talos_es6_instance():
    global __es6
    if __es6 is None:
        __es6 = get_talos_es6()
    return __es6


def get_es_list_by_type(es_type):
    if es_type in ['service', 'sku']:
        es_list = [get_talos_elasticsearch_instance()]
    else:
        es_list = [get_elasticsearch_instance()]
    return es_list


class TypeInfo(object):
    def __init__(
            self,
            name,
            type,
            model,
            query_deferred,
            get_data_func,
            bulk_insert_chunk_size,
            round_insert_chunk_size,
            round_insert_period,
            batch_get_data_func=None,  # receive a list of pks, not instance
            gm_mq_endpoint=None,
            logic_database_id=None,
    ):
        self.name = name
        self.type = type
        self.model = model
        self.query_deferred = query_deferred
        self.get_data_func = get_data_func
        self.batch_get_data_func = batch_get_data_func
        self.pk_blacklist = ()
        self.bulk_insert_chunk_size = bulk_insert_chunk_size
        self.round_insert_chunk_size = round_insert_chunk_size
        self.round_insert_period = round_insert_period
        self.gm_mq_endpoint = gm_mq_endpoint
        self.logic_database_id = logic_database_id

    @property
    def query(self):
        return self.query_deferred()

    @property
    def queryset(self):
        return django.db.models.QuerySet(model=self.model, query=self.query)

    @property
    def pk_blacklist(self):
        return self.__pk_blacklist

    @pk_blacklist.setter
    def pk_blacklist(self, value):
        self.__pk_blacklist = frozenset(value)

    def bulk_get_data(self, instance_iterable):
        data_list = []
        if self.batch_get_data_func:
            _pk_list = [getattr(instance, 'pk', None) for instance in instance_iterable]
            not_found_pk_list = []
            blacklisted_pk_list = []
            pk_list = []
            for pk in _pk_list:
                if pk is None:
                    not_found_pk_list.append(pk)
                elif pk in self.__pk_blacklist:
                    blacklisted_pk_list.append(pk)
                else:
                    pk_list.append(pk)
            if not_found_pk_list:
                logging.exception('those pks not found for name={}, doc_type={}, pk_list={}'.format(
                    self.name,
                    self.type,
                    str(not_found_pk_list),
                ))
            if blacklisted_pk_list:
                logging.info('those pks are in blacklist for name={}, doc_type={}, pk_list={}'.format(
                    self.name,
                    self.type,
                    str(blacklisted_pk_list),
                ))
            try:
                data_list = self.batch_get_data_func(pk_list)
            except Exception:
                traceback.print_exc()
                logging.exception('bulk_get_data for name={}, doc_type={}, pk_list={}'.format(
                    self.name,
                    self.type,
                    str(pk_list),
                ))
        else:
            for instance in instance_iterable:
                pk = getattr(instance, 'pk', None)
                try:
                    if pk is None:
                        raise Exception('pk not found')
                    if pk in self.__pk_blacklist:
                        logging.info('bulk_get_data for name={}, doc_type={}, pk={}: ignore blacklisted pk'.format(
                            self.name,
                            self.type,
                            pk,
                        ))
                        continue
                    data = self.get_data_func(instance)
                except Exception:
                    traceback.print_exc()
                    logging.exception('bulk_get_data for name={}, doc_type={}, pk={}'.format(
                        self.name,
                        self.type,
                        pk,
                    ))
                else:
                    data_list.append(data)
        return data_list

    def elasticsearch_bulk_insert_data(self, index_prefix, data_list, es=None):
        if es is None:
            es = get_es_list_by_type(self.type)

        if not isinstance(es, (list, tuple,)):
            es = [es]

        index_type = self.type

        if self.type == 'newitemwiki' or self.type == 'collectwiki' or self.type == 'brandwiki' or self.type == 'productwiki':
            index_type = 'newwiki'

        index = es_index_adapt(
            index_prefix=index_prefix,
            doc_type=index_type,
            rw='write',
        )
        bulk_actions = []

        if self.type == 'newitemwiki' or self.type == 'collectwiki' or self.type == 'brandwiki' or self.type == 'productwiki':
            for data in data_list:
                bulk_actions.append({
                    '_op_type': 'index',
                    '_index': index,
                    '_type': "_doc",
                    '_id': data['id'] + data["wikitype"] * 10000,
                    '_source': data,
                })
        else:
            for data in data_list:
                bulk_actions.append({
                    '_op_type': 'index',
                    '_index': index,
                    '_type': index_type,
                    '_id': data['id'],
                    '_source': data,
                })

        print("duan add,index_name:%s,data_list:%s" % (str(index), str(data_list)))
        es_result = None
        if bulk_actions:
            for t in es:
                try:
                    es_result = elasticsearch.helpers.bulk(client=t, actions=bulk_actions)
                except Exception as e:
                    traceback.print_exc()
                    es_result = 'error'

        return es_result

    def elasticsearch_bulk_insert(self, index_prefix, instance_iterable, es=None):
        data_list = self.bulk_get_data(instance_iterable)
        return self.elasticsearch_bulk_insert_data(
            index_prefix=index_prefix,
            data_list=data_list,
            es=es,
        )

    def insert_table_by_pk_list(self, index_prefix, pk_list, es=None, use_batch_query_set=False):
        if use_batch_query_set:
            qs = self.queryset
        else:
            qs = self.model.objects.all()

        instance_list = qs.filter(pk__in=pk_list)
        data_list = self.bulk_get_data(instance_list)
        self.elasticsearch_bulk_insert_data(
            index_prefix=index_prefix,
            data_list=data_list,
            es=es
        )

    def insert_table_chunk(self, index_prefix, table_chunk, es=None):
        assert isinstance(table_chunk, ITableChunk)

        start_clock = time.clock()
        start_time = time.time()

        instance_list = list(table_chunk)

        stage_1_time = time.time()

        data_list = self.bulk_get_data(instance_list)

        stage_2_time = time.time()

        es_result = self.elasticsearch_bulk_insert_data(
            index_prefix=index_prefix,
            data_list=data_list,
            es=es,
        )

        stage_3_time = time.time()
        end_clock = time.clock()

        return ('{datetime} {index_prefix} {type_name:10s} {pk_start:>15s} {pk_stop:>15s} {count:5d} '
                '{stage_1_duration:6.3f} {stage_2_duration:6.3f} {stage_3_duration:6.3f} {clock_duration:6.3f} '
                '{response}').format(
            datetime=datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S.%f'),
            index_prefix=index_prefix,
            type_name=self.name,
            pk_start=repr(table_chunk.get_pk_start()),
            pk_stop=repr(table_chunk.get_pk_stop()),
            count=len(instance_list),
            stage_1_duration=stage_1_time - start_time,
            stage_2_duration=stage_2_time - stage_1_time,
            stage_3_duration=stage_3_time - stage_2_time,
            clock_duration=end_clock - start_clock,
            response=es_result,
        )


_get_type_info_map_result = None


def rpc2batch_get_data_func(api, pk_param_name):
    rpc_invoker = get_rpc_remote_invoker()

    def _rpc_call(pk_list):
        return rpc_invoker.invoke(method=api, params={pk_param_name: pk_list}).unwrap()

    return _rpc_call


def get_type_info_map():
    global _get_type_info_map_result
    if _get_type_info_map_result:
        return _get_type_info_map_result

    from trans2es.utils import (transfer, service_transfer, sku_transfer,
                                diary_transfer, topic_transfer,
                                doctor_transfer,
                                user_transfer, board_transfer, WikiTab_transfer, user_album_transfer)
    type_info_list = [
        TypeInfo(
            name='tag',  # 圈子
            type='tag',
            model=am.Tag,
            query_deferred=lambda: am.Tag.objects.all().query,
            get_data_func=transfer.get_tag,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
            gm_mq_endpoint=DBMWEndPoints.GAIA_TAG,
        ),
        TypeInfo(
            name='itemwiki',  # 百科
            type='itemwiki',
            model=am.ItemWiki,
            query_deferred=lambda: am.ItemWiki.objects.select_related('tag').query,
            get_data_func=transfer.get_itemwiki,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,

        ),

        TypeInfo(
            name='newitemwiki',  # 项目百科tab
            type='newitemwiki',
            model=item.NewItemWiki,
            query_deferred=lambda: item.NewItemWiki.objects.all().query,
            get_data_func=WikiTab_transfer.get_ItemWiki,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
        ),

        TypeInfo(
            name='collectwiki',  # 聚合百科tab
            type='collectwiki',
            model=collect.CollectWiki,
            query_deferred=lambda: collect.CollectWiki.objects.all().query,
            get_data_func=WikiTab_transfer.get_CollectWiki,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
        ),

        TypeInfo(
            name='brandwiki',  # 品牌产品tab
            type='brandwiki',
            model=brand.BrandWiki,
            query_deferred=lambda: brand.BrandWiki.objects.all().query,
            get_data_func=WikiTab_transfer.get_BrandWiki,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
        ),

        TypeInfo(
            name='productwiki',  # 产品百科tab
            type='productwiki',
            model=product.ProductWiki,
            query_deferred=lambda: product.ProductWiki.objects.all().query,
            get_data_func=WikiTab_transfer.get_ProductWiki,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
        ),
        TypeInfo(
            name='wiki_keyword',  # 百科keyword
            type='wiki_keyword',
            model=hot_wiki_keyword.Wiki_Keyword,
            query_deferred=lambda: hot_wiki_keyword.Wiki_Keyword.objects.all().query,
            get_data_func=transfer.Wiki_keyword,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=5,
            round_insert_period=2,
        ),

        TypeInfo(
            name='diary',  # 日记本
            type='diary',
            model=Diary,
            # WARNING:
            # This code *MUST* be synchronized with trans2es.utils.transfer.get_diary
            # REGION BEGIN
            query_deferred=lambda: Diary.objects
                .annotate(annotate__topicreply__reply_date__max=django.db.models.Max('topicreply__reply_date'))
                .annotate(
                annotate__topics__topicreply__reply_date__max=django.db.models.Max('topics__topicreply__reply_date'))
                .query,
            # REGION END
            get_data_func=diary_transfer.get_diary,
            # batch_get_data_func=rpc2batch_get_data_func('talos/diary/data_sync/get_diaries', 'diary_ids'),
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=50,
            round_insert_period=12,
        ),
        TypeInfo(
            name='service',  # 福利
            type='service',
            model=am.Service,
            query_deferred=lambda: am.Service.objects.select_related('doctor__hospital__city__province').query,
            get_data_func=service_transfer.get_service,
            bulk_insert_chunk_size=200,
            round_insert_chunk_size=50,
            round_insert_period=6,
        ),
        TypeInfo(
            name='sku',  # sku
            type='sku',
            model=am.ServiceItem,
            query_deferred=lambda: am.ServiceItem.objects.select_related(
                'service__doctor__hospital__city__province').query,
            get_data_func=sku_transfer.get_sku,
            bulk_insert_chunk_size=200,
            round_insert_chunk_size=50,
            round_insert_period=6,
        ),
        TypeInfo(
            name='doctor',  # 医生
            type='doctor',
            model=am.Doctor,
            query_deferred=lambda: am.Doctor.objects
                .select_related('hospital__city__province')
                .prefetch_related('doctortag_set')
                .prefetch_related('services')
                .query,
            get_data_func=doctor_transfer.get_doctor,
            bulk_insert_chunk_size=100,
            round_insert_chunk_size=25,
            round_insert_period=6,
        ),
        TypeInfo(
            name='user',  # 用户
            type='user',
            model=am.User,
            query_deferred=lambda: am.User.objects.select_related('userextra').query,
            get_data_func=user_transfer.get_user,
            bulk_insert_chunk_size=1000,
            round_insert_chunk_size=50,
            round_insert_period=4,
        ),
        TypeInfo(
            name='board',  # 问答
            type='board',
            model=RankBoard,
            query_deferred=lambda: RankBoard.objects.prefetch_related('data_tags').query,
            get_data_func=board_transfer.get_board,
            bulk_insert_chunk_size=500,
            round_insert_chunk_size=50,
            round_insert_period=12,
        ),
        TypeInfo(
            name='user_album',  # 用户相册
            type='user_album',
            model=am.User,
            query_deferred=lambda: am.User.objects.select_related('userextra').query,
            get_data_func=user_album_transfer.get_user_album,
            bulk_insert_chunk_size=500,
            round_insert_chunk_size=50,
            round_insert_period=12,
        ),

    ]

    try:
        pk_blacklist_map = settings.DATA_SYNC.get('pk_blacklist', {})
        assert isinstance(pk_blacklist_map, dict)
    except (ImportError, AttributeError):
        traceback.print_exc()
        print('fallback to empty pk_blacklist_map', file=sys.stderr)
        pk_blacklist_map = {}
    for type_info in type_info_list:
        type_info.pk_blacklist = pk_blacklist_map.get(type_info.name, ())
        print('loaded pk_blacklist for {}: {}'.format(
            type_info.name,
            sorted(list(type_info.pk_blacklist)),
        ), file=sys.stderr)

    type_info_map = {
        type_info.name: type_info
        for type_info in type_info_list
    }

    _get_type_info_map_result = type_info_map
    return type_info_map
