Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
strategy_embedding
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
rank
strategy_embedding
Commits
49040396
Commit
49040396
authored
4 years ago
by
赵威
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update valiadte item2vec
parent
d8cf717f
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
133 additions
and
8 deletions
+133
-8
es.py
utils/es.py
+89
-7
tractate.py
word_vector/tractate.py
+44
-1
No files found.
utils/es.py
View file @
49040396
...
...
@@ -12,7 +12,7 @@ def es_index_adapt(index_prefix, doc_type, rw=None):
return
index
def
get_es
():
def
get_es
(
is_es7
=
False
,
es_router_key
=
"default"
):
init_args
=
{
"sniff_on_start"
:
False
,
"sniff_on_connection_fail"
:
False
,
...
...
@@ -27,21 +27,37 @@ def get_es():
"host"
:
"172.16.31.13"
,
"port"
:
9000
,
}]
new_es
=
Es
(
hosts
=
new_hosts
,
**
init_args
)
if
is_es7
:
clusters
=
{
"content"
:
[
"172.16.52.33:9200"
,
"172.16.52.19:9200"
,
"172.16.52.27:9200"
,
"172.16.52.34:9200"
,
"172.16.52.48:9200"
],
"service"
:
[
"172.16.52.25:9200"
,
"172.16.52.36:9200"
,
"172.16.52.26:9200"
]
}
es_router
=
{
"default"
:
"content"
,
"diary"
:
"content"
,
"service"
:
"service"
,
"tractate"
:
"content"
,
"answer"
:
"content"
}
http_auth
=
(
"elastic"
,
"gengmei!@#"
)
new_es
=
Es
(
hosts
=
clusters
[
es_router
[
es_router_key
]],
http_auth
=
http_auth
,
**
init_args
)
else
:
new_es
=
Es
(
hosts
=
new_hosts
,
**
init_args
)
return
new_es
def
es_query
(
doc
,
body
,
offset
=
0
,
size
=
100
,
es
=
None
,
rw
=
"read"
):
def
es_query
(
doc
,
body
,
offset
=
0
,
size
=
100
,
es
=
None
,
rw
=
"read"
,
is_es7
=
False
,
es_router_key
=
"default"
):
if
es
is
None
:
es
=
get_es
()
es
=
get_es
(
is_es7
=
is_es7
,
es_router_key
=
es_router_key
)
index
=
es_index_adapt
(
index_prefix
=
"gm-dbmw"
,
doc_type
=
doc
,
rw
=
rw
)
res
=
es
.
search
(
index
=
index
,
timeout
=
"10s"
,
body
=
body
,
from_
=
offset
,
size
=
size
)
res
=
es
.
search
(
index
=
index
,
timeout
=
"2s"
,
body
=
body
,
from_
=
offset
,
size
=
size
)
return
res
def
es_msearch
(
query_body
,
es
=
None
,
es_router_key
=
"default"
):
if
es
is
None
:
es
=
get_es
(
is_es7
=
True
,
es_router_key
=
es_router_key
)
res
=
es
.
msearch
(
query_body
)
return
res
def
es_scan
(
doc
,
body
,
es
=
None
,
rw
=
"read"
):
def
es_scan
(
doc
,
body
,
es
=
None
,
rw
=
"read"
,
is_es7
=
False
,
es_router_key
=
"default"
):
if
es
is
None
:
es
=
get_es
()
es
=
get_es
(
is_es7
=
is_es7
,
es_router_key
=
es_router_key
)
index
=
es_index_adapt
(
index_prefix
=
"gm-dbmw"
,
doc_type
=
doc
,
rw
=
rw
)
return
helpers
.
scan
(
es
,
index
=
index
,
query
=
body
,
request_timeout
=
100
,
scroll
=
"300m"
,
raise_on_error
=
False
)
...
...
@@ -160,3 +176,69 @@ def get_online_ids(content_type):
q
=
{
"query"
:
{
"bool"
:
{
"must"
:
[{
"term"
:
{
"is_online"
:
True
}}]}}}
results
=
es_scan
(
content_type
,
q
)
return
results
def
get_tag3_info_from_content
(
id
,
es_doc_name
,
action
=
""
):
first_demands
=
[]
second_demands
=
[]
first_solutions
=
[]
second_solutions
=
[]
first_positions
=
[]
second_positions
=
[]
projects
=
[]
anecdote_tags
=
[]
gossip_tags
=
[]
business_tags
=
[]
fields
=
[
"first_demands"
,
"second_demands"
,
"first_solutions"
,
"second_solutions"
,
"positions"
,
"second_positions"
,
"tags_v3"
,
"anecdote_tags"
,
"portrait_tag_name"
]
if
es_doc_name
in
[
"answer"
,
"tractate"
]:
fields
.
append
(
"gossip_tags"
)
q
=
{
"query"
:
{
"term"
:
{
"id"
:
id
}},
"_source"
:
{
"includes"
:
fields
}}
try
:
es_res
=
es_query
(
es_doc_name
,
q
,
is_es7
=
True
)
hits
=
es_res
[
"hits"
][
"hits"
]
for
hit
in
hits
:
first_demands
=
hit
[
"_source"
]
.
get
(
"first_demands"
,
[])
second_demands
=
hit
[
"_source"
]
.
get
(
"second_demands"
,
[])
first_solutions
=
hit
[
"_source"
]
.
get
(
"first_solutions"
,
[])
second_solutions
=
hit
[
"_source"
]
.
get
(
"second_solutions"
,
[])
first_positions
=
hit
[
"_source"
]
.
get
(
"positions"
,
[])
second_positions
=
hit
[
"_source"
]
.
get
(
"second_positions"
,
[])
projects
=
hit
[
"_source"
]
.
get
(
"tags_v3"
,
[])
anecdote_tags
=
hit
[
"_source"
]
.
get
(
"anecdote_tags"
,
[])
business_tags
=
hit
[
"_source"
]
.
get
(
"portrait_tag_name"
,
[])
gossip_tags
=
hit
[
"_source"
]
.
get
(
"gossip_tags"
,
[])
res
=
{
"action"
:
action
,
"es_doc_name"
:
es_doc_name
,
"first_demands"
:
first_demands
,
"second_demands"
:
second_demands
,
"first_solutions"
:
first_solutions
,
"second_solutions"
:
second_solutions
,
"first_positions"
:
first_positions
,
"second_positions"
:
second_positions
,
"projects"
:
projects
,
"anecdote_tags"
:
anecdote_tags
,
"business_tags"
:
business_tags
,
"gossip_tags"
:
gossip_tags
}
return
res
except
Exception
as
e
:
res
=
{
"action"
:
action
,
"es_doc_name"
:
es_doc_name
,
"first_demands"
:
first_demands
,
"second_demands"
:
second_demands
,
"first_solutions"
:
first_solutions
,
"second_solutions"
:
second_solutions
,
"first_positions"
:
first_positions
,
"second_positions"
:
second_positions
,
"projects"
:
projects
,
"anecdote_tags"
:
anecdote_tags
,
"business_tags"
:
business_tags
,
"gossip_tags"
:
gossip_tags
}
return
res
This diff is collapsed.
Click to expand it.
word_vector/tractate.py
View file @
49040396
import
multiprocessing
import
os
import
random
import
sys
import
time
from
collections
import
defaultdict
...
...
@@ -8,6 +9,7 @@ sys.path.append(os.path.realpath("."))
from
gensim.models
import
Word2Vec
,
word2vec
from
utils.date
import
get_ndays_before_no_minus
,
get_ndays_before_with_format
from
utils.es
import
get_tag3_info_from_content
from
utils.files
import
DATA_PATH
,
MODEL_PATH
from
utils.spark
import
get_spark
...
...
@@ -151,6 +153,39 @@ def save_clicked_tractate_ids_item2vec():
return
model
def
get_tractate_item2vec_by_id
(
tractate_id
,
model
,
score_limit
=
0.8
,
topn
=
100
):
res
=
[]
for
(
id
,
score
)
in
model
.
wv
.
most_similar
(
tractate_id
,
topn
=
topn
):
if
score
>=
score_limit
:
res
.
append
(
id
)
return
res
def
validate_tractate_item2vec_by_id
(
tractate_id
,
related_tractate_ids
):
projects
=
set
(
get_tag3_info_from_content
(
tractate_id
,
"tractate"
)
.
get
(
"projects"
,
[]))
res
=
[]
if
not
projects
:
return
0
for
id
in
related_tractate_ids
:
res
.
append
(
set
(
get_tag3_info_from_content
(
id
,
"tractate"
)
.
get
(
"projects"
,
[])))
total
=
len
(
res
)
count
=
0
for
i
in
res
:
if
len
(
projects
.
intersection
(
i
))
>
0
:
count
+=
1
if
total
==
0
:
return
0
return
count
/
total
def
validate_tractate_item2vec
(
tractate_ids
,
n
=
30
):
tractate_ids
=
random
.
sample
(
list
(
tractate_ids
),
n
)
for
tractate_id
in
tractate_ids
:
related_tractate_ids
=
get_tractate_item2vec_by_id
(
tractate_id
,
TRACTATE_CLICK_IDS_MODEL
)
score
=
validate_tractate_item2vec_by_id
(
tractate_id
,
related_tractate_ids
)
print
(
"{}: {}"
.
format
(
tractate_id
,
score
))
if
__name__
==
"__main__"
:
begin_time
=
time
.
time
()
...
...
@@ -168,9 +203,17 @@ if __name__ == "__main__":
save_clicked_tractate_ids_item2vec
()
TRACTATE_CLICK_IDS_MODEL
=
word2vec
.
Word2Vec
.
load
(
tractate_click_ids_model_path
)
TRACTATE_CLICK_IDS
=
set
(
TRACTATE_CLICK_IDS_MODEL
.
wv
.
vocab
.
keys
())
for
id
in
[
"84375"
,
"148764"
,
"368399"
]:
print
(
TRACTATE_CLICK_IDS_MODEL
.
wv
.
most_similar
(
id
,
topn
=
5
))
try
:
print
(
TRACTATE_CLICK_IDS_MODEL
.
wv
.
most_similar
(
id
,
topn
=
5
))
except
Exception
as
e
:
print
(
e
)
print
(
"total cost: {:.2f}mins"
.
format
((
time
.
time
()
-
begin_time
)
/
60
))
# validate_tractate_item2vec(TRACTATE_CLICK_IDS, 10)
# spark-submit --master yarn --deploy-mode client --queue root.strategy --driver-memory 16g --executor-memory 1g --executor-cores 1 --num-executors 70 --conf spark.default.parallelism=100 --conf spark.storage.memoryFraction=0.5 --conf spark.shuffle.memoryFraction=0.3 --conf spark.locality.wait=0 --jars /srv/apps/tispark-core-2.1-SNAPSHOT-jar-with-dependencies.jar,/srv/apps/spark-connector_2.11-1.9.0-rc2.jar,/srv/apps/mysql-connector-java-5.1.38.jar /srv/apps/strategy_embedding/word_vector/tractate.py
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment