Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
strategy_embedding
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
rank
strategy_embedding
Commits
6d492c13
Commit
6d492c13
authored
Nov 16, 2020
by
赵威
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
save index
parent
086d0f85
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
8 deletions
+62
-8
get_data.py
personas_vector/get_data.py
+3
-3
to_vector.py
personas_vector/to_vector.py
+59
-5
No files found.
personas_vector/get_data.py
View file @
6d492c13
...
...
@@ -14,9 +14,9 @@ DATA_PATH = os.path.join(base_dir, "_data")
if
__name__
==
"__main__"
:
spark
=
get_spark
(
"personas_vector_data"
)
card_type
=
"user_post"
days
=
5
# TODO days 30
start
,
end
=
get_ndays_before_no_minus
(
days
),
get_ndays_before_no_minus
(
1
)
#
card_type = "user_post"
#
days = 5 # TODO days 30
#
start, end = get_ndays_before_no_minus(days), get_ndays_before_no_minus(1)
# click_df = get_click_data(spark, card_type, start, end)
# save_df_to_csv(click_df, "personas_tractate_click.csv")
...
...
personas_vector/to_vector.py
View file @
6d492c13
...
...
@@ -7,10 +7,13 @@ sys.path.append(os.path.realpath("."))
import
multiprocessing
import
faiss
import
numpy
as
np
import
pandas
as
pd
from
gensim.models
import
Word2Vec
,
word2vec
from
utils.defs
import
nth_element
from
utils.files
import
get_df
from
utils.cache
import
redis_client_db
def
device_tractate_fe
():
...
...
@@ -40,17 +43,68 @@ if __name__ == "__main__":
tags_data
=
tractate_tags_df
[
"business_tags"
]
.
to_list
()
model
=
tractate_business_tags_word2vec
(
tags_data
)
# all business tags
tags_set
=
set
()
for
i
in
tags_data
:
for
j
in
i
:
tags_set
.
add
(
j
)
# tag vector dict
tags_vector_dict
=
{}
for
i
in
tags_set
:
tags_vector_dict
[
i
]
=
json
.
dumps
(
model
.
wv
.
get_vector
(
i
))
try
:
# vec = json.dumps(model.wv.get_vector(i).tolist())
vec
=
model
.
wv
.
get_vector
(
i
)
tags_vector_dict
[
i
]
=
vec
except
Exception
as
e
:
pass
redis_client_db
.
hmset
(
"personas_tags_embedding"
,
tags_vector_dict
)
print
(
random
.
choice
(
tags_vector_dict
.
items
()))
print
(
len
(
tags_vector_dict
.
items
()))
# print(random.choice(list(tags_vector_dict.items())))
# for i in ["自体脂肪面部年轻化", "自体脂肪填充面部", "自体脂肪全面部填充", "自体脂肪面部填充", "鼻综合", "鼻部综合"]:
# print(model.wv.most_similar(i))
# print(model.wv.get_vector(i))
# tractate vector dict
tractate_vector_dict
=
{}
for
_
,
row
in
tractate_tags_df
.
iterrows
():
vecs
=
[]
for
i
in
row
[
"business_tags"
]:
vec
=
tags_vector_dict
.
get
(
i
,
np
.
array
([]))
if
vec
.
any
():
vecs
.
append
(
vec
)
if
vecs
:
tractate_vector_dict
[
row
[
"tractate_id"
]]
=
np
.
average
(
vecs
,
axis
=
0
)
print
(
len
(
tractate_vector_dict
.
items
()))
# print(random.choice(list(tractate_vector_dict.items())))
# tractate vector index
tractate_ids
=
np
.
array
(
list
(
tractate_vector_dict
.
keys
()))
.
astype
(
"int"
)
tractate_embeddings
=
np
.
array
(
list
(
tractate_vector_dict
.
values
()))
.
astype
(
"float32"
)
index
=
faiss
.
IndexFlatL2
(
tractate_embeddings
.
shape
[
1
])
print
(
"trained: "
+
str
(
index
.
is_trained
))
index2
=
faiss
.
IndexIDMap
(
index
)
index2
.
add_with_ids
(
tractate_embeddings
,
tractate_ids
)
print
(
"trained: "
+
str
(
index2
.
is_trained
))
print
(
"total index: "
+
str
(
index2
.
ntotal
))
base_dir
=
os
.
getcwd
()
model_dir
=
os
.
path
.
join
(
base_dir
,
"_models"
)
index_path
=
os
.
path
.
join
(
model_dir
,
"faiss_personas_vector.index"
)
faiss
.
write_index
(
index2
,
index_path
)
print
(
index_path
)
# device vector
# for _, row in device_tags_df.iterrows():
# vecs = []
# for i in row["business_tags"]:
# vec = tags_vector_dict.get(i, np.array([]))
# if vec.any():
# vecs.append(vec)
# if vecs:
# t = np.array([np.average(vecs, axis=0)]).astype("float32")
# D, I = index2.search(t, 10)
# print(row["cl_id"], row["business_tags"])
# print(I)
# curl "http://172.16.31.17:9000/gm-dbmw-tractate-read/_search?pretty" -d '{"query": {"term": {"id": "10269"}}, "_source": {"include": ["content", "portrait_tag_name"]}}'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment