Commit 141ee133 authored by 段英荣's avatar 段英荣

modify

parent d7d30198
...@@ -7,6 +7,7 @@ from libs.cache import redis_client ...@@ -7,6 +7,7 @@ from libs.cache import redis_client
from trans2es.models.tag import Tag from trans2es.models.tag import Tag
import logging import logging
import traceback import traceback
import json
class LinUCB: class LinUCB:
d = 2 d = 2
...@@ -41,12 +42,12 @@ class LinUCB: ...@@ -41,12 +42,12 @@ class LinUCB:
""" """
try: try:
Aa_list = list() Aa_list = list()
for tag_id in redis_linucb_tag_data_dict:
Aa_list.append(redis_linucb_tag_data_dict[tag_id]["Aa"])
theta_list = list() theta_list = list()
for tag_id in redis_linucb_tag_data_dict: for tag_id in redis_linucb_tag_data_dict:
theta_list.append(redis_linucb_tag_data_dict[tag_id]["theta"]) tag_dict = json.loads(redis_linucb_tag_data_dict[tag_id])
Aa_list.append(tag_dict["Aa"])
theta_list.append(tag_dict["theta"])
xaT = np.array([user_features_list]) xaT = np.array([user_features_list])
xa = np.transpose(xaT) xa = np.transpose(xaT)
...@@ -71,12 +72,14 @@ class LinUCB: ...@@ -71,12 +72,14 @@ class LinUCB:
user_tag_linucb_dict = dict() user_tag_linucb_dict = dict()
for tag_id in tag_list: for tag_id in tag_list:
user_tag_linucb_dict[tag_id] = { init_dict = {
"Aa": np.identity(cls.d), "Aa": np.identity(cls.d),
"theta": np.zeros((cls.d, 1)), "theta": np.zeros((cls.d, 1)),
"ba": np.zeros((cls.d, 1)), "ba": np.zeros((cls.d, 1)),
"AaI": np.identity(cls.d) "AaI": np.identity(cls.d)
} }
json_data = json.dumps(init_dict)
user_tag_linucb_dict[tag_id] = json_data
redis_cli.hmset(redis_key, user_tag_linucb_dict) redis_cli.hmset(redis_key, user_tag_linucb_dict)
...@@ -100,11 +103,12 @@ class LinUCB: ...@@ -100,11 +103,12 @@ class LinUCB:
xa = np.transpose(xaT) xa = np.transpose(xaT)
redis_key = redis_prefix + str(device_id) redis_key = redis_prefix + str(device_id)
ori_redis_tag_dict = redis_cli.hget(redis_key, tag_id) ori_redis_tag_data = redis_cli.hget(redis_key, tag_id)
if not ori_redis_tag_dict: if not ori_redis_tag_data:
LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id]) LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id])
else: else:
ori_redis_tag_dict = json.loads(ori_redis_tag_data)
new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT) new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT)
new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d)) new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d))
new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa
...@@ -116,7 +120,7 @@ class LinUCB: ...@@ -116,7 +120,7 @@ class LinUCB:
"theta": np.dot(new_AaI_matrix, new_ba_matrix) "theta": np.dot(new_AaI_matrix, new_ba_matrix)
} }
redis_cli.hset(redis_key, tag_id, user_tag_dict) redis_cli.hset(redis_key, tag_id, json.dumps(user_tag_dict))
return True return True
except: except:
logging.error("catch exception,err_msg:%s" % traceback.format_exc()) logging.error("catch exception,err_msg:%s" % traceback.format_exc())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment