Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
G
gm_mab
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
rank
gm_mab
Commits
9ce47566
Commit
9ce47566
authored
Apr 17, 2020
by
段英荣
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
增加初始代码
parents
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
228 additions
and
0 deletions
+228
-0
README.md
README.md
+0
-0
cache.py
libs/cache.py
+8
-0
__init__.py
linucb/__init__.py
+0
-0
Linucb.py
linucb/core/Linucb.py
+220
-0
No files found.
README.md
0 → 100644
View file @
9ce47566
libs/cache.py
0 → 100644
View file @
9ce47566
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
redis
redis_url
=
"redis://:ReDis!GmTx*0aN9@172.16.40.173:6379"
redis_client
=
redis
.
StrictRedis
.
from_url
(
redis_url
)
\ No newline at end of file
linucb/__init__.py
0 → 100644
View file @
9ce47566
linucb/core/Linucb.py
0 → 100644
View file @
9ce47566
# -*- coding: UTF-8 -*-
# !/usr/bin/env python
import
numpy
as
np
import
redis
from
libs.cache
import
redis_client
import
logging
import
traceback
import
json
import
pickle
import
pymysql
from
libs.error
import
logging_exception
import
random
class
LinUCB
:
d
=
6
alpha
=
0.01
r1
=
10
r0
=
-
0.1
default_tag_list
=
list
()
zhengxing_host
=
"172.16.30.141"
zhengxing_user
=
"work"
zhengxing_password
=
"BJQaT9VzDcuPBqkd"
zhengxing_database
=
"zhengxing"
redis_name_linucb_matrix_prefix
=
"strategy:linucb:content_type:"
@classmethod
def
linucb_recommend_tag
(
cls
,
device_id
,
redis_linucb_tag_data_dict
,
user_features_list
,
tag_list
):
"""
:remark 获取推荐标签
:param redis_linucb_tag_data_dict:
:param user_features_list:
:param tag_list:
:return:
"""
try
:
Aa_list
=
list
()
theta_list
=
list
()
for
tag_id
in
tag_list
:
tag_dict
=
pickle
.
loads
(
redis_linucb_tag_data_dict
[
tag_id
])
Aa_list
.
append
(
tag_dict
[
"Aa"
])
theta_list
.
append
(
tag_dict
[
"theta"
])
xaT
=
np
.
array
([
user_features_list
])
xa
=
np
.
transpose
(
xaT
)
art_max
=
-
1
old_pa
=
0
AaI_tmp
=
np
.
array
(
Aa_list
)
theta_tmp
=
np
.
array
(
theta_list
)
np_array
=
np
.
dot
(
xaT
,
theta_tmp
)
+
cls
.
alpha
*
np
.
sqrt
(
np
.
dot
(
np
.
dot
(
xaT
,
AaI_tmp
),
xa
))
# top_tag_list_len = int(np_array.size/2)
# top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
#
# top_tag_list = list()
# top_np_list = top_np_ind.tolist()
# for tag_id in top_np_list:
# top_tag_list.append(tag_id)
#art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]
top_tag_set
=
set
()
top_tag_dict
=
dict
()
np_score_list
=
list
()
np_score_dict
=
dict
()
for
score_index
in
range
(
0
,
np_array
.
size
):
score
=
np_array
.
take
(
score_index
)
np_score_list
.
append
(
score
)
if
score
not
in
np_score_dict
:
np_score_dict
[
score
]
=
[
score_index
]
else
:
np_score_dict
[
score
]
.
append
(
score_index
)
sorted_np_score_list
=
sorted
(
np_score_list
,
reverse
=
True
)
for
top_score
in
sorted_np_score_list
:
for
top_score_index
in
np_score_dict
[
top_score
]:
tag_id
=
str
(
tag_list
[
top_score_index
],
encoding
=
"utf-8"
)
top_tag_dict
[
tag_id
]
=
top_score
top_tag_set
.
add
(
tag_id
)
if
len
(
top_tag_dict
)
>=
20
:
break
if
len
(
top_tag_dict
)
>=
20
:
break
logging
.
info
(
"duan add,device_id:
%
s,sorted_np_score_list:
%
s,np_score_dict:
%
s"
%
(
str
(
device_id
),
str
(
sorted_np_score_list
),
str
(
np_score_dict
)))
return
(
top_tag_dict
,
top_tag_set
)
except
:
logging_exception
()
logging
.
error
(
"catch exception,err_msg:
%
s"
%
traceback
.
format_exc
())
return
({},())
@classmethod
def
init_all_arm_by_card_content
(
cls
,
card_content
=
"diary"
,
user_features_list
=
list
()):
try
:
redis_name_linucb_matrix
=
cls
.
redis_name_linucb_matrix_prefix
+
card_content
if
card_content
==
"diary"
:
zhengxing_conn
=
pymysql
.
connect
(
host
=
cls
.
zhengxing_host
,
user
=
cls
.
zhengxing_user
,
password
=
cls
.
zhengxing_password
,
database
=
cls
.
zhengxing_database
,
charset
=
"utf8"
)
zhengxing_cursor
=
zhengxing_conn
.
cursor
()
diary_id_sql
=
"select id from api_diary where is_online=true and content_level in (5,6);"
diary_id_list
=
list
()
zhengxing_cursor
.
execute
(
diary_id_sql
)
sql_tag_results
=
zhengxing_cursor
.
fetchall
()
for
item
in
sql_tag_results
:
diary_id
=
int
(
item
[
0
])
diary_id_list
.
append
(
diary_id
)
for
diary_id
in
diary_id_list
:
init_dict
=
{
"Aa"
:
np
.
identity
(
cls
.
d
),
"theta"
:
np
.
zeros
((
cls
.
d
,
1
)),
"ba"
:
np
.
zeros
((
cls
.
d
,
1
)),
"AaI"
:
np
.
identity
(
cls
.
d
)
}
pickle_data
=
pickle
.
dumps
(
init_dict
)
redis_client
.
hset
(
redis_name_linucb_matrix
,
diary_id
,
pickle_data
)
user_feature_index
=
random
.
randint
(
0
,
9
)
user_feature
=
user_features_list
[
user_feature_index
]
cls
.
update_linucb_info
(
user_feature
,
1
,
diary_id
,
redis_name_linucb_matrix
,
redis_client
)
print
(
str
(
user_feature
)
+
"
\t
"
+
str
(
diary_id
))
except
:
logging
.
error
(
"catch exception,err_msg:
%
s"
%
traceback
.
format_exc
())
return
False
@classmethod
def
init_device_id_linucb_info
(
cls
,
redis_cli
,
redis_name_linucb_matrix
,
tag_list
):
try
:
user_tag_linucb_dict
=
dict
()
for
tag_id
in
tag_list
:
init_dict
=
{
"Aa"
:
np
.
identity
(
cls
.
d
),
"theta"
:
np
.
zeros
((
cls
.
d
,
1
)),
"ba"
:
np
.
zeros
((
cls
.
d
,
1
)),
"AaI"
:
np
.
identity
(
cls
.
d
)
}
pickle_data
=
pickle
.
dumps
(
init_dict
)
user_tag_linucb_dict
[
tag_id
]
=
pickle_data
redis_cli
.
hmset
(
redis_name_linucb_matrix
,
user_tag_linucb_dict
)
return
True
except
:
logging_exception
()
logging
.
error
(
"catch exception,err_msg:
%
s"
%
traceback
.
format_exc
())
return
False
@classmethod
def
update_linucb_info
(
cls
,
user_features
,
reward
,
content_id
,
redis_name_linucb_matrix
,
redis_cli
):
try
:
if
reward
==
-
1
:
logging
.
warning
(
"reward val error!"
)
elif
reward
==
1
or
reward
==
0
:
if
reward
==
1
:
r
=
cls
.
r1
else
:
r
=
cls
.
r0
xaT
=
np
.
array
([
user_features
])
xa
=
np
.
transpose
(
xaT
)
ori_redis_tag_data
=
redis_cli
.
hget
(
redis_name_linucb_matrix
,
content_id
)
if
not
ori_redis_tag_data
:
LinUCB
.
init_device_id_linucb_info
(
redis_client
,
redis_name_linucb_matrix
,[
content_id
])
else
:
ori_redis_tag_dict
=
pickle
.
loads
(
ori_redis_tag_data
)
new_Aa_matrix
=
ori_redis_tag_dict
[
"Aa"
]
+
np
.
dot
(
xa
,
xaT
)
new_AaI_matrix
=
np
.
linalg
.
solve
(
new_Aa_matrix
,
np
.
identity
(
cls
.
d
))
new_ba_matrix
=
ori_redis_tag_dict
[
"ba"
]
+
r
*
xa
user_tag_dict
=
{
"Aa"
:
new_Aa_matrix
,
"ba"
:
new_ba_matrix
,
"AaI"
:
new_AaI_matrix
,
"theta"
:
np
.
dot
(
new_AaI_matrix
,
new_ba_matrix
)
}
redis_cli
.
hset
(
redis_name_linucb_matrix
,
content_id
,
pickle
.
dumps
(
user_tag_dict
))
else
:
logging
.
warning
(
"not standard linucb reward"
)
return
True
except
:
logging_exception
()
logging
.
error
(
"catch exception,err_msg:
%
s"
%
traceback
.
format_exc
())
return
False
if
__name__
==
"__main__"
:
user_features
=
[
[
1
,
2
,
1
,
1
,
3
,
1
],
[
1
,
4
,
3
,
1
,
3
,
1
],
[
3
,
2
,
1
,
5
,
5
,
1
],
[
1
,
2
,
4
,
2
,
3
,
2
],
[
1
,
5
,
7
,
1
,
4
,
4
],
[
3
,
4
,
1
,
1
,
3
,
1
],
[
5
,
2
,
1
,
6
,
3
,
1
],
[
1
,
2
,
3
,
2
,
3
,
5
],
[
1
,
2
,
1
,
1
,
2
,
4
],
[
1
,
2
,
6
,
4
,
2
,
1
],
]
LinUCB
.
init_all_arm_by_card_content
(
user_features_list
=
user_features
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment