# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import functools
import itertools
import os
import traceback

from django.conf import settings
from django.core.management import BaseCommand

from message.utils.common import get_update_retry_on_conflict
from message.utils.es_abstract import create_esop_for_database, table_schema_map, ESBulkAction
from message.utils.es_abstract import table_conversation, table_message
from multiprocessing import current_process
from multiprocessing.pool import Pool


class ProcessLocal(object):

    def __init__(self, constructor):
        self.constructor = constructor
        self.pid = None
        self.value = None

    def get(self):
        if self.pid != os.getpid():
            self.value = self.constructor()
            self.pid = os.getpid()
        return self.value


def do_chunk_once(action_chunk, dst_database):
    chunk_size = len(action_chunk)
    try:
        result = Command.get_dst_esop(dst_database).helper_bulk(
            action_chunk,
            chunk_size=chunk_size,
        )
        success = True
    except (KeyboardInterrupt, SystemExit):
        raise
    except Exception:
        result = traceback.format_exc()
        success = False
    return success, '{:15s} {:8d} {:s}'.format(current_process().name, chunk_size, repr(result))


def do_chunk(chunk, dst_database, table_name, retry=5):
    table = table_schema_map[table_name]
    action_chunk = list(Command.doc_to_action(chunk, table=table))
    result_list = []
    success = False
    for _ in range(retry):
        success, result = do_chunk_once(action_chunk=action_chunk, dst_database=dst_database)
        result_list.append(result)
        if success:
            break
    if len(result_list) == 1 and success:
        return result_list[0]
    header = '[*] Retried, success={}, count={}:'.format(
        success, len(result_list)
    )
    return '\n    '.join([header] + result_list)


class Command(BaseCommand):

    dst_esop_map = ProcessLocal(dict)

    def add_arguments(self, parser):
        parser.add_argument('--src-database')
        parser.add_argument('--dst-database')
        parser.add_argument('--table')
        parser.add_argument('--chunk-size', type=int, default=1000)
        parser.add_argument('--parallel', type=int, default=0)

    def handle(self, *args, **options):
        src_database = options['src_database']
        dst_database = options['dst_database']
        table = options['table']
        parallel = options['parallel']
        assert src_database and dst_database and table
        table_scheam = table_schema_map[table]

        src_iter = create_esop_for_database(src_database).helper_scan(
            table=table_scheam,
            fields=('_source', '_parent', '_routing', '_timestamp'),
        )
        chunk_iter = self._group(src_iter, size=options['chunk_size'])

        func = functools.partial(do_chunk, dst_database=dst_database, table_name=table)
        if parallel:
            self.do_parallel(func, chunk_iter, parallel=parallel)
        else:
            self.do_sequencial(func, chunk_iter)

    @staticmethod
    def _group(it, size):
        it = iter(it)
        try:
            while True:
                element = []
                for _ in range(size):
                    element.append(next(it))
                yield element
        except StopIteration:
            pass

    @staticmethod
    def doc_to_copy_action(hits, table):
        for h in hits:
            del h['_index']
            del h['_type']
            if 'fields' in h:
                h.update(h.pop('fields'))
            yield ESBulkAction(
                table=table,
                params=h
            )

    @staticmethod
    def doc_to_conversation_merge_action(hits, table, retry_on_conflict=None):
        retry_on_conflict = get_update_retry_on_conflict(retry_on_conflict)
        for h in hits:
            try:
                conversation_head = h['_source']
                messages = conversation_head.pop('messages', [])
                conversation_id_outer = h['_id']
                conversation_id_inner = conversation_head['id']
                assert conversation_id_inner == int(conversation_id_outer)
                yield ESBulkAction(
                    table=table,
                    params={
                        '_op_type': 'update',
                        '_id': conversation_id_outer,
                        '_retry_on_conflict': retry_on_conflict,
                        '_source': {
                            'scripted_upsert': True,
                            'upsert': conversation_head,
                            'script': {
                                'lang': settings.ES_SCRIPT_LANG,
                                'script_file': 'update_conversation-add-all-and-unique-messages',
                                'params': {
                                    'MESSAGE_LIST': messages,
                                }
                            }
                        }
                    }
                )
            except Exception:
                traceback.print_exc()

    @classmethod
    def doc_to_action(cls, hits, table):
        if table is table_conversation:
            action_chunk_getter = cls.doc_to_conversation_merge_action
        elif table is table_message:
            action_chunk_getter = cls.doc_to_copy_action
        else:
            raise Exception('unsupported table: {}'.format(table))
        return action_chunk_getter(hits=hits, table=table)

    @classmethod
    def get_dst_esop(cls, dst_database):
        d = cls.dst_esop_map.get()
        try:
            return d[dst_database]
        except KeyError:
            value = create_esop_for_database(dst_database)
            d[dst_database] = value
            return value

    @classmethod
    def do_parallel(cls, func, data_iter, parallel):
        pool = Pool(processes=parallel)
        result_iter = pool.imap_unordered(func, data_iter)
        for result in result_iter:
            print(result)
        pool.close()
        pool.join()

    @classmethod
    def do_sequencial(cls, func, data_iter):
        result_iter = itertools.imap(func, data_iter)
        for result in result_iter:
            print(result)
