import threading
from itertools import imap
from multiprocessing import cpu_count
from threading import Thread, Lock

from django.conf import settings
from django.core.management import BaseCommand
from elasticsearch import helpers

from api.models.message import Message
from message.utils.es import load_mapping
from message.utils.es_abstract import get_migrate_esop, get_esop, table_message

console_lock = Lock()


def get_message_bulk_action(msg):
    return {
        '_index': 'gm_msg-message',
        '_type': 'message',
        '_id': msg['id'],
        '_source': msg
    }


def read_msg_from_es(msg_ids):
    if not msg_ids:
        return []
    msgs = get_esop().mget(
        table=table_message,
        body={'ids': [str(i) for i in msg_ids]}
    )
    msgs = msgs['docs']
    return [m['_source'] for m in msgs if m['found'] and 'id' in m['_source']]


def worker(start_id, stop_id, chunk_size=100):
    assert start_id <= stop_id
    with console_lock:
        print("[thread-{0}]messages between {1} and {2} are starting".format(
            threading.currentThread().getName(), start_id, stop_id))
    for i in range(start_id, stop_id, chunk_size):
        right = min(stop_id, i + chunk_size - 1)

        # read from older and write to new one.
        print("range: {}-{}".format(i, right))
        msgs = read_msg_from_es(range(i, right + 1))
        helpers.bulk(get_migrate_esop().client, imap(get_message_bulk_action, msgs))

    with console_lock:
        print("[thread-{0}]messages between {1} and {2} are done".format(
            threading.currentThread().getName(), start_id, stop_id))


class Command(BaseCommand):
    @staticmethod
    def setup():
        """
        setup index and mappings
        """
        index = 'gm_msg_test-message'
        setting = {
            "settings": {
                "number_of_shards": 8,
                "number_of_replicas": 1
            },
            "mappings": load_mapping('message.v2')
        }
        es = get_migrate_esop().client
        if es.indices.exists(index):
            es.indices.delete(index)

        es.indices.create(index=index, body=setting)

    def add_arguments(self, parser):
        parser.add_argument('-c', help="number of thread.defaults cpu's core", type=int)
        parser.add_argument('-s', help="sync start message id", type=int)
        parser.add_argument('-t', help="sync stop message id", type=int)

    def handle(self, *args, **options):
        num = cpu_count()
        start = 1
        stop = Message.objects.using(settings.SLAVE_DB_NAME).last().id
        if options['c']:
            num = options['c']
        if options['s']:
            start = options['s']
        if options['t']:
            stop = options['t']

        assert start <= stop
        print('max_message_id: {0}, thread_num: {1}'.format(stop, num))
        step = max((stop - start + 1) / num, 1)

        threads = []
        for idx, i in enumerate(range(start, stop, step)):
            right = min(stop, i + step - 1)
            t = Thread(target=worker, args=(i, right), name=idx + 1)
            threads.append(t)
            t.start()

        for t in threads:
            t.join()
