# -*- coding: UTF-8 -*- # !/usr/bin/env python import numpy as np import redis from libs.cache import redis_client from trans2es.models.tag import Tag import logging import traceback import json import pickle class LinUCB: d = 2 alpha = 0.25 r1 = 1 r0 = -16 default_tag_list = list() @classmethod def get_default_tag_list(cls): try: if len(cls.default_tag_list) == 0: query_item_results = Tag.objects.filter(is_online=True) for item in query_item_results: cls.default_tag_list.append(item.id) return cls.default_tag_list[:20] except: logging.error("catch exception,err_msg:%s" % traceback.format_exc()) return list() @classmethod def linucb_recommend_tag(cls,device_id,redis_linucb_tag_data_dict,user_features_list,tag_list): """ :remark 获取推荐标签 :param redis_linucb_tag_data_dict: :param user_features_list: :param tag_list: :return: """ try: Aa_list = list() theta_list = list() for tag_id in tag_list: tag_dict = pickle.loads(redis_linucb_tag_data_dict[tag_id]) Aa_list.append(tag_dict["Aa"]) theta_list.append(tag_dict["theta"]) xaT = np.array([user_features_list]) xa = np.transpose(xaT) art_max = -1 old_pa = 0 AaI_tmp = np.array(Aa_list) theta_tmp = np.array(theta_list) np_array = np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)) # top_tag_list_len = int(np_array.size/2) # top_np_ind = np.argpartition(np_array, -top_tag_list_len)[-top_tag_list_len:] # # top_tag_list = list() # top_np_list = top_np_ind.tolist() # for tag_id in top_np_list: # top_tag_list.append(tag_id) #art_max = tag_list[np.argmax(np.dot(xaT, theta_tmp) + cls.alpha * np.sqrt(np.dot(np.dot(xaT, AaI_tmp), xa)))] top_tag_set = set() np_score_list = list() np_score_dict = dict() for score_index in range(0,np_array.size): score = np_array.take(score_index) np_score_list.append(score) if score not in np_score_dict: np_score_dict[score] = [score_index] else: np_score_dict[score].append(score_index) sorted_np_score_list = sorted(np_score_list,reverse=True) for top_score in sorted_np_score_list: for top_score_index in np_score_dict[top_score]: top_tag_set.add(str(tag_list[top_score_index], encoding="utf-8")) if len(top_tag_set) >= 10: break logging.info("duan add,device_id:%s,sorted_np_score_list:%s,np_score_dict:%s" % (str(device_id), str(sorted_np_score_list), str(np_score_dict))) return list(top_tag_set) except: logging.error("catch exception,err_msg:%s" % traceback.format_exc()) return [] @classmethod def init_device_id_linucb_info(cls, redis_cli,redis_prefix, device_id, tag_list): try: redis_key = redis_prefix + str(device_id) user_tag_linucb_dict = dict() for tag_id in tag_list: init_dict = { "Aa": np.identity(cls.d), "theta": np.zeros((cls.d, 1)), "ba": np.zeros((cls.d, 1)), "AaI": np.identity(cls.d) } pickle_data = pickle.dumps(init_dict) user_tag_linucb_dict[tag_id] = pickle_data redis_cli.hmset(redis_key, user_tag_linucb_dict) return True except: logging.error("catch exception,err_msg:%s" % traceback.format_exc()) return False @classmethod def update_linucb_info(cls, user_features,reward, tag_id, device_id, redis_prefix,redis_cli): try: if reward == -1: logging.warning("reward val error!") elif reward == 1 or reward == 0: if reward == 1: r = cls.r1 else: r = cls.r0 xaT = np.array([user_features]) xa = np.transpose(xaT) redis_key = redis_prefix + str(device_id) ori_redis_tag_data = redis_cli.hget(redis_key, tag_id) if not ori_redis_tag_data: LinUCB.init_device_id_linucb_info(redis_client, redis_prefix, device_id,[tag_id]) else: ori_redis_tag_dict = pickle.loads(ori_redis_tag_data) new_Aa_matrix = ori_redis_tag_dict["Aa"] + np.dot(xa, xaT) new_AaI_matrix = np.linalg.solve(new_Aa_matrix, np.identity(cls.d)) new_ba_matrix = ori_redis_tag_dict["ba"] + r*xa user_tag_dict = { "Aa": new_Aa_matrix, "ba": new_ba_matrix, "AaI": new_AaI_matrix, "theta": np.dot(new_AaI_matrix, new_ba_matrix) } redis_cli.hset(redis_key, tag_id, pickle.dumps(user_tag_dict)) return True except: logging.error("catch exception,err_msg:%s" % traceback.format_exc()) return False