# -*- coding:utf-8 -*-
# author:gm
# mail: zhangguodong@igengmei.com
# datetime:2020/4/24 3:32 下午
# software: PyCharm
from preprocesser.processors import token_processor
from preprocesser.filter import stopwords_filter
from collections import Counter
from config import config
import os
import codecs
import json


class SELECTED_CONTENT_TYPE():
    BEAUTY_PROJECT = (1, "医美项目")
    BEAUTY_STAR = (2, "明星医美")
    BEAUTY_CELEBRITY = (3, "网红医美")
    STAR_GOSSIP = (4, "明星八卦")
    CELEBRITY_GOSSIP = (5, "网红八卦")


class TextClassifical(object):
    def __init__(self, network_influencer_path, project_path, star_path, synonym_path, tag_info_path,
                 support_words_path, encoding="utf-8"):
        self.encoding = encoding
        self.network_influencer_words = self.build_network_influencer_words(network_influencer_path)
        self.project_words = self.build_project_words(project_path)
        self.star_words = self.build_star_words(star_path)
        self.tokenprocessor = token_processor
        self.stopwords_filter = stopwords_filter
        self.project_synonym = self.build_projects_synonym(synonym_path)
        self.template = [
            u"自体脂肪 泪沟 自体脂肪填充泪沟",
            u"自体脂肪 黑眼圈 自体脂肪填充黑眼圈",
            u"自体脂肪 隆鼻 自体脂肪隆鼻",
            u"自体脂肪 下巴 自体脂肪丰下巴",
            u"自体脂肪 面颊 自体脂肪丰面颊",
            u"自体脂肪 太阳穴 自体脂肪丰太阳穴",
            u"自体脂肪 额头 自体脂肪丰额头",
            u"自体脂肪 唇 自体脂肪丰唇",
            u"自体脂肪 颈纹 自体脂肪除颈纹",
            u"自体脂肪 法令纹 自体脂肪除法令纹",
            u"自体脂肪 臀 自体脂肪丰臀",
            u"自体脂肪 胸 自体脂肪隆胸",
            u"自体脂肪 卧蚕 自体脂肪填充卧蚕",
            u"自体脂肪 鼻基底 自体脂肪垫鼻基底",
            u"自体脂肪 眉弓 自体脂肪垫眉弓",
            u"自体脂肪 臀 自体脂肪填充臀部",
            u"耳软骨 鼻子 耳软骨隆鼻"
        ]
        self.template_logic = self.build_template()
        self.tag_info = self.build_tag_info_pro(tag_info_path)
        self.support_words = self.build_support_words(support_words_path)

    def build_support_words(self, support_words_path):
        ret = []
        ret = json.loads(codecs.open(support_words_path, "r", encoding=self.encoding).read())
        return set(ret)

    def build_tag_info_pro(self, tag_info_path):
        ret = {}
        ret = json.loads(codecs.open(tag_info_path, "r", encoding=self.encoding).read())
        return ret

    def build_template(self):
        ret = []
        for item in self.template:
            infos = item.split(" ")
            if len(infos) == 3:
                ret.append((set(infos[:2]), infos[-1]))
        return ret

    def build_network_influencer_words(self, word_path):
        ret = {}
        for line in codecs.open(word_path, "r", errors="ignore", encoding=self.encoding):
            line = line.strip()
            ret[line] = 1
        return ret

    def build_project_words(self, project_path):
        ret = {}
        for line in codecs.open(project_path, "r", errors="ignore", encoding=self.encoding):
            line = line.strip()
            ret[line] = 1
        return ret

    def build_star_words(self, star_path):
        ret = {}
        for line in codecs.open(star_path, "r", errors="ignore", encoding=self.encoding):
            line = line.strip()
            ret[line] = 1
        return ret

    def build_projects_synonym(self, project_synonym_path):
        ret = {}
        for line in codecs.open(project_synonym_path, "r", errors="ignore", encoding=self.encoding):
            line = line.strip()
            words = line.split(",")
            words = [word.strip() for word in words]
            if len(words) > 1:
                core_word = words[0]
                for word in words[1:]:
                    ret[word] = core_word
        return ret

    def standard_project(self, word):
        return self.project_synonym.get(word, word)

    def get_inference_tags(self, words):
        ret = []
        corpus = set(words)
        for item in self.template_logic:
            candidates, tag = item[0], item[1]
            if len(corpus & candidates) == len(candidates):
                ret.append(tag)
        return [{item: 1.0} for item in ret]

    def get_info_inference_tags(self, words, proba_threshold=0.3, topk=10):
        tag_proba = {}
        common_words_concurrence = self.support_words & set(words)
        for word in common_words_concurrence:
            for tag in self.tag_info[word]:
                if tag in tag_proba:
                    tag_proba[tag] = tag_proba[tag] + self.tag_info[word][tag]
                else:
                    tag_proba[tag] = self.tag_info[word][tag]
        return sorted([(tag, tag_proba[tag]) for tag in tag_proba if tag_proba[tag] > proba_threshold],
                      key=lambda x: x[1],
                      reverse=True)[:topk]

    def run(self, content, proba_threshold=0.3, topk=10):
        ret = {
            "content_type": -1,
            "star": [],
            "celebrity": [],
            "projects": [],
            "inference_tags": [],
            "info_inference_tags": []
        }
        words = self.tokenprocessor.lcut(content, cut_all=True)
        words = stopwords_filter.filter(words)
        netword_influencer_concurrence = set(words) & set(self.network_influencer_words)
        project_word_concurrence = set(words) & set(self.project_words)
        star_words_concurrence = set(words) & set(self.star_words)
        info_inference_tags = self.get_info_inference_tags(words, proba_threshold=0.3, topk=10)
        counter = Counter(words)
        content_type, words_proba = self.predict(counter, netword_influencer_concurrence, project_word_concurrence,
                                                 star_words_concurrence)
        ret["content_type"] = content_type
        ret["star"].extend([{word: words_proba[2].get(word, 0.0)} for word in list(star_words_concurrence)])
        ret["celebrity"].extend(
            [{word: words_proba[0].get(word, 0.0)} for word in list(netword_influencer_concurrence)])
        ret["projects"].extend(
            [{self.standard_project(word): words_proba[1].get(word, 0.0)} for word in list(project_word_concurrence)])
        ret["inference_tags"].extend(self.get_inference_tags(words))
        ret["info_inference_tags"].extend(info_inference_tags)
        return ret

    def score(self, counter, concurrence_words):
        pass

    def predict(self, counter, netword_influencer_concurrence, project_word_concurrence, star_words_concurrence):
        words_proba = []
        net_influencer_total = sum([counter[word] * 2 for word in netword_influencer_concurrence])
        net_influencer_proba = {word: float(counter[word] * 2) / net_influencer_total for word in
                                netword_influencer_concurrence}
        words_proba.append(net_influencer_proba)
        project_words_total = sum([counter[word] for word in project_word_concurrence])
        project_words_proba = {word: float(counter[word]) / project_words_total for word in project_word_concurrence}
        words_proba.append(project_words_proba)
        star_words_total = sum([counter[word] * 2 for word in star_words_concurrence])
        star_words_proba = {word: float(counter[word] * 2) / star_words_total for word in star_words_concurrence}
        words_proba.append(star_words_proba)
        total_word = sum([net_influencer_total, project_words_total, star_words_total])
        if total_word <= 0:
            return -1, words_proba
        each_proba = [float(item) / total_word for item in
                      [net_influencer_total, project_words_total, star_words_total]]
        if each_proba[1] <= 0 and each_proba[2] >= each_proba[0]:
            return SELECTED_CONTENT_TYPE.STAR_GOSSIP[0], words_proba
        elif each_proba[1] <= 0 and each_proba[2] < each_proba[0]:
            return SELECTED_CONTENT_TYPE.CELEBRITY_GOSSIP[0], words_proba
        elif each_proba[1] > 0.75:
            return SELECTED_CONTENT_TYPE.BEAUTY_PROJECT[0], words_proba
        elif each_proba[0] > each_proba[2]:
            return SELECTED_CONTENT_TYPE.BEAUTY_CELEBRITY[0], words_proba
        else:
            return SELECTED_CONTENT_TYPE.BEAUTY_STAR[0], words_proba


root_path = "/".join(str(__file__).split("/")[:-3])
model = TextClassifical(os.path.join(root_path, config.network_influcer_dic),
                        os.path.join(root_path, config.projects_dic), os.path.join(root_path, config.star_dic),
                        os.path.join(root_path, config.synonym_path), os.path.join(root_path, config.tag_info_path),os.path.join(root_path,config.support_words_path))
