import multiprocessing
import os
import time
import traceback

from gensim.models import Word2Vec, word2vec
from gm_rpcd.all import bind
from utils.es import es_scan
from utils.message import send_msg_to_dingtalk

base_dir = os.getcwd()
print("base_dir: " + base_dir)
model_dir = os.path.join(base_dir, "_models")
data_dir = os.path.join(base_dir, "_data")

model_output_name = "w2v_model"
model_path = os.path.join(model_dir, model_output_name)
try:
    WORD2VEC_MODEL = word2vec.Word2Vec.load(model_path)
except Exception as e:
    print(e)

tracate_click_ids_model_name = "tractate_click_ids_item2vec_model"
tractate_click_ids_model_path = os.path.join(model_dir, tracate_click_ids_model_name)
try:
    TRACTATE_CLICK_IDS_MODEL = word2vec.Word2Vec.load(tractate_click_ids_model_path)
except Exception as e:
    print(e)


class W2vSentences:
    def __init__(self, f_name):
        self.f_name = f_name

    def __iter__(self):
        with open(self.f_name, mode="r", encoding="utf-8", errors="ignore") as f:
            for line in f.readlines():
                yield line.split()


def w2v_train(f_name, model_output_name):
    input_file = os.path.join(data_dir, f_name)
    print("input: " + input_file)
    sentences = W2vSentences(input_file)
    w2v_model = word2vec.Word2Vec(sentences, min_count=2, workers=2, size=100, window=10)
    model_path = os.path.join(model_dir, model_output_name)
    print("output: " + model_path)
    w2v_model.save(model_path)


@bind("strategy_embedding/word_vector/word_similarity")
def word_similarity(word):
    try:
        return WORD2VEC_MODEL.wv.most_similar(word)
    except Exception as e:
        send_msg_to_dingtalk(str(traceback.format_exc()))
        return []


def get_user_portrait_projects(score_limit=5):
    """
    return: {
      '6231F098-9E72-448E-B8D2-19FCB9687005': ['鼻综合', '玻尿酸填充面部'],
      '862538030266882': ['吸脂瘦脸', '吸脂瘦全身']
    }
    """
    es_res = es_scan("device", {}, rw=None)
    count = 0
    res = {}
    for i in es_res:
        count += 1
        print(count)
        source = i["_source"]
        device_id = source.get("device_id", "")
        projects = [i["name"] for i in source.get("projects", []) if i["score"] >= score_limit]
        if projects:
            res[device_id] = projects
    return res


def projects_item2vec(score_limit=5):
    user_dict = get_user_portrait_projects(score_limit=score_limit)
    # TODO if not redis.get user_dict:
    projects = list(user_dict.values())
    model = Word2Vec(projects, hs=0, min_count=3, workers=multiprocessing.cpu_count(), iter=10)
    print(model)
    print(len(projects))
    for word in ["鼻综合", "吸脂瘦脸"]:
        print(model.wv.most_similar(word, topn=5))
    return model


def save_clicked_tractate_ids_item2vec():
    click_ids = []
    with open(os.path.join(data_dir, "click_tractate_ids.csv"), "r") as f:
        data = f.readlines()
        for i in data:
            tmp = i.split("|")
            # app_session_id = tmp[0]
            ids = tmp[1].rstrip("\n").split(",")
            click_ids.append(ids)
    model = Word2Vec(click_ids, hs=0, min_count=3, workers=multiprocessing.cpu_count(), iter=10)
    print(model)
    print(len(click_ids))

    model.save(tractate_click_ids_model_path)
    return model


@bind("strategy_embedding/word_vector/tractate_item2vec")
def clicked_tractate_ids_item2vec_model(id, n=5):
    try:
        return TRACTATE_CLICK_IDS_MODEL.wv.most_similar(id, topn=n)
    except Exception as e:
        send_msg_to_dingtalk(str(traceback.format_exc()))
        return []


if __name__ == "__main__":
    begin_time = time.time()

    # w2v_train("dispose_problem.txt", model_output_name)

    for i in ["双眼皮", "隆鼻"]:
        print(word_similarity(i))

    # save_clicked_tractate_ids_item2vec()

    for id in ["84375", "148764", "368399"]:
        print(clicked_tractate_ids_item2vec_model(id, n=5))

    print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60))
