Commit 574107f8 authored by 段英荣's avatar 段英荣

modifu

parent 29eae22e
...@@ -45,7 +45,7 @@ class LinUCB: ...@@ -45,7 +45,7 @@ class LinUCB:
Aa_list = list() Aa_list = list()
theta_list = list() theta_list = list()
for tag_id in redis_linucb_tag_data_dict: for tag_id in tag_list:
tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id]) tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id])
Aa_list.append(tag_dict["Aa"]) Aa_list.append(tag_dict["Aa"])
theta_list.append(tag_dict["theta"]) theta_list.append(tag_dict["theta"])
...@@ -60,17 +60,38 @@ class LinUCB: ...@@ -60,17 +60,38 @@ class LinUCB:
theta_tmp = np.array(theta_list) theta_tmp = np.array(theta_list)
np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)) np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa))
top_tag_list_len = int(np_array.size/2) # top_tag_list_len = int(np_array.size/2)
top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:] # top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:]
#
top_tag_list = list() # top_tag_list = list()
top_np_list = top_np_ind.tolist() # top_np_list = top_np_ind.tolist()
for tag_id in top_np_list: # for tag_id in top_np_list:
top_tag_list.append(tag_id) # top_tag_list.append(tag_id)
#art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))] #art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))]
return top_tag_list
top_tag_set = set()
np_score_list = list()
np_score_dict = dict()
for score_index in range(0,np_array.size):
score = np_array.take(score_index)
np_score_list.append(score)
if score not in np_score_dict:
np_score_dict[score] = [score_index]
else:
np_score_dict[score].append(score_index)
sorted_np_score_list = sorted(np_score_list,reverse=True)
for top_score in sorted_np_score_list:
for top_score_index in np_score_dict[top_score]:
top_tag_set.add(tag_list[top_score_index])
if len(top_tag_set) >= 10:
break
return list(top_tag_set)
except: except:
logging.error("catch exception,err_msg:%s" % traceback.format_exc()) logging.error("catch exception,err_msg:%s" % traceback.format_exc())
return [] return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment