import json
import os
import random
import sys

sys.path.append(os.path.realpath("."))

import faiss
import numpy as np
from bert_serving.client import BertClient
from bs4 import BeautifulSoup
from utils.cache import redis_client_db
from utils.es import get_answer_info_from_es
from utils.files import MODEL_PATH


def cos_sim(vector_a, vector_b):
    """
    计算两个向量之间的余弦相似度
    :param vector_a: 向量 a
    :param vector_b: 向量 b
    :return: sim
    """
    vector_a = np.mat(vector_a)
    vector_b = np.mat(vector_b)
    num = float(vector_a * vector_b.T)
    denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
    cos = num / denom
    sim = 0.5 + 0.5 * cos
    return sim


def write_result():
    bc = BertClient("172.16.44.82", check_length=False)

    count = 0
    content_level_dict = {}
    embedding_dict = {}
    for item in get_answer_info_from_es(["id", "answer", "content_level"]):
        count += 1
        try:
            id = int(item["_id"])
            soup = BeautifulSoup(item["_source"]["answer"], "html.parser")
            content = soup.get_text()
            content_level = str(item["_source"]["content_level"])
            content_level_dict[id] = content_level
            # print(count, id, content)
            embedding_dict[id] = bc.encode([content]).tolist()[0]
        except Exception as e:
            pass

    answer_ids = np.array(list(embedding_dict.keys())).astype("int")
    answer_embeddings = np.array(list(embedding_dict.values())).astype("float32")
    print(answer_embeddings.shape)

    index = faiss.IndexFlatL2(answer_embeddings.shape[1])
    index2 = faiss.IndexIDMap(index)
    index2.add_with_ids(answer_embeddings, answer_ids)
    print("trained: " + str(index2.is_trained))
    print("total index: " + str(index2.ntotal))

    # index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index")
    # faiss.write_index(index2, index_path)
    # print(index_path)

    for (id, emb) in embedding_dict.items():
        emb = np.array([emb]).astype("float32")
        D, I = index2.search(emb, 10)
        distances = D.tolist()[0]
        ids = I.tolist()[0]
        res = []
        for (index, i) in enumerate(distances):
            tmp_id = ids[index]
            if i <= 1.0 and tmp_id != id:
                res.append(str(tmp_id))
        if res:
            data = "{}:{}:{}".format(content_level_dict.get(id, "-1"), str(id), ",".join(res))
            print(data)


def save_result():
    bc = BertClient("172.16.44.82", check_length=False)

    index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index")
    faiss_index = faiss.read_index(index_path)

    count = 0
    for item in get_answer_info_from_es(["id", "answer", "content_level"]):
        count += 1
        id = int(item["_id"])
        content = item["_source"]["answer"]
        content_level = str(item["_source"]["content_level"])
        try:
            emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32")
            D, I = faiss_index.search(emb, 10)
            distances = D.tolist()[0]
            ids = I.tolist()[0]
            res = []
            for (index, i) in enumerate(distances):
                tmp_id = ids[index]
                if i <= 1.0 and tmp_id != id:
                    res.append(str(tmp_id))
            if res:
                data = "{}:{}:{}".format(content_level, str(id), ",".join(res))
                print(data)
        except Exception as e:
            pass
    print("done")


if __name__ == "__main__":
    write_result()