# coding=utf-8

import xml.etree.ElementTree as ET
import traceback
import re
from api.tool.log_tool import logging_exception
import logging
import redis
import os
import sys
from django.conf import settings


class DataPoint:
    def __init__(self):
        # elemType:float
        self.fvalsDict = dict()

        # begin with#
        self.description = ""

        self.label = -1
        self.id = ""

    def getFeatureValue(self, fid):
        try:
            if fid<=0 or fid > len(self.fvalsDict):
                logging.error("error fid val:%f" % fid)
                return 0.0

            return self.fvalsDict[fid]
        except:
            logging_exception()
            return 0.0


    def parse(self, content):
        try:
            idx = content.find("#")
            if idx != -1:
                self.description = content[idx:]
                content = content[:idx]
            itemList = re.split("\s", content)

            self.label = float(itemList[0])
            if self.label < 0:
                logging.error("Relevance label cannot be negative. System will now exit!")
                return False

            self.id = int(itemList[1].split(":")[1])

            for item in itemList[2:]:
                offiItem = item.strip()
                if len(offiItem) > 0:
                    keyList = offiItem.split(":")
                    key = int(keyList[0])
                    val = float(keyList[1])

                    if key < 0:
                        logging.error("key val error!")
                        return False
                    self.fvalsDict[key] = val

            return True
        except:
            logging_exception()
            return False

    def getId(self):
        return self.id

class Split:
    def __init__(self,output=0.0,featureId=-1,threshold=0.0,deviance=0.0):
        self.featureId = featureId
        self.threshold = threshold
        self.deviance = deviance
        self.avgLabel = output

        self.isRoot = False
        self.sumLabel = 0.0
        self.sqSumLabel = 0.0

        # Split type
        self.leftObj = None
        self.rightObj = None

    def setLeft(self,s):
        self.leftObj = s

    def setRight(self,s):
        self.rightObj = s

    def leaves(self):
        leavesList = list()
        self.leavesOne(leavesList)
        return leavesList

    def leavesOne(self,leavesList):
        if self.featureId == -1:
            leavesList.append(self)
        else:
            self.leftObj.leavesOne(leavesList)
            self.rightObj.leavesOne(leavesList)

    def eval(self, dataPointObj):
        try:
            splitObj = self
            while splitObj.featureId != -1:
                if dataPointObj.getFeatureValue(splitObj.featureId) <= splitObj.threshold:
                    splitObj = splitObj.leftObj
                else:
                    splitObj = splitObj.rightObj

            return splitObj.avgLabel
        except:
            logging_exception()
            return None

class RegressionTree:
    def __init__(self, splitRoot=None):
        self.splitRoot = None
        self.leavesList = None

        if splitRoot:
            self.splitRoot = splitRoot
            self.leavesList = splitRoot.leaves()


    def eval(self, dataPointObj):
        try:
            return self.splitRoot.eval(dataPointObj)
        except:
            logging_exception()
            return None

class Ensemble:
    def __init__(self,strXmlModel):
        # itemType: RegessionTree
        self.treesDict = dict()

        # itemType: float
        self.weightsDict = dict()

        # itemType: int
        self.featuresList = list()

        self.strXmlMod = strXmlModel

    def create(self,splitNode,fids_dict):
        try:
            splitObj = None
            splitChildren = splitNode.getchildren()

            if splitChildren[0].tag == "feature":
                fid = int(splitChildren[0].text)
                threshold = float(splitChildren[1].text)

                fids_dict[fid] = 0
                splitObj = Split(featureId=fid,threshold=threshold,deviance=0)
                splitObj.setLeft(self.create(splitChildren[2],fids_dict))
                splitObj.setRight(self.create(splitChildren[3],fids_dict))

            else:# this is a stump
                output = float(splitChildren[0].text)
                splitObj = Split(output=output)

            return splitObj
        except:
            logging_exception()
            return None

    def getTreesDict(self):
        try:
            rootTree = ET.fromstring(self.strXmlMod)

            dictIndex = 0
            #int,int
            fids_dict = dict()
            # 遍历循环 tree 内容
            for elem in rootTree.iterfind("tree"):
                #print("###########")
                splitRoot = self.create(elem.getchildren()[0],fids_dict)
                #print(elem.items())
                weight = float(elem.get("weight"))

                self.treesDict[dictIndex] = RegressionTree(splitRoot=splitRoot)
                self.weightsDict[dictIndex] = weight
                dictIndex += 1

            fidsIndex = 0
            for featureId in fids_dict:
                self.featuresList.append((fidsIndex, featureId))
                fidsIndex += 1

            return self.treesDict
        except:
            logging_exception()
            return None

    def eval(self, dataPointObj):
        try:
            score = 0.0
            for keyIndex in self.treesDict:
                score += self.treesDict[keyIndex].eval(dataPointObj)*self.weightsDict[keyIndex]

            return score
        except:
            logging_exception()
            return -1

class loadModelFromModFile(object):
    ensembleObj = None
    redisIp = "r-m5e20bd3a58c36c4.redis.rds.aliyuncs.com"
    redisPort = 6379
    redisPwd = "ahf3iGu4cahLoh"
    redisCli = None

    diaryMaxDiaryTotalClickNum = 1.0
    diaryQueryMaxScore = 1.0
    diaryMaxPopularity = 1.0
    diaryMaxVoteNum = 1.0

    @classmethod
    def __init__(cls):
        pass

    @classmethod
    def getModelContent(cls, model_file):
        try:
            if cls.ensembleObj is None:
                logging.info("begin load mod file")
                model = ""
                f = open(model_file)
                for line in f.readlines():
                    line = line.strip()
                    if len(line) > 0 and line.find("##") == -1:
                        model += line
                        if line.find("</ensemble>") != -1:
                            break
                f.close()

                cls.ensembleObj = Ensemble(model)
                cls.ensembleObj.getTreesDict()
            return cls.ensembleObj
        except:
            logging_exception()
            return None

    @classmethod
    def getRedisCli(cls):
        try:
            if cls.redisCli is None:
                #cls.redisCli = redis.from_url("redis://:ahf3iGu4cahLoh@r-m5e20bd3a58c36c4.redis.rds.aliyuncs.com:6379")
                cls.redisCli = redis.from_url(settings.GM_KV_URL)

            return cls.redisCli
        except:
            logging_exception()
            return None

    @classmethod
    def getGlobalMaxNum(cls,redisCli):

        try:
            if cls.diaryMaxDiaryTotalClickNum == 1.0:
                cls.diaryMaxDiaryTotalClickNum = float(redisCli.hget("search_ltr:diary_have_been_clicked", "maxDiaryTotalClickNum"))
                cls.diaryQueryMaxScore = float(redisCli.hget("search_ltr:diary_have_been_clicked", "maxScore"))
                cls.diaryMaxPopularity = float(redisCli.hget("search_ltr:diary_have_been_clicked", "maxPopularity"))
                cls.diaryMaxVoteNum = float(redisCli.hget("search_ltr:diary_have_been_clicked", "maxVoteNum"))

            return (cls.diaryMaxDiaryTotalClickNum,cls.diaryQueryMaxScore,cls.diaryMaxPopularity,cls.diaryMaxVoteNum)
        except:
            logging_exception()
            return (1.0,1.0,1.0,1.0)

def evalOneContent(content):
    try:
        mod_file_name = os.path.dirname(os.path.realpath(__file__)) + "/learned_lambdamart_model.mod"
        ensembleObj = loadModelFromModFile.getModelContent(mod_file_name)
        content = content.strip()
        dataPointObj = DataPoint()
        dataPointObj.parse(content)

        return (dataPointObj.getId(), ensembleObj.eval(dataPointObj))
        #print("%s\t%f" % (dataPointObj.getId(), ensembleObj.eval(dataPointObj)))
    except:
        logging_exception()
        return None

def getRedisCli():
    return loadModelFromModFile.getRedisCli()

def getGlobalMaxNum():
    return loadModelFromModFile.getGlobalMaxNum(getRedisCli())