# coding=utf-8
from __future__ import absolute_import, unicode_literals


from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from api.models import JpushBindUser, Tag
from gm_types.gaia import TAG_TYPE
from django.db import close_old_connections
from django.core.management import BaseCommand
import time
import jpush
import sys
import os
import multiprocessing
from jpush.common import JPushFailure
from optparse import make_option


class Job(object):
    _jpush = jpush.JPush(settings.USER_JPUSH_APP_KEY, settings.USER_JPUSH_MASTER_SECRET)


    def build_jpush_tag(self, tag_id):
        return '_'+str(tag_id)

    def parse_jpush_tag(self, jpush_tag):
        if jpush_tag.startswith('_'):
            return int(jpush_tag[1:])
        else:
            return None

    def __init__(self, start_pk, end_pk, chunk_size, step_size, chunk_min_duration):
        self.start_pk = start_pk
        self.end_pk = end_pk
        self.chunk_size = chunk_size
        self.step_size = step_size
        self.chunk_min_duration = chunk_min_duration

    def __call__(self):
        close_old_connections()  # close all conn forked from parent process
        for pk_st in xrange(self.start_pk, self.end_pk, self.step_size):
            pk_ed = min(self.end_pk, pk_st+self.chunk_size)
            bindings = JpushBindUser.objects.filter(pk__gte=pk_st, pk__lt=pk_ed)

            start_time = time.time()

            for binding in bindings:
                registration_id = binding.registration_id
                user = binding.user

                try:
                    device = self._jpush.create_device()
                    device_info = device.get_deviceinfo(registration_id=registration_id).payload

                    if 'tags' in device_info and device_info['tags']:
                        tag_ids = filter(None, [self.parse_jpush_tag(t) for t in device_info['tags']])
                        if tag_ids:
                            old_city_jpush_tags = [self.build_jpush_tag(tag.id) for tag in
                                                   Tag.objects.filter(pk__in=tag_ids, tag_type=TAG_TYPE.CITY)]
                            if old_city_jpush_tags:
                                continue

                    try:
                        if hasattr(user, 'userextra') and user.userextra:
                            if user.userextra.city is not None:
                                entity = jpush.device_tag(jpush.add(self.build_jpush_tag(user.userextra.city.tag.id)))
                                device.set_deviceinfo(registration_id=registration_id,
                                                      entity=entity)
                    except ObjectDoesNotExist:
                        pass
                except JPushFailure as e:
                    sys.stderr.write('binding for registration_id[%s] failed: %s\n' % (registration_id, e.details))
                except Exception as e:
                    sys.stderr.write('binding for registration_id[%s] failed: %s\n' % (registration_id, repr(e)))

            print('process[%d] binding id [%d, %d) done' % (os.getpid(), pk_st, pk_ed))
            sys.stdout.flush()

            end_time = time.time()
            duration = end_time-start_time
            if duration < self.chunk_min_duration:
                time.sleep(self.chunk_min_duration-duration)


class Command(BaseCommand):
    """This command should be run ONCE AND ONLY ONCE
    """
    FREQ_LIMIT = 30  # query limit 30hz
    CHUNK_SIZE = 100
    PROCESS_NUM = 2
    DURATION_LIMIT = float(PROCESS_NUM*CHUNK_SIZE)/FREQ_LIMIT

    option_list = BaseCommand.option_list+(
        make_option('-s', '--start-pk', dest='start_pk', type='int', default=1),
    )

    def handle(self, *args, **options):
        global_start_pk = options['start_pk']
        global_end_pk = JpushBindUser.objects.latest('id').id+1
        p_list = []
        for process_id in range(0, self.PROCESS_NUM):
            job = Job(
                start_pk=global_start_pk+self.CHUNK_SIZE*process_id,
                end_pk=global_end_pk,
                chunk_size=self.CHUNK_SIZE,
                step_size=self.CHUNK_SIZE*self.PROCESS_NUM,
                chunk_min_duration=self.DURATION_LIMIT)
            p = multiprocessing.Process(target=job)
            p.start()
            p_list.append(p)

        for p in p_list:
            p.join()
