Commit 49040396 authored by 赵威's avatar 赵威

update valiadte item2vec

parent d8cf717f
......@@ -12,7 +12,7 @@ def es_index_adapt(index_prefix, doc_type, rw=None):
return index
def get_es():
def get_es(is_es7=False, es_router_key="default"):
init_args = {
"sniff_on_start": False,
"sniff_on_connection_fail": False,
......@@ -27,21 +27,37 @@ def get_es():
"host": "",
"port": 9000,
new_es = Es(hosts=new_hosts, **init_args)
if is_es7:
clusters = {
"content": ["", "", "", "", ""],
"service": ["", "", ""]
es_router = {"default": "content", "diary": "content", "service": "service", "tractate": "content", "answer": "content"}
http_auth = ("elastic", "gengmei!@#")
new_es = Es(hosts=clusters[es_router[es_router_key]], http_auth=http_auth, **init_args)
new_es = Es(hosts=new_hosts, **init_args)
return new_es
def es_query(doc, body, offset=0, size=100, es=None, rw="read"):
def es_query(doc, body, offset=0, size=100, es=None, rw="read", is_es7=False, es_router_key="default"):
if es is None:
es = get_es()
es = get_es(is_es7=is_es7, es_router_key=es_router_key)
index = es_index_adapt(index_prefix="gm-dbmw", doc_type=doc, rw=rw)
res =, timeout="10s", body=body, from_=offset, size=size)
res =, timeout="2s", body=body, from_=offset, size=size)
return res
def es_msearch(query_body, es=None, es_router_key="default"):
if es is None:
es = get_es(is_es7=True, es_router_key=es_router_key)
res = es.msearch(query_body)
return res
def es_scan(doc, body, es=None, rw="read"):
def es_scan(doc, body, es=None, rw="read", is_es7=False, es_router_key="default"):
if es is None:
es = get_es()
es = get_es(is_es7=is_es7, es_router_key=es_router_key)
index = es_index_adapt(index_prefix="gm-dbmw", doc_type=doc, rw=rw)
return helpers.scan(es, index=index, query=body, request_timeout=100, scroll="300m", raise_on_error=False)
......@@ -160,3 +176,69 @@ def get_online_ids(content_type):
q = {"query": {"bool": {"must": [{"term": {"is_online": True}}]}}}
results = es_scan(content_type, q)
return results
def get_tag3_info_from_content(id, es_doc_name, action=""):
first_demands = []
second_demands = []
first_solutions = []
second_solutions = []
first_positions = []
second_positions = []
projects = []
anecdote_tags = []
gossip_tags = []
business_tags = []
fields = [
"first_demands", "second_demands", "first_solutions", "second_solutions", "positions", "second_positions", "tags_v3",
"anecdote_tags", "portrait_tag_name"
if es_doc_name in ["answer", "tractate"]:
q = {"query": {"term": {"id": id}}, "_source": {"includes": fields}}
es_res = es_query(es_doc_name, q, is_es7=True)
hits = es_res["hits"]["hits"]
for hit in hits:
first_demands = hit["_source"].get("first_demands", [])
second_demands = hit["_source"].get("second_demands", [])
first_solutions = hit["_source"].get("first_solutions", [])
second_solutions = hit["_source"].get("second_solutions", [])
first_positions = hit["_source"].get("positions", [])
second_positions = hit["_source"].get("second_positions", [])
projects = hit["_source"].get("tags_v3", [])
anecdote_tags = hit["_source"].get("anecdote_tags", [])
business_tags = hit["_source"].get("portrait_tag_name", [])
gossip_tags = hit["_source"].get("gossip_tags", [])
res = {
"action": action,
"es_doc_name": es_doc_name,
"first_demands": first_demands,
"second_demands": second_demands,
"first_solutions": first_solutions,
"second_solutions": second_solutions,
"first_positions": first_positions,
"second_positions": second_positions,
"projects": projects,
"anecdote_tags": anecdote_tags,
"business_tags": business_tags,
"gossip_tags": gossip_tags
return res
except Exception as e:
res = {
"action": action,
"es_doc_name": es_doc_name,
"first_demands": first_demands,
"second_demands": second_demands,
"first_solutions": first_solutions,
"second_solutions": second_solutions,
"first_positions": first_positions,
"second_positions": second_positions,
"projects": projects,
"anecdote_tags": anecdote_tags,
"business_tags": business_tags,
"gossip_tags": gossip_tags
return res
import multiprocessing
import os
import random
import sys
import time
from collections import defaultdict
......@@ -8,6 +9,7 @@ sys.path.append(os.path.realpath("."))
from gensim.models import Word2Vec, word2vec
from import get_ndays_before_no_minus, get_ndays_before_with_format
from import get_tag3_info_from_content
from utils.files import DATA_PATH, MODEL_PATH
from utils.spark import get_spark
......@@ -151,6 +153,39 @@ def save_clicked_tractate_ids_item2vec():
return model
def get_tractate_item2vec_by_id(tractate_id, model, score_limit=0.8, topn=100):
res = []
for (id, score) in model.wv.most_similar(tractate_id, topn=topn):
if score >= score_limit:
return res
def validate_tractate_item2vec_by_id(tractate_id, related_tractate_ids):
projects = set(get_tag3_info_from_content(tractate_id, "tractate").get("projects", []))
res = []
if not projects:
return 0
for id in related_tractate_ids:
res.append(set(get_tag3_info_from_content(id, "tractate").get("projects", [])))
total = len(res)
count = 0
for i in res:
if len(projects.intersection(i)) > 0:
count += 1
if total == 0:
return 0
return count / total
def validate_tractate_item2vec(tractate_ids, n=30):
tractate_ids = random.sample(list(tractate_ids), n)
for tractate_id in tractate_ids:
related_tractate_ids = get_tractate_item2vec_by_id(tractate_id, TRACTATE_CLICK_IDS_MODEL)
score = validate_tractate_item2vec_by_id(tractate_id, related_tractate_ids)
print("{}: {}".format(tractate_id, score))
if __name__ == "__main__":
begin_time = time.time()
......@@ -168,9 +203,17 @@ if __name__ == "__main__":
TRACTATE_CLICK_IDS_MODEL = word2vec.Word2Vec.load(tractate_click_ids_model_path)
for id in ["84375", "148764", "368399"]:
print(TRACTATE_CLICK_IDS_MODEL.wv.most_similar(id, topn=5))
print(TRACTATE_CLICK_IDS_MODEL.wv.most_similar(id, topn=5))
except Exception as e:
print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60))
# validate_tractate_item2vec(TRACTATE_CLICK_IDS, 10)
# spark-submit --master yarn --deploy-mode client --queue root.strategy --driver-memory 16g --executor-memory 1g --executor-cores 1 --num-executors 70 --conf spark.default.parallelism=100 --conf --conf spark.shuffle.memoryFraction=0.3 --conf spark.locality.wait=0 --jars /srv/apps/tispark-core-2.1-SNAPSHOT-jar-with-dependencies.jar,/srv/apps/spark-connector_2.11-1.9.0-rc2.jar,/srv/apps/mysql-connector-java-5.1.38.jar /srv/apps/strategy_embedding/word_vector/
