Commit f5cab4c7 authored by 段英荣's avatar 段英荣

Merge branch 'linUCB' into 'master'

modify

See merge request !130
parents 19f84959 deac1811
...@@ -8,6 +8,7 @@ from trans2es.models.tag import Tag ...@@ -8,6 +8,7 @@ from trans2es.models.tag import Tag
import logging import logging
import traceback import traceback
import json import json
import pickle
class LinUCB: class LinUCB:
d = 2 d = 2
...@@ -45,7 +46,7 @@ class LinUCB: ...@@ -45,7 +46,7 @@ class LinUCB:
theta_list = list() theta_list = list()
for tag_id in redis_linucb_tag_data_dict: for tag_id in redis_linucb_tag_data_dict:
tag_dict = json.loads(redis_linucb_tag_data_dict[tag_id]) tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
Aa_list.append(tag_dict["Aa"]) Aa_list.append(tag_dict["Aa"])
theta_list.append(tag_dict["theta"]) theta_list.append(tag_dict["theta"])
...@@ -78,8 +79,8 @@ class LinUCB: ...@@ -78,8 +79,8 @@ class LinUCB:
"ba": np.zeros((cls.d, 1)), "ba": np.zeros((cls.d, 1)),
"AaI": np.identity(cls.d) "AaI": np.identity(cls.d)
} }
json_data = json.dumps(init_dict) pickle_data = pickle.dumps(init_dict)
user_tag_linucb_dict[tag_id] = json_data user_tag_linucb_dict[tag_id] = pickle_data
redis_cli.hmset(redis_key, user_tag_linucb_dict) redis_cli.hmset(redis_key, user_tag_linucb_dict)
...@@ -108,7 +109,7 @@ class LinUCB: ...@@ -108,7 +109,7 @@ class LinUCB:
if not ori_redis_tag_data: if not ori_redis_tag_data:
LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id]) LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id])
else: else:
ori_redis_tag_dict = json.loads(ori_redis_tag_data) ori_redis_tag_dict = pickle.loads(ori_redis_tag_data)
new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT) new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT)
new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d)) new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d))
new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa
...@@ -120,7 +121,7 @@ class LinUCB: ...@@ -120,7 +121,7 @@ class LinUCB:
"theta": np.dot(new_AaI_matrix, new_ba_matrix) "theta": np.dot(new_AaI_matrix, new_ba_matrix)
} }
redis_cli.hset(redis_key, tag_id, json.dumps(user_tag_dict)) redis_cli.hset(redis_key, tag_id, pickle.dumps(user_tag_dict))
return True return True
except: except:
logging.error("catch exception,err_msg:%s" % traceback.format_exc()) logging.error("catch exception,err_msg:%s" % traceback.format_exc())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment