featureEng_copy.py 41.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
import sys
import os

from datetime import date, timedelta
from elasticsearch import Elasticsearch
from elasticsearch.helpers import scan

import time
import redis
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import pyspark.sql as sql
from pyspark.sql.functions import when
from pyspark.sql.types import *
from pyspark.sql import functions as F

from collections import defaultdict
import json

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
import utils.configUtils as configUtils
# import utils.connUtils as connUtils
import pandas as pd
import math

# os.environ["PYSPARK_PYTHON"]="/usr/bin/python3"

"""
    特征工程
"""
NUMBER_PRECISION = 2

VERSION = configUtils.SERVICE_VERSION
FEATURE_USER_KEY = "Strategy:rec:feature:service:" + VERSION + ":user:"
FEATURE_ITEM_KEY = "Strategy:rec:feature:service:" + VERSION + ":item:"
FEATURE_VOCAB_KEY = "Strategy:rec:vocab:service:" + VERSION
FEATURE_COLUMN_KEY = "Strategy:rec:column:service:" + VERSION

ITEM_PREFIX = "item_"
DATA_PATH_TRAIN = "/data/files/service_feature_{}_train.csv".format(VERSION)


def getRedisConn():
    pool = redis.ConnectionPool(host="172.16.50.145", password="XfkMCCdWDIU%ls$h", port=6379, db=0)
    conn = redis.Redis(connection_pool=pool)
    # conn = redis.Redis(host="172.16.50.145", port=6379, password="XfkMCCdWDIU%ls$h",db=0)
    # conn = redis.Redis(host="172.18.51.10", port=6379,db=0) #test
    return conn


def parseTags(tags, i):
    tags_arr = tags.split(",")
    if len(tags_arr) >= i:
        return tags_arr[i - 1]
    else:
        return "-1"


def numberToBucket(num):
    res = 0
    if not num:
        return str(res)
    if num >= 1000:
        res = 1000 // 10
    else:
        res = int(num) // 10
    return str(res)


def priceToBucket(num):
    res = 0
    if not num:
        return str(res)
    if num >= 100000:
        res = 100000 // 1000
    else:
        res = int(num) // 1000
    return str(res)


numberToBucketUdf = F.udf(numberToBucket, StringType())
priceToBucketUdf = F.udf(priceToBucket, StringType())


def addItemStaticFeatures(samples, itemDF, dataVocab):
    ctrUdf = F.udf(wilson_ctr, FloatType())
    # item不设置over窗口,原因:item可能一直存在,统计数据按照最新即可
    print("item统计特征处理...")
    staticFeatures = samples.groupBy('item_id').agg(F.count(F.lit(1)).alias('itemRatingCount'),
                                                    F.avg(F.col('rating')).alias('itemRatingAvg'),
                                                    F.stddev(F.col('rating')).alias('itemRatingStddev'),
                                                    F.sum(
                                                        when(F.col('label') == 1, F.lit(1)).otherwise(F.lit(0))).alias(
                                                        "itemClickCount"),
                                                    F.sum(
                                                        when(F.col('label') == 0, F.lit(1)).otherwise(F.lit(0))).alias(
                                                        "itemExpCount")
                                                    ).fillna(0) \
        .withColumn('itemRatingStddev', F.format_number(F.col('itemRatingStddev'), NUMBER_PRECISION).cast("float")) \
        .withColumn('itemRatingAvg', F.format_number(F.col('itemRatingAvg'), NUMBER_PRECISION).cast("float")) \
        .withColumn('itemCtr',
                    F.format_number(ctrUdf(F.col("itemClickCount"), (F.col("itemExpCount"))), NUMBER_PRECISION).cast(
                        "float"))

    staticFeatures.show(20, truncate=False)

    staticFeatures = itemDF.join(staticFeatures, on=["item_id"], how='left')

    # 连续特征分桶
    bucket_vocab = [str(i) for i in range(101)]
    bucket_suffix = "_Bucket"
    for col in ["itemRatingCount", "itemRatingAvg", "itemClickCount", "itemExpCount"]:
        new_col = col + bucket_suffix
        staticFeatures = staticFeatures.withColumn(new_col, numberToBucketUdf(F.col(col))) \
            .drop(col) \
            .withColumn(new_col, F.when(F.col(new_col).isNull(), "0").otherwise(F.col(new_col)))
        dataVocab[new_col] = bucket_vocab

    # 方差处理
    number_suffix = "_number"
    for col in ["itemRatingStddev"]:
        new_col = col + number_suffix
        staticFeatures = staticFeatures.withColumn(new_col,
                                                   F.when(F.col(col).isNull(), 0).otherwise(1 / (F.col(col) + 1))).drop(
            col)
    for col in ["itemCtr"]:
        new_col = col + number_suffix
        staticFeatures = staticFeatures.withColumn(col, F.when(F.col(col).isNull(), 0).otherwise(
            F.col(col))).withColumnRenamed(col, new_col)

    print("item size:", staticFeatures.count())

    staticFeatures.show(5, truncate=False)
    return staticFeatures


def addUserStaticsFeatures(samples, dataVocab):
    print("user统计特征处理...")
    samples = samples \
        .withColumn('userRatingCount', F.format_number(
        F.sum(F.lit(1)).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userRatingAvg", F.format_number(
        F.avg(F.col("rating")).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userRatingStddev", F.format_number(
        F.stddev(F.col("rating")).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userClickCount", F.format_number(
        F.sum(when(F.col('label') == 1, F.lit(1)).otherwise(F.lit(0))).over(
            sql.Window.partitionBy("userid").orderBy(F.col("timestamp")).rowsBetween(-100, -1)), NUMBER_PRECISION).cast(
        "float")) \
        .withColumn("userExpCount", F.format_number(F.sum(when(F.col('label') == 0, F.lit(1)).otherwise(F.lit(0))).over(
        sql.Window.partitionBy("userid").orderBy(F.col("timestamp")).rowsBetween(-100, -1)), NUMBER_PRECISION).cast(
        "float")) \
        .withColumn("userCtr",
                    F.format_number(F.col("userClickCount") / (F.col("userExpCount") + 1), NUMBER_PRECISION).cast(
                        "float")) \
        .filter(F.col("userRatingCount") > 1)

    samples.show(20, truncate=False)

    # 连续特征分桶
    bucket_vocab = [str(i) for i in range(101)]
    bucket_suffix = "_Bucket"
    for col in ["userRatingCount", "userRatingAvg", "userClickCount", "userExpCount"]:
        new_col = col + bucket_suffix
        samples = samples.withColumn(new_col, numberToBucketUdf(F.col(col))) \
            .drop(col) \
            .withColumn(new_col, F.when(F.col(new_col).isNull(), "0").otherwise(F.col(new_col)))
        dataVocab[new_col] = bucket_vocab

    # 方差处理
    number_suffix = "_number"
    for col in ["userRatingStddev"]:
        new_col = col + number_suffix
        samples = samples.withColumn(new_col, F.when(F.col(col).isNull(), 0).otherwise(1 / (F.col(col) + 1))).drop(col)
    for col in ["userCtr"]:
        new_col = col + number_suffix
        samples = samples.withColumn(col, F.when(F.col(col).isNull(), 0).otherwise(F.col(col))).withColumnRenamed(col,
                                                                                                                  new_col)

    samples.printSchema()
    samples.show(20, truncate=False)
    return samples


def addItemFeatures(itemDF, dataVocab, multi_col_vocab):
    # multi_col = ['sku_tags', 'sku_show_tags','second_demands', 'second_solutions', 'second_positions']
    multi_col = ['tags_v3', 'second_demands', 'second_solutions', 'second_positions']
    onehot_col = ['id', 'service_type', 'merchant_id', 'doctor_type', 'doctor_id', 'doctor_famous', 'hospital_id',
                  'hospital_city_tag_id', 'hospital_type', 'hospital_is_high_quality']

    for col in onehot_col:
        new_c = ITEM_PREFIX + col
        dataVocab[new_c] = list(set(itemDF[col].tolist()))
        itemDF[new_c] = itemDF[col]
    itemDF = itemDF.drop(columns=onehot_col)

    for c in multi_col:
        multi_col_vocab[c] = list(set(itemDF[c].tolist()))

        for i in range(1, 6):
            new_c = ITEM_PREFIX + c + "__" + str(i)
            itemDF[new_c] = itemDF[c].map(lambda x: parseTags(x, i))
            dataVocab[new_c] = multi_col_vocab[c]

    # 连续特征分桶
    bucket_vocab = [str(i) for i in range(101)]
    bucket_suffix = "_Bucket"
    for col in ['case_count', 'sales_count']:
        new_col = ITEM_PREFIX + col + bucket_suffix
        itemDF[new_col] = itemDF[col].map(numberToBucket)
        itemDF = itemDF.drop(columns=[col])
        dataVocab[new_col] = bucket_vocab

    for col in ['sku_price']:
        new_col = ITEM_PREFIX + col + bucket_suffix
        itemDF[new_col] = itemDF[col].map(priceToBucket)
        itemDF = itemDF.drop(columns=[col])
        dataVocab[new_col] = bucket_vocab

    # 连续数据处理
    number_suffix = "_number"
    for col in ["discount"]:
        new_col = ITEM_PREFIX + col + number_suffix
        itemDF[new_col] = itemDF[col]
        itemDF = itemDF.drop(columns=[col])

    return itemDF


def extractTags(genres_list):
    # 根据点击列表顺序加权
    genres_dict = defaultdict(int)
    for i, genres in enumerate(genres_list):
        for genre in genres.split(','):
            genres_dict[genre] += i
    sortedGenres = sorted(genres_dict.items(), key=lambda x: x[1], reverse=True)
    return [x[0] for x in sortedGenres]


# sql版本不支持F.reverse
def arrayReverse(arr):
    arr.reverse()
    return arr


"""
p —— 概率,即点击的概率,也就是 CTR
n —— 样本总数,即曝光数
z —— 在正态分布里,均值 + z * 标准差会有一定的置信度。例如 z 取 1.96,就有 95% 的置信度。
Wilson区间的含义就是,就是指在一定置信度下,真实的 CTR 范围是多少
"""


def wilson_ctr(num_pv, num_click):
    num_pv = float(num_pv)
    num_click = float(num_click)
    if num_pv * num_click == 0 or num_pv < num_click:
        return 0.0

    z = 1.96;
    n = num_pv;
    p = num_click / num_pv;
    score = (p + z * z / (2 * n) - z * math.sqrt((p * (1.0 - p) + z * z / (4.0 * n)) / n)) / (1.0 + z * z / n);
    return float(score);


def addUserFeatures(samples, dataVocab, multiVocab):
    dataVocab["userid"] = collectColumnToVocab(samples, "userid")
    dataVocab["user_city_id"] = collectColumnToVocab(samples, "user_city_id")
    dataVocab["user_os"] = ["ios", "android"]

    extractTagsUdf = F.udf(extractTags, ArrayType(StringType()))
    arrayReverseUdf = F.udf(arrayReverse, ArrayType(StringType()))
    ctrUdf = F.udf(wilson_ctr, FloatType())
    print("user历史数据处理...")
    # user历史记录
    samples = samples.withColumn('userPositiveHistory', F.collect_list(
        when(F.col('label') == 1, F.col('item_id')).otherwise(F.lit(None))).over(
        sql.Window.partitionBy("userid").orderBy(F.col("timestamp")).rowsBetween(-100, -1)))

    samples = samples.withColumn("userPositiveHistory", arrayReverseUdf(F.col("userPositiveHistory")))

    for i in range(1, 11):
        samples = samples.withColumn("userRatedHistory" + str(i),
                                     F.when(F.col("userPositiveHistory")[i - 1].isNotNull(),
                                            F.col("userPositiveHistory")[i - 1]).otherwise("-1"))
        dataVocab["userRatedHistory" + str(i)] = dataVocab["item_id"]
    samples = samples.drop("userPositiveHistory")

    # user偏好
    print("user 偏好数据")
    for c, v in multiVocab.items():
        new_col = "user" + "__" + c
        samples = samples.withColumn(new_col, extractTagsUdf(
            F.collect_list(when(F.col('label') == 1, F.col(c)).otherwise(F.lit(None))).over(
                sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1))))
        for i in range(1, 6):
            samples = samples.withColumn(new_col + "__" + str(i),
                                         F.when(F.col(new_col)[i - 1].isNotNull(), F.col(new_col)[i - 1]).otherwise(
                                             "-1"))
            dataVocab[new_col + "__" + str(i)] = v

        samples = samples.drop(new_col).drop(c)

    print("user统计特征处理...")
    samples = samples \
        .withColumn('userRatingCount', F.format_number(
        F.sum(F.lit(1)).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userRatingAvg", F.format_number(
        F.avg(F.col("rating")).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userRatingStddev", F.format_number(
        F.stddev(F.col("rating")).over(sql.Window.partitionBy('userid').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION).cast("float")) \
        .withColumn("userClickCount", F.format_number(
        F.sum(when(F.col('label') == 1, F.lit(1)).otherwise(F.lit(0))).over(
            sql.Window.partitionBy("userid").orderBy(F.col("timestamp")).rowsBetween(-100, -1)), NUMBER_PRECISION).cast(
        "float")) \
        .withColumn("userExpCount", F.format_number(F.sum(when(F.col('label') == 0, F.lit(1)).otherwise(F.lit(0))).over(
        sql.Window.partitionBy("userid").orderBy(F.col("timestamp")).rowsBetween(-100, -1)), NUMBER_PRECISION).cast(
        "float")) \
        .withColumn("userCtr",
                    F.format_number(ctrUdf(F.col("userClickCount"), F.col("userExpCount")), NUMBER_PRECISION)) \
        .filter(F.col("userRatingCount") > 1)

    samples.show(10, truncate=False)

    # 连续特征分桶
    bucket_vocab = [str(i) for i in range(101)]
    bucket_suffix = "_Bucket"
    for col in ["userRatingCount", "userRatingAvg", "userClickCount", "userExpCount"]:
        new_col = col + bucket_suffix
        samples = samples.withColumn(new_col, numberToBucketUdf(F.col(col))) \
            .drop(col) \
            .withColumn(new_col, F.when(F.col(new_col).isNull(), "0").otherwise(F.col(new_col)))
        dataVocab[new_col] = bucket_vocab

    # 方差处理
    number_suffix = "_number"
    for col in ["userRatingStddev"]:
        new_col = col + number_suffix
        samples = samples.withColumn(new_col, F.when(F.col(col).isNull(), 0).otherwise(1 / (F.col(col) + 1))).drop(col)
    for col in ["userCtr"]:
        new_col = col + number_suffix
        samples = samples.withColumn(col, F.when(F.col(col).isNull(), 0).otherwise(F.col(col))).withColumnRenamed(col,
                                                                                                                  new_col)

    samples.printSchema()
    samples.show(10, truncate=False)
    return samples


def addSampleLabel(ratingSamples):
    ratingSamples = ratingSamples.withColumn('label', when(F.col('rating') >= 1, 1).otherwise(0))
    # ratingSamples = ratingSamples.withColumn('label', when(F.col('rating') >= 5, 1).otherwise(0))
    ratingSamples.show(5, truncate=False)
    ratingSamples.printSchema()
    return ratingSamples


def samplesNegAndUnion(samplesPos, samplesNeg):
    # 正负样本 1:4
    pos_count = samplesPos.count()
    neg_count = samplesNeg.count()

    print("before filter posSize:{},negSize:{}".format(str(pos_count), str(neg_count)))

    samplesNeg = samplesNeg.sample(pos_count * 4 / neg_count)
    samples = samplesNeg.union(samplesPos)
    dataSize = samples.count()
    print("dataSize:{}".format(str(dataSize)))
    return samples


def splitAndSaveTrainingTestSamplesByTimeStamp(samples, splitTimestamp, file_path):
    samples = samples.withColumn("timestampLong", F.col("timestamp").cast(LongType()))
    # quantile = smallSamples.stat.approxQuantile("timestampLong", [0.8], 0.05)
    # splitTimestamp = quantile[0]
    train = samples.where(F.col("timestampLong") <= splitTimestamp).drop("timestampLong")
    test = samples.where(F.col("timestampLong") > splitTimestamp).drop("timestampLong")
    print("split train size:{},test size:{}".format(str(train.count()), str(test.count())))
    trainingSavePath = file_path + '_train'
    testSavePath = file_path + '_test'
    train.write.option("header", "true").option("delimiter", "|").mode('overwrite').csv(trainingSavePath)
    test.write.option("header", "true").option("delimiter", "|").mode('overwrite').csv(testSavePath)


def collectColumnToVocab(samples, column):
    datas = samples.select(column).distinct().collect()
    vocabSet = set()
    for d in datas:
        if d[column]:
            vocabSet.add(str(d[column]))
    return list(vocabSet)


def collectMutiColumnToVocab(samples, column):
    datas = samples.select(column).distinct().collect()
    tagSet = set()
    for d in datas:
        if d[column]:
            for tag in d[column].split(","):
                tagSet.add(tag)

    tagSet.add("-1")  # 空值默认
    return list(tagSet)


def dataVocabToRedis(dataVocab):
    conn = getRedisConn()
    conn.set(FEATURE_VOCAB_KEY, dataVocab)
    conn.expire(FEATURE_VOCAB_KEY, 60 * 60 * 24 * 7)


def featureColumnsToRedis(columns):
    conn = getRedisConn()
    conn.set(FEATURE_COLUMN_KEY, json.dumps(columns))
    conn.expire(FEATURE_COLUMN_KEY, 60 * 60 * 24 * 7)


def featureToRedis(key, datas):
    conn = getRedisConn()
    for k, v in datas.items():
        newKey = key + k
        conn.set(newKey, v)
        conn.expire(newKey, 60 * 60 * 24 * 7)


def userFeaturesToRedis(samples, columns, prefix, redisKey):
    idCol = prefix + "id"
    timestampCol = idCol + "_timestamp"

    def toRedis(datas):
        conn = getRedisConn()
        for d in datas:
            k = d[idCol]
            v = json.dumps(d.asDict(), ensure_ascii=False)
            newKey = redisKey + k
            conn.set(newKey, v)
            conn.expire(newKey, 60 * 60 * 24 * 7)

    # 根据timestamp获取每个user最新的记录
    prefixSamples = samples.groupBy(idCol).agg(F.max("timestamp").alias(timestampCol))
    resDatas = prefixSamples.join(samples, on=[idCol], how='inner').where(F.col("timestamp") == F.col(timestampCol))
    resDatas = resDatas.select(*columns).distinct()
    resDatas.show(10, truncate=False)
    print(prefix, resDatas.count())
    resDatas.repartition(8).foreachPartition(toRedis)


def itemFeaturesToRedis(itemStaticDF, redisKey):
    idCol = "item_id"

    def toRedis(datas):
        conn = getRedisConn()
        for d in datas:
            k = d[idCol]
            v = json.dumps(d.asDict(), ensure_ascii=False)
            newKey = redisKey + k
            conn.set(newKey, v)
            conn.expire(newKey, 60 * 60 * 24 * 7)

    itemStaticDF.repartition(8).foreachPartition(toRedis)


"""
    数据加载
"""

CONTENT_TYPE = "service"
SERVICE_HOSTS = [
    {'host': "172.16.52.33", 'port': 9200},
    {'host': "172.16.52.19", 'port': 9200},
    {'host': "172.16.52.48", 'port': 9200},
    {'host': "172.16.52.27", 'port': 9200},
    {'host': "172.16.52.34", 'port': 9200}
]
ES_INDEX = "gm-dbmw-service-read"
ES_INDEX_TEST = "gm_test-service-read"

ACTION_REG = r"""^\\d+$"""


def getEsConn_test():
    host_config = [{'host': '172.18.52.14', 'port': 9200}, {'host': '172.18.52.133', 'port': 9200},
                   {'host': '172.18.52.7', 'port': 9200}]

    return Elasticsearch(host_config, http_auth=('elastic', 'gm_test'), timeout=3600)


def getEsConn():
    return Elasticsearch(SERVICE_HOSTS, http_auth=('elastic', 'gengmei!@#'), timeout=3600)


def getClickSql(start, end):
    sql = """
    SELECT DISTINCT t1.partition_date, t1.cl_id device_id, t1.card_id,t1.time_stamp,t1.page_stay,t1.cl_type as os,t1.city_id as user_city_id
      FROM
        (
            select partition_date,city_id,cl_id,business_id as card_id,time_stamp,page_stay,cl_type
            from online.bl_hdfs_maidian_updates
            where action = 'page_view'
            AND partition_date>='{startDay}' and partition_date<='{endDay}'
            AND page_name='welfare_detail'
            -- AND page_stay>=1
            AND cl_id is not null
            AND cl_id != ''
            AND business_id is not null
            AND business_id != ''
            group by partition_date,city_id,cl_id,business_id,time_stamp,page_stay,cl_type
        ) AS t1
        join
        (	--渠道,新老
            SELECT distinct device_id
            FROM online.ml_device_day_active_status
            where partition_date>='{startDay}' and partition_date<='{endDay}'
            AND active_type in ('1','2','4')
            and first_channel_source_type not in ('yqxiu1','yqxiu2','yqxiu3','yqxiu4','yqxiu5','mxyc1','mxyc2','mxyc3'
            ,'wanpu','jinshan','jx','maimai','zhuoyi','huatian','suopingjingling','mocha','mizhe','meika','lamabang'
            ,'js-az1','js-az2','js-az3','js-az4','js-az5','jfq-az1','jfq-az2','jfq-az3','jfq-az4','jfq-az5','toufang1'
            ,'toufang2','toufang3','toufang4','toufang5','toufang6','TF-toufang1','TF-toufang2','TF-toufang3','TF-toufang4'
            ,'TF-toufang5','tf-toufang1','tf-toufang2','tf-toufang3','tf-toufang4','tf-toufang5','benzhan','promotion_aso100'
            ,'promotion_qianka','promotion_xiaoyu','promotion_dianru','promotion_malioaso','promotion_malioaso-shequ'
            ,'promotion_shike','promotion_julang_jl03','promotion_zuimei','','unknown')
            AND first_channel_source_type not like 'promotion\_jf\_%'
        ) t2
        on t1.cl_id = t2.device_id

        LEFT JOIN
        (	--去除黑名单
            select distinct device_id
            from ML.ML_D_CT_DV_DEVICECLEAN_DIMEN_D
            where PARTITION_DAY =regexp_replace(DATE_SUB(current_date,1) ,'-','')
            AND is_abnormal_device = 'true'
        )t3 
        on t3.device_id=t2.device_id
        WHERE t3.device_id is null
         """.format(startDay=start, endDay=end)
    print(sql)
    return sql


def getExposureSql(start, end):
    sql = """
    SELECT DISTINCT t1.partition_date,t1.cl_id device_id,t1.card_id,t1.time_stamp, 0 as page_stay,cl_type as os,t1.city_id as user_city_id
    from
        (	--新首页卡片曝光
            SELECT partition_date,city_id,cl_type,cl_id,card_id,max(time_stamp) as time_stamp
            FROM online.ml_community_precise_exposure_detail
            where partition_date>='{startDay}' and partition_date<='{endDay}'
            and action in ('page_precise_exposure','home_choiceness_card_exposure')
            and cl_id IS NOT NULL
            and card_id IS NOT NULL
            and is_exposure='1'
            --and page_name='home'
            --and tab_name='精选'
            --and page_name in ('home','search_result_more')
            and ((page_name='home' and tab_name='精选') or (page_name='category' and tab_name = '商品'))
            and card_type in ('card','video')
            and card_content_type in ('service')
            and (get_json_object(exposure_card,'$.in_page_pos') is null or get_json_object(exposure_card,'$.in_page_pos') != 'seckill')
            group by partition_date,city_id,cl_type,cl_id,card_id,app_session_id

        ) t1
        join
        (	--渠道,新老
            SELECT distinct device_id
            FROM online.ml_device_day_active_status
            where partition_date>='{startDay}' and partition_date<='{endDay}'
            AND active_type in ('1','2','4')
            and first_channel_source_type not in ('yqxiu1','yqxiu2','yqxiu3','yqxiu4','yqxiu5','mxyc1','mxyc2','mxyc3'
            ,'wanpu','jinshan','jx','maimai','zhuoyi','huatian','suopingjingling','mocha','mizhe','meika','lamabang'
            ,'js-az1','js-az2','js-az3','js-az4','js-az5','jfq-az1','jfq-az2','jfq-az3','jfq-az4','jfq-az5','toufang1'
            ,'toufang2','toufang3','toufang4','toufang5','toufang6','TF-toufang1','TF-toufang2','TF-toufang3','TF-toufang4'
            ,'TF-toufang5','tf-toufang1','tf-toufang2','tf-toufang3','tf-toufang4','tf-toufang5','benzhan','promotion_aso100'
            ,'promotion_qianka','promotion_xiaoyu','promotion_dianru','promotion_malioaso','promotion_malioaso-shequ'
            ,'promotion_shike','promotion_julang_jl03','promotion_zuimei','','unknown')
            AND first_channel_source_type not like 'promotion\_jf\_%'
        ) t2
        on t1.cl_id = t2.device_id

        LEFT JOIN
        (	--去除黑名单
            select distinct device_id
            from ML.ML_D_CT_DV_DEVICECLEAN_DIMEN_D
            where PARTITION_DAY =regexp_replace(DATE_SUB(current_date,1) ,'-','')
            AND is_abnormal_device = 'true'
        )t3 
        on t3.device_id=t2.device_id
        WHERE t3.device_id is null
    """.format(startDay=start, endDay=end)
    print(sql)
    return sql


def getClickSql2(start, end):
    sql = """
            SELECT DISTINCT t1.partition_date, t1.cl_id device_id, t1.business_id card_id,t1.time_stamp time_stamp,t1.page_stay as page_stay 
              FROM
              (select partition_date,cl_id,business_id,action,page_name,page_stay,time_stamp,page_stay
              from online.bl_hdfs_maidian_updates
              where action = 'page_view'
                AND partition_date BETWEEN '{}' AND '{}'
                AND page_name='welfare_detail'
                AND page_stay>=1
                AND cl_id is not null
                AND cl_id != ''
                AND business_id is not null
                AND business_id != ''
                AND business_id rlike '{}'
                ) AS t1
              JOIN
              (select partition_date,active_type,first_channel_source_type,device_id
              from online.ml_device_day_active_status
              where partition_date BETWEEN '{}' AND '{}'
                AND active_type IN ('1', '2', '4')
                AND first_channel_source_type not IN ('yqxiu1','yqxiu2','yqxiu3','yqxiu4','yqxiu5','mxyc1','mxyc2','mxyc3'
                      ,'wanpu','jinshan','jx','maimai','zhuoyi','huatian','suopingjingling','mocha','mizhe','meika','lamabang'
                      ,'js-az1','js-az2','js-az3','js-az4','js-az5','jfq-az1','jfq-az2','jfq-az3','jfq-az4','jfq-az5','toufang1'
                      ,'toufang2','toufang3','toufang4','toufang5','toufang6','TF-toufang1','TF-toufang2','TF-toufang3','TF-toufang4'
                      ,'TF-toufang5','tf-toufang1','tf-toufang2','tf-toufang3','tf-toufang4','tf-toufang5','benzhan','promotion_aso100'
                      ,'promotion_qianka','promotion_xiaoyu','promotion_dianru','promotion_malioaso','promotion_malioaso-shequ'
                      ,'promotion_shike','promotion_julang_jl03','promotion_zuimei')
                AND first_channel_source_type not LIKE 'promotion\\_jf\\_%') as t2
              ON t1.cl_id = t2.device_id
              AND t1.partition_date = t2.partition_date
            LEFT JOIN
            (
                select distinct device_id
                from ML.ML_D_CT_DV_DEVICECLEAN_DIMEN_D
                where PARTITION_DAY = regexp_replace(DATE_SUB(current_date,1) ,'-','')
                AND is_abnormal_device = 'true'
            )dev
            on t1.cl_id=dev.device_id
            WHERE  dev.device_id is null 
         """.format(start, end, ACTION_REG, start, end)
    print(sql)
    return sql


def getExposureSql2(start, end):
    sql = """
        SELECT DISTINCT t1.partition_date,t1.cl_id device_id,t1.card_id,t1.time_stamp, 0 as page_stay  
        FROM
          (SELECT partition_date,cl_id,card_id,time_stamp
           FROM online.ml_community_precise_exposure_detail
           WHERE cl_id IS NOT NULL
             AND card_id IS NOT NULL
             AND card_id rlike '{}'
             AND action='page_precise_exposure'
             AND card_content_type = '{}'
             AND is_exposure = 1 ) AS t1
        LEFT JOIN online.ml_device_day_active_status AS t2 ON t1.cl_id = t2.device_id
        AND t1.partition_date = t2.partition_date
        LEFT JOIN
          ( SELECT DISTINCT device_id
           FROM ML.ML_D_CT_DV_DEVICECLEAN_DIMEN_D
           WHERE PARTITION_DAY = regexp_replace(DATE_SUB(CURRENT_DATE,1),'-','')
             AND is_abnormal_device = 'true' )dev 
             ON t1.cl_id=dev.device_id
        WHERE dev.device_id IS NULL
          AND t2.partition_date BETWEEN '{}' AND '{}'
          AND t2.active_type IN ('1',
                                 '2',
                                 '4')
          AND t2.first_channel_source_type NOT IN ('yqxiu1',
                                                   'yqxiu2',
                                                   'yqxiu3',
                                                   'yqxiu4',
                                                   'yqxiu5',
                                                   'mxyc1',
                                                   'mxyc2',
                                                   'mxyc3' ,
                                                   'wanpu',
                                                   'jinshan',
                                                   'jx',
                                                   'maimai',
                                                   'zhuoyi',
                                                   'huatian',
                                                   'suopingjingling',
                                                   'mocha',
                                                   'mizhe',
                                                   'meika',
                                                   'lamabang' ,
                                                   'js-az1',
                                                   'js-az2',
                                                   'js-az3',
                                                   'js-az4',
                                                   'js-az5',
                                                   'jfq-az1',
                                                   'jfq-az2',
                                                   'jfq-az3',
                                                   'jfq-az4',
                                                   'jfq-az5',
                                                   'toufang1' ,
                                                   'toufang2',
                                                   'toufang3',
                                                   'toufang4',
                                                   'toufang5',
                                                   'toufang6',
                                                   'TF-toufang1',
                                                   'TF-toufang2',
                                                   'TF-toufang3',
                                                   'TF-toufang4' ,
                                                   'TF-toufang5',
                                                   'tf-toufang1',
                                                   'tf-toufang2',
                                                   'tf-toufang3',
                                                   'tf-toufang4',
                                                   'tf-toufang5',
                                                   'benzhan',
                                                   'promotion_aso100' ,
                                                   'promotion_qianka',
                                                   'promotion_xiaoyu',
                                                   'promotion_dianru',
                                                   'promotion_malioaso',
                                                   'promotion_malioaso-shequ' ,
                                                   'promotion_shike',
                                                   'promotion_julang_jl03',
                                                   'promotion_zuimei')
          AND t2.first_channel_source_type NOT LIKE 'promotion\\_jf\\_%'
        """.format(ACTION_REG, CONTENT_TYPE, start, end)

    print(sql)
    return sql


def connectDoris(spark, table):
    return spark.read \
        .format("jdbc") \
        .option("driver", "com.mysql.jdbc.Driver") \
        .option("url", "jdbc:mysql://172.16.30.136:3306/doris_prod") \
        .option("dbtable", table) \
        .option("user", "doris") \
        .option("password", "o5gbA27hXHHm") \
        .load()


def get_spark(appName):
    sparkConf = SparkConf()
    sparkConf.set("spark.sql.crossJoin.enabled", True)
    sparkConf.set("spark.debug.maxToStringFields", "100")
    sparkConf.set("spark.tispark.plan.allow_index_double_read", False)
    sparkConf.set("spark.tispark.plan.allow_index_read", True)
    sparkConf.set("spark.hive.mapred.supports.subdirectories", True)
    sparkConf.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", True)
    sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    sparkConf.set("mapreduce.output.fileoutputformat.compress", False)
    sparkConf.set("mapreduce.map.output.compress", False)
    spark = (SparkSession
             .builder
             .config(conf=sparkConf)
             .appName(appName)
             .enableHiveSupport()
             .getOrCreate())
    return spark


def init_es_query():
    q = {
        "_source": {
            "includes": []
        },
        "query": {
            "bool": {
                "must": [{"term": {"is_online": True}}],
                "must_not": [],
                "should": []
            }
        }
    }
    return q


def parseSource(_source):
    id = str(_source.setdefault("id", -1))
    discount = _source.setdefault("discount", 0)
    case_count = _source.setdefault("case_count", 0)
    sales_count = _source.setdefault("sales_count", 0)
    service_type = str(_source.setdefault("service_type", -1))
    second_demands = ','.join(_source.setdefault("second_demands", ["-1"]))
    second_solutions = ','.join(_source.setdefault("second_solutions", ["-1"]))
    second_positions = ','.join(_source.setdefault("second_positions", ["-1"]))
    tags_v3 = ','.join(_source.setdefault("tags_v3", ["-1"]))

    # sku
    sku_list = _source.setdefault("sku_list", [])
    sku_tags_list = []
    sku_show_tags_list = []
    sku_price_list = []
    for sku in sku_list:
        sku_tags_list += sku.setdefault("sku_tags", [])
        # sku_tags_list += sku.setdefault("sku_tags_id",[])
        sku_show_tags_list.append(sku.setdefault("show_project_type_name", ""))
        price = sku.setdefault("price", 0.0)
        if price > 0:
            sku_price_list.append(price)

    # sku_tags = ",".join([str(i) for i in sku_tags_list]) if len(sku_tags_list) > 0 else "-1"
    # sku_show_tags = ",".join(sku_show_tags_list) if len(sku_show_tags_list) > 0 else "-1"
    sku_price = min(sku_price_list) if len(sku_price_list) > 0 else 0.0

    # merchant_id
    merchant_id = str(_source.setdefault("merchant_id", "-1"))
    # doctor_type id famous_doctor
    doctor = _source.setdefault("doctor", {})
    doctor_type = str(doctor.setdefault("doctor_type", "-1"))
    doctor_id = str(doctor.setdefault("id", "-1"))
    doctor_famous = str(int(doctor.setdefault("famous_doctor", False)))

    # hospital id city_tag_id hospital_type is_high_quality
    hospital = doctor.setdefault("hospital", {})
    hospital_id = str(hospital.setdefault("id", "-1"))
    hospital_city_tag_id = str(hospital.setdefault("city_tag_id", -1))
    hospital_type = str(hospital.setdefault("hospital_type", "-1"))
    hospital_is_high_quality = str(int(hospital.setdefault("is_high_quality", False)))

    data = [id,
            discount,
            case_count,
            sales_count,
            service_type,
            merchant_id,
            doctor_type,
            doctor_id,
            doctor_famous,
            hospital_id,
            hospital_city_tag_id,
            hospital_type,
            hospital_is_high_quality,
            second_demands,
            second_solutions,
            second_positions,
            tags_v3,
            # sku_show_tags,
            sku_price
            ]

    return data


# es中获取特征
def get_service_feature_df():
    es_columns = ["id", "discount", "sales_count", "doctor", "case_count", "service_type", "merchant_id",
                  "second_demands", "second_solutions", "second_positions", "sku_list", "tags_v3"]
    query = init_es_query()
    query["_source"]["includes"] = es_columns
    print(json.dumps(query), flush=True)

    es_cli = getEsConn()
    scan_re = scan(client=es_cli, index=ES_INDEX, query=query, scroll='3m')
    datas = []
    for res in scan_re:
        _source = res['_source']
        data = parseSource(_source)
        datas.append(data)
    print("item size:", len(datas))

    itemColumns = ['id', 'discount', 'case_count', 'sales_count', 'service_type', 'merchant_id',
                   'doctor_type', 'doctor_id', 'doctor_famous', 'hospital_id', 'hospital_city_tag_id', 'hospital_type',
                   'hospital_is_high_quality', 'second_demands', 'second_solutions', 'second_positions',
                   'tags_v3', 'sku_price']
    # 'sku_tags','sku_show_tags','sku_price']
    df = pd.DataFrame(datas, columns=itemColumns)
    return df


def addDays(n, format="%Y%m%d"):
    return (date.today() + timedelta(days=n)).strftime(format)


if __name__ == '__main__':

    start = time.time()
    # 入参
    trainDays = int(sys.argv[1])
    print('trainDays:{}'.format(trainDays), flush=True)

    endDay = addDays(0)
    startDay = addDays(-int(trainDays))

    print("train_data start:{} end:{}".format(startDay, endDay))

    spark = get_spark("service_feature_csv_export")
    spark.sparkContext.setLogLevel("ERROR")

    itemDF = get_service_feature_df()
    print(itemDF.columns)
    print(itemDF.head(10))

    # 行为数据
    clickSql = getClickSql(startDay, endDay)
    expSql = getExposureSql(startDay, endDay)

    clickDF = spark.sql(clickSql)
    expDF = spark.sql(expSql)
    # ratingDF = samplesNegAndUnion(clickDF,expDF)
    ratingDF = clickDF.union(expDF)
    ratingDF = ratingDF.withColumnRenamed("time_stamp", "timestamp") \
        .withColumnRenamed("device_id", "userid") \
        .withColumnRenamed("card_id", "item_id") \
        .withColumnRenamed("page_stay", "rating") \
        .withColumnRenamed("os", "user_os") \
        .withColumn("user_city_id", F.when(F.col("user_city_id").isNull(), "-1").otherwise(F.col("user_city_id"))) \
        .withColumn("timestamp", F.col("timestamp").cast("long"))

    print(ratingDF.columns)
    print(ratingDF.show(10, truncate=False))

    print("添加label...")
    ratingSamplesWithLabel = addSampleLabel(ratingDF)
    df = ratingSamplesWithLabel.toPandas()
    df = pd.DataFrame(df)

    posCount = df.loc[df["label"] == 1]["label"].count()
    negCount = df.loc[df["label"] == 0]["label"].count()
    print("pos size:" + str(posCount), "neg size:" + str(negCount))

    itemDF = get_service_feature_df()
    print(itemDF.columns)
    print(itemDF.head(10))
    # itemDF.to_csv("/tmp/service_{}.csv".format(endDay))
    # df.to_csv("/tmp/service_train_{}.csv".format(endDay))

    # 数据字典
    dataVocab = {}
    multiVocab = {}

    print("处理item特征...")
    timestmp1 = int(round(time.time()))
    itemDF = addItemFeatures(itemDF, dataVocab, multiVocab)
    timestmp2 = int(round(time.time()))
    print("处理item特征, 耗时s:{}".format(timestmp2 - timestmp1))
    print("multiVocab:")
    for k, v in multiVocab.items():
        print(k, len(v))

    print("dataVocab:")
    for k, v in dataVocab.items():
        print(k, len(v))

    itemDF_spark = spark.createDataFrame(itemDF)
    itemDF_spark.printSchema()
    itemDF_spark.show(10, truncate=False)

    # item统计特征处理
    itemStaticDF = addItemStaticFeatures(ratingSamplesWithLabel, itemDF_spark, dataVocab)

    # 统计数据处理
    # ratingSamplesWithLabel = addStaticsFeatures(ratingSamplesWithLabel,dataVocab)

    samples = ratingSamplesWithLabel.join(itemStaticDF, on=['item_id'], how='inner')

    print("处理user特征...")
    samplesWithUserFeatures = addUserFeatures(samples, dataVocab, multiVocab)
    timestmp3 = int(round(time.time()))
    print("处理user特征, 耗时s:{}".format(timestmp3 - timestmp2))
    #
    # user columns
    user_columns = [c for c in samplesWithUserFeatures.columns if c.startswith("user")]
    print("collect feature for user:{}".format(str(user_columns)))
    # item columns
    item_columns = [c for c in itemStaticDF.columns if c.startswith("item")]
    print("collect feature for item:{}".format(str(item_columns)))
    # model columns
    print("model columns to redis...")
    model_columns = user_columns + item_columns
    featureColumnsToRedis(model_columns)

    print("数据字典save...")
    print("dataVocab:", str(dataVocab.keys()))
    vocab_path = "../vocab/{}_vocab.json".format(VERSION)
    dataVocabStr = json.dumps(dataVocab, ensure_ascii=False)
    open(configUtils.VOCAB_PATH, mode='w', encoding='utf-8').write(dataVocabStr)

    # item特征数据存入redis
    itemFeaturesToRedis(itemStaticDF, FEATURE_ITEM_KEY)
    timestmp6 = int(round(time.time()))
    print("item feature to redis 耗时s:{}".format(timestmp6 - timestmp3))

    """特征数据存入redis======================================"""
    # user特征数据存入redis
    userFeaturesToRedis(samplesWithUserFeatures, user_columns, "user", FEATURE_USER_KEY)
    timestmp5 = int(round(time.time()))
    print("user feature to redis 耗时s:{}".format(timestmp5 - timestmp6))

    """训练数据保存 ======================================"""
    timestmp3 = int(round(time.time()))
    train_columns = model_columns + ["label", "timestamp", "rating"]
    trainSamples = samplesWithUserFeatures.select(*train_columns)
    train_df = trainSamples.toPandas()
    train_df = pd.DataFrame(train_df)
    train_df.to_csv(DATA_PATH_TRAIN, sep="|")
    timestmp4 = int(round(time.time()))
    print("训练数据写入success 耗时s:{}".format(timestmp4 - timestmp3))

    print("总耗时m:{}".format((timestmp4 - start) / 60))

    spark.stop()