#! /usr/bin/env python
# -*- coding: utf-8 -*-
# __author__ = "chenwei"
# Date: 2019/4/4

from gm_dataquery.db import DB
from django.db import transaction
from django.db.models import Q
from gm_dataquery.dict_mixin import to_dict
from gm_dataquery.dataquery import DataBuilder, DataSQLQuery

from category.models import (
    CategoryGroup, CategoryElement, CategoryElementTag,
    CategoryArea, CategoryAreaElement, CategoryPolymer,
    CategoryGadget, CategoryPolymerBrand, CategoryPolymerBanner,
    CategoryPolymerGroup, CategoryPolymerArea, CategoryPolymerTag,
    CategoryPolymerRecommend, CategoryPolymerRecommendRelatedCities,
    Category, CategoryElementTagV3, CategoryPolymerTagV3
)

from api.models.city import City


class CategoryGroupDB(DataBuilder):
    pass


@DB
class CategoryGroupDQ(DataSQLQuery):
    model = CategoryGroup
    data_model = CategoryGroupDB

    def filter_start_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)

    def filter_end_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)


class CategoryElementDB(DataBuilder):
    def getval_tags(self, obj):
        return ','.join(
            [
                u'{tag_id}:{tag_name}'.format(
                    tag_id=_d['tag_id'], tag_name=_d['tag__name']
                ) for _d in obj.tags.values('tag_id', 'tag__name')
            ]
        )

    def getval_tag_ids(self, obj):
        return list(obj.tags.values_list('tag_id', flat=True))

    def getval_tag_v3_ids(self, obj):
        return list(CategoryElementTagV3.objects.filter(element_id=obj.id).values_list("tag_v3_id", flat=True))


@DB
class CategoryElementDQ(DataSQLQuery):
    model = CategoryElement
    data_model = CategoryElementDB

    def update_tags(self, obj, tags):
        if not isinstance(tags, list):
            return
        old_tags = list(CategoryElementTag.objects.filter(element=obj).values_list('tag_id', flat=True))
        CategoryElementTag.objects.filter(tag_id__in=set(old_tags) - set(tags), element=obj).delete()
        for tag_id in set(tags) - set(old_tags):
            CategoryElementTag.objects.get_or_create(
                element=obj, tag_id=tag_id
            )

    def update_tags_v3(self, element_id, tags):
        if not isinstance(tags, list):
            return
        old_tags = list(CategoryElementTagV3.objects.filter(element_id=element_id).values_list('tag_v3_id', flat=True))

        CategoryElementTagV3.objects.filter(tag_v3_id__in=set(old_tags) - set(tags), element_id=element_id).delete()

        for tag_id in set(tags) - set(old_tags):
            CategoryElementTagV3.objects.get_or_create(
                element_id=element_id, tag_v3_id=tag_id
            )

    def create(self, **kwargs):
        tag_ids = kwargs.pop('tag_ids', [])
        tag_v3_ids = kwargs.pop("tag_v3_ids", [])
        obj = self.model.objects.create(**kwargs)
        self.update_tags(obj, tag_ids)
        self.update_tags_v3(obj.id, tag_v3_ids)

        return {'id': obj.id}

    def update(self, updates, **kwargs):
        tag_ids = None
        if 'tag_ids' in updates:
            tag_ids = updates.pop('tag_ids')

        if "tag_v3_ids" in updates:
            self.update_tags_v3(kwargs["id"], updates.pop("tag_v3_ids", []))

        obj = self.model.objects.get(id=kwargs['id'])
        for k, v in updates.items():
            setattr(obj, k, v)
        obj.save()

        if tag_ids is not None:
            self.update_tags(obj, tag_ids)

        return {'id': obj.id}


class CategoryAreaDB(DataBuilder):
    def getval_category_ids(self, obj):
        return list(obj.elements.filter(element__is_main_category=False).values_list('element_id', flat=True))

    def getval_c_category_ids(self, obj):
        return list(obj.elements.filter(
            element__is_main_category=True
        ).values_list('element_id', flat=True))

    def getval_cc_category_ids(self, obj):
        return list(obj.elements.filter(
            element__is_main_category=True
        ).values_list('element_id', flat=True))

    def getval_categorys(self, obj):
        return ','.join(
            [
                u'{element_id}:{element_name}'.format(
                    element_id=_d['element_id'], element_name=_d['element__name']
                ) for _d in obj.elements.values('element_id', 'element__name')
            ]
        )


@DB
class CategoryAreaDQ(DataSQLQuery):
    model = CategoryArea
    data_model = CategoryAreaDB

    def update_category(self, obj, category_ids):
        if not isinstance(category_ids, list):
            return

        old_category = list(CategoryAreaElement.objects.filter(area=obj).values_list('element_id', flat=True))
        CategoryAreaElement.objects.filter(element_id__in=set(old_category) - set(category_ids), area=obj).delete()

        c_list = list(set(category_ids) - set(old_category))
        c_list.sort(key=category_ids.index)

        for element_id in c_list:
            CategoryAreaElement.objects.get_or_create(
                area=obj, element_id=element_id
            )

    def create(self, **kwargs):
        category_ids = kwargs.pop('category_ids', [])
        c_category_ids = kwargs.pop('c_category_ids', [])
        cc_category_ids = kwargs.pop('cc_category_ids', [])

        if c_category_ids:
            category_ids.extend([c_category_ids,])

        if cc_category_ids:
            category_ids.extend([cc_category_ids, ])

        obj = self.model.objects.create(**kwargs)
        self.update_category(obj, category_ids)
        return {'id': obj.id}

    def update(self, updates, **kwargs):
        category_ids = None
        if 'category_ids' in updates:
            category_ids = updates.pop('category_ids')
        obj = self.model.objects.get(id=kwargs['id'])
        for k, v in updates.items():
            setattr(obj, k, v)
        obj.save()

        if category_ids is not None:
            self.update_category(obj, category_ids)

        return {'id': obj.id}


class CategoryPolymerDB(DataBuilder):
    def getval_banner_ids(self, obj):
        return list(obj.banners.values_list('banner_id', flat=True))

    def getval_gadget_ids(self, obj):
        return list(obj.gadgets.values_list('gadget_id', flat=True))

    def getval_area_ids(self, obj):
        return list(obj.areas.values_list('area_id', flat=True))

    def getval_group_ids(self, obj):
        return list(obj.groups.values_list('group_id', flat=True))

    def getval_brand_ids(self, obj):
        return list(obj.brands.values_list('brand_id', flat=True))

    def getval_tag_ids(self, obj):
        return list(obj.tags.values_list('tag_id', flat=True))

    def getval_tag_v3_ids(self, obj):
        return list(CategoryPolymerTagV3.objects.filter(category_id=obj.id).values_list("tag_v3_id", flat=True))


@DB
class CategoryPolymerDQ(DataSQLQuery):
    model = CategoryPolymer
    data_model = CategoryPolymerDB

    def update_banner(self, obj, banner_ids):
        if not isinstance(banner_ids, list):
            return
        old_banners = list(CategoryPolymerBanner.objects.filter(category=obj).values_list('banner_id', flat=True))
        CategoryPolymerBanner.objects.filter(banner_id__in=set(old_banners) - set(banner_ids), category=obj).delete()
        for banner_id in set(banner_ids) - set(old_banners):
            CategoryPolymerBanner.objects.get_or_create(
                category=obj, banner_id=banner_id
            )

    def update_gadget(self, obj, gadget_ids):
        if not isinstance(gadget_ids, list):
            return
        old_gadgets = list(CategoryGadget.objects.filter(category=obj).values_list('gadget_id', flat=True))
        CategoryGadget.objects.filter(gadget_id__in=set(old_gadgets) - set(gadget_ids), category=obj).delete()
        for gadget_id in set(gadget_ids) - set(old_gadgets):
            CategoryGadget.objects.get_or_create(
                category=obj, gadget_id=gadget_id
            )

    def update_area(self, obj, area_ids):

        if not isinstance(area_ids, list):
            return
        old_areas = list(CategoryPolymerArea.objects.filter(category=obj).values_list('area_id', flat=True))
        CategoryPolymerArea.objects.filter(area_id__in=set(old_areas) - set(area_ids), category=obj).delete()

        r_list = list(set(area_ids) - set(old_areas))
        r_list.sort(key=area_ids.index)  # 保持原有绝对位置

        for area_id in r_list:
            CategoryPolymerArea.objects.get_or_create(
                category=obj, area_id=area_id
            )

    def update_group(self, obj, group_ids):
        if not isinstance(group_ids, list):
            return
        old_groups = list(CategoryPolymerGroup.objects.filter(category=obj).values_list('group_id', flat=True))
        CategoryPolymerGroup.objects.filter(group_id__in=set(old_groups) - set(group_ids), category=obj).delete()
        for group_id in set(group_ids) - set(old_groups):
            CategoryPolymerGroup.objects.get_or_create(
                category=obj, group_id=group_id
            )

    def update_brand(self, obj, brand_ids):
        if not isinstance(brand_ids, list):
            return
        old_brands = list(CategoryPolymerBrand.objects.filter(category=obj).values_list('brand_id', flat=True))
        CategoryPolymerBrand.objects.filter(brand_id__in=set(old_brands) - set(brand_ids), category=obj).delete()
        for brand_id in set(brand_ids) - set(old_brands):
            CategoryPolymerBrand.objects.get_or_create(
                category=obj, brand_id=brand_id
            )

    def update_tags(self, obj, tag_ids):
        if not isinstance(tag_ids, list):
            return

        old_tags = list(CategoryPolymerTag.objects.filter(category=obj).values_list('tag_id', flat=True))
        CategoryPolymerTag.objects.filter(tag_id__in=set(old_tags) - set(tag_ids), category=obj).delete()
        for tag_id in set(tag_ids) - set(old_tags):
            CategoryPolymerTag.objects.get_or_create(
                category=obj, tag_id=tag_id
            )

    def update_tag_v3(self, category_id, tag_ids):
        if not isinstance(tag_ids, list):
            return

        old_tags = list(CategoryPolymerTagV3.objects.filter(category_id=category_id).values_list('tag_v3_id', flat=True))

        CategoryPolymerTagV3.objects.filter(tag_v3_id__in=set(old_tags) - set(tag_ids), category_id=category_id).delete()

        for tag_id in set(tag_ids) - set(old_tags):
            CategoryPolymerTagV3.objects.get_or_create(
                category_id=category_id, tag_v3_id=tag_id
            )

    def create(self, **kwargs):
        banner_ids = kwargs.pop('banner_ids', [])
        gadget_ids = kwargs.pop('gadget_ids', [])
        area_ids = kwargs.pop('area_ids', [])
        group_ids = kwargs.pop('group_ids', [])
        brand_ids = kwargs.pop('brand_ids', [])
        tag_ids = kwargs.pop('tag_ids', [])
        tag_v3_ids = kwargs.pop('tag_v3_ids', [])

        if not kwargs.get('operation_level_one_id'):
            kwargs["operation_level_one_id"] = None

        with transaction.atomic():
            obj = self.model.objects.create(**kwargs)
            self.update_banner(obj, banner_ids)
            self.update_group(obj, group_ids)
            self.update_brand(obj, brand_ids)
            self.update_area(obj, area_ids)
            self.update_gadget(obj, gadget_ids)
            self.update_tags(obj, tag_ids)
            self.update_tag_v3(obj.id, tag_v3_ids)

        return {'id': obj.id}

    def filter_start_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)

    def filter_end_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)

    def update(self, updates, **kwargs):
        banner_ids = None
        gadget_ids = None
        area_ids = None
        group_ids = None
        brand_ids = None
        tag_ids = None
        tag_v3_ids = None

        if 'banner_ids' in updates:
            banner_ids = updates.pop('banner_ids', [])

        if 'gadget_ids' in updates:
            gadget_ids = updates.pop('gadget_ids', [])

        if 'area_ids' in updates:
            area_ids = updates.pop('area_ids', [])

        if 'group_ids' in updates:
            group_ids = updates.pop('group_ids', [])

        if 'brand_ids' in updates:
            brand_ids = updates.pop('brand_ids', [])

        if 'tag_ids' in updates:
            tag_ids = updates.pop('tag_ids', [])

        if "tag_v3_ids" in updates:
            tag_v3_ids = updates.pop("tag_v3_ids", [])

        obj = self.model.objects.get(id=kwargs['id'])

        if not updates.get('operation_level_one_id'):
            updates["operation_level_one_id"] = None

        for k, v in updates.items():
            setattr(obj, k, v)
        obj.save()

        if gadget_ids is not None:
            self.update_gadget(obj, gadget_ids)

        if banner_ids is not None:
            self.update_banner(obj, banner_ids)

        if area_ids is not None:
            self.update_area(obj, area_ids)

        if group_ids is not None:
            self.update_group(obj, group_ids)

        if brand_ids is not None:
            self.update_brand(obj, brand_ids)

        if tag_ids is not None:
            self.update_tags(obj, tag_ids)

        if tag_v3_ids is not None:
            self.update_tag_v3(obj.id, tag_v3_ids)

        return {'id': obj.id}


class CategoryPolymerRecommendDB(DataBuilder):

    def getval_cities_for_list(self, obj):
        return obj.get_related_cities()

    def getval_cities_for_detail(self, obj):
        return list(obj.category_recommend_cities.values_list('city_id', flat=True))

    def getval_mix_id(self, obj):
        return u'{id}:{word}'.format(id=obj.categorypolymer.id, word=obj.categorypolymer.name)


@DB
class CategoryPolymerRecommendDQ(DataSQLQuery):
    model = CategoryPolymerRecommend
    data_model = CategoryPolymerRecommendDB

    def filter_cities_for_list(self, srch_key, srch_val, regex=False):
        return Q(
            id__in=CategoryPolymerRecommendRelatedCities.objects.filter(
                city_id=srch_val).values_list("categorypolymerrecommend_id", flat=True)
        ) | Q(is_related_city=False)

    def filter_start_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)

    def filter_end_time(self, srch_key, srch_val, regex=False):
        return self._qry_time_range(srch_key, srch_val, regex)

    def update(self, updates, **kwargs):
        obj = self.model.objects.get(id=kwargs['id'])
        for k, v in updates.items():
            setattr(obj, k, v)
        obj.save()

        update_city = updates.get('cities_for_detail', '')
        if update_city == '':  #城市没有更新
            is_related_city = 1
        if not update_city and isinstance(update_city, list):  #没有关联城市
            is_related_city = 2
            updates.update({"is_related_city": False})
        elif update_city:
            is_related_city = 3
            updates.update({"is_related_city": True})
        obj.add_related_cities(update_city, is_related_city)

        return {'id': kwargs['id']}

    def create(self, **kwargs):
        cities = kwargs.pop("cities_for_detail", "")
        polymer_obj = self.model.objects.create(**kwargs)
        if cities:
            polymer_obj.add_related_cities(cities)
            polymer_obj.is_related_city = True
            polymer_obj.save()
        return to_dict(polymer_obj)


@DB
class CategoryPolymerAreaDQ(DataSQLQuery):
    model = CategoryPolymerArea



class CategoryDB(DataBuilder):
    def getval_tag_ids(self, obj):
        return obj.tag_ids.split(",")


@DB
class CategoryDQ(DataSQLQuery):
    model = Category
    data_model = CategoryDB

    def create(self, **kwargs):
        tag_ids = kwargs.pop('tag_ids', [])
        kwargs['tag_ids'] = ','.join(tag_ids)
        return to_dict(self.model.objects.create(**kwargs))

    def update(self, updates, **kwargs):
        count = 0
        tag_ids = updates.pop('tag_ids', [])
        if tag_ids:
            updates['tag_ids'] = ','.join(tag_ids)

        with transaction.atomic(self.model.objects.db):
            for obj in self.model.objects.select_for_update().filter(**kwargs):
                count += 1
                for k, v in updates.items():
                    setattr(obj, k, v)
                obj.save()
        return count
