import json
import os
import random
import sys

sys.path.append(os.path.realpath("."))

import faiss
import numpy as np
from bert_serving.client import BertClient
from utils.cache import redis_client_db
from utils.es import get_diary_info_from_es
from utils.files import DATA_PATH, MODEL_PATH


def save_result():
    bc = BertClient("172.16.44.82", check_length=False)

    index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index")
    faiss_index = faiss.read_index(index_path)

    count = 0
    for item in get_diary_info_from_es(["id", "answer", "content_level"]):
        count += 1
        id = int(item["_id"])
        content = item["_source"]["answer"]
        content_level = str(item["_source"]["content_level"])
        try:
            emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32")
            D, I = faiss_index.search(emb, 10)
            distances = D.tolist()[0]
            ids = I.tolist()[0]
            res = []
            for (index, i) in enumerate(distances):
                tmp_id = ids[index]
                if i <= 1.0 and tmp_id != id:
                    res.append(str(tmp_id))
            if res:
                data = "{}:{}:{}".format(content_level, str(id), ",".join(res))
                print(data)
        except Exception as e:
            pass
    print("done")


if __name__ == "__main__":
    # bc = BertClient("172.16.44.82", check_length=False)

    # level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []}
    # count = 0
    # embedding_dict = {}
    # for item in get_diary_info_from_es(["id", "answer", "content_level"]):
    #     count += 1
    #     id = int(item["_id"])
    #     print(count, id)
    #     content = item["_source"]["answer"]
    #     content_level = str(item["_source"]["content_level"])
    #     level_dict[content_level].append(id)
    #     try:
    #         embedding_dict[id] = bc.encode([content]).tolist()[0]
    #     except Exception as e:
    #         pass

    # # redis_client_db.hmset("diary:level_dict", json.dumps(level_dict))

    # tmp_tuple = random.choice(list(embedding_dict.items()))
    # print(tmp_tuple)
    # diary_ids = np.array(list(embedding_dict.keys())).astype("int")
    # diary_embeddings = np.array(list(embedding_dict.values())).astype("float32")
    # print(diary_embeddings.shape)

    # index = faiss.IndexFlatL2(diary_embeddings.shape[1])
    # print("trained: " + str(index.is_trained))

    # index2 = faiss.IndexIDMap(index)
    # index2.add_with_ids(diary_embeddings, diary_ids)
    # print("trained: " + str(index2.is_trained))
    # print("total index: " + str(index2.ntotal))

    # index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index")
    # faiss.write_index(index2, index_path)
    # print(index_path)

    # id = tmp_tuple[0]
    # emb = np.array([embedding_dict[id]]).astype("float32")
    # print(emb)
    # D, I = index2.search(emb, 10)
    # distances = D.tolist()[0]
    # ids = I.tolist()[0]
    # res = []
    # for (index, i) in enumerate(distances):
    #     if i <= 1.0:
    #         res.append(ids[index])
    # print(res, "\n")
    # print(ids, "\n")
    # print(D)

    save_result()