Commit 574107f8 authored by 段英荣's avatar 段英荣

modifu

parent 29eae22e
......@@ -45,7 +45,7 @@ class LinUCB:
Aa_list = list()
theta_list = list()
for tag_id in redis_linucb_tag_data_dict:
for tag_id in tag_list:
tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
Aa_list.append(tag_dict["Aa"])
theta_list.append(tag_dict["theta"])
......@@ -60,17 +60,38 @@ class LinUCB:
theta_tmp = np.array(theta_list)
np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa))
top_tag_list_len = int(np_array.size/2)
top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
top_tag_list = list()
top_np_list = top_np_ind.tolist()
for tag_id in top_np_list:
top_tag_list.append(tag_id)
# top_tag_list_len = int(np_array.size/2)
# top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
#
# top_tag_list = list()
# top_np_list = top_np_ind.tolist()
# for tag_id in top_np_list:
# top_tag_list.append(tag_id)
#art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]
return top_tag_list
top_tag_set = set()
np_score_list = list()
np_score_dict = dict()
for score_index in range(0,np_array.size):
score = np_array.take(score_index)
np_score_list.append(score)
if score not in np_score_dict:
np_score_dict[score] = [score_index]
else:
np_score_dict[score].append(score_index)
sorted_np_score_list = sorted(np_score_list,reverse=True)
for top_score in sorted_np_score_list:
for top_score_index in np_score_dict[top_score]:
top_tag_set.add(tag_list[top_score_index])
if len(top_tag_set) >= 10:
break
return list(top_tag_set)
except:
logging.error("catch exception,err_msg:%s" % traceback.format_exc())
return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment