Commit 367d7a94 authored by 王志伟's avatar 王志伟
parents 2a883860 1773c210
......@@ -6,8 +6,7 @@ import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import hdfs
import avro
def app_list_func(x,l):
b = x.split(",")
......@@ -112,11 +111,11 @@ def feature_engineer():
spark.createDataFrame(test).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z")\
.write.format("avro").save(path="/recommend/va", mode="overwrite")
.write.format("avro").save(path=path+"va", mode="overwrite")
spark.createDataFrame(train).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z")\
.write.format("avro").save(path="/recommend/tr", mode="overwrite")
.write.format("avro").save(path=path+"tr", mode="overwrite")
print("done")
rdd.unpersist()
......@@ -168,7 +167,7 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
print("native")
print(native_pre.count())
native_pre.write.format("avro").save(path="/recommend/pre_native", mode="overwrite")
native_pre.write.format("avro").save(path=path+"pre_native", mode="overwrite")
spark.createDataFrame(rdd.filter(lambda x: x[6] == 0)
.map(lambda x: (x[0], x[1], x[2],x[7],x[8],x[9],x[10],x[11],x[12],
......@@ -177,13 +176,13 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids","y","z","ucity_id",
"ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time", "top","stat_date").write.format("avro").save(path="/recommend/native", mode="overwrite")
"recover_time", "top","stat_date").write.format("avro").save(path=path+"native", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[6] == 1).map(lambda x: (x[3], x[4], x[5]))) \
.toDF("city", "uid", "cid_id")
print("nearby")
print(nearby_pre.count())
nearby_pre.write.format("avro").save(path="/recommend/pre_nearby", mode="overwrite")
nearby_pre.write.format("avro").save(path=path+"pre_nearby", mode="overwrite")
spark.createDataFrame(rdd.filter(lambda x: x[6] == 1)
.map(lambda x: (x[0], x[1], x[2], x[7], x[8], x[9], x[10], x[11], x[12],
......@@ -192,7 +191,7 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids","y","z", "ucity_id",
"ccity_name", "device_type", "manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time","top","stat_date").write.format("avro").save(path="/recommend/nearby", mode="overwrite")
"recover_time","top","stat_date").write.format("avro").save(path=path+"nearby", mode="overwrite")
rdd.unpersist()
......@@ -233,6 +232,7 @@ if __name__ == '__main__':
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
spark.sparkContext.setLogLevel("WARN")
path = "/strategy/esmm/"
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
......
# -*- coding: utf-8 -*-
from pyspark.sql import HiveContext
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
# import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
# from py4j.java_gateway import java_import
# import pytispark.pytispark as pti
# import pandas as pd
#
# def con_sql(db,sql):
# cursor = db.cursor()
# try:
# cursor.execute(sql)
# result = cursor.fetchall()
# df = pd.DataFrame(list(result))
# except Exception:
# print("发生异常", Exception)
# df = pd.DataFrame()
# finally:
# db.close()
# return df
import numpy as np
def test():
conf = SparkConf().setAppName("My App").set("spark.io.compression.codec", "lzf")
sc = SparkContext(conf = conf)
hive_context = HiveContext(sc)
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
......@@ -46,34 +28,26 @@ def test():
spark.sql(sql).show(6)
# def esmm_pre():
# yesterday = (datetime.date.today() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
# print(yesterday)
#
# spark = SparkSession.builder.enableHiveSupport().getOrCreate()
# # gw = SparkContext._gateway
# #
# # # Import TiExtensions
# # java_import(gw.jvm, "org.apache.spark.sql.TiContext")
#
# # Inject TiExtensions, and get a TiContext
# # ti = gw.jvm.TiExtensions.getInstance(spark._jsparkSession).getOrCreateTiContext(spark._jsparkSession)
# ti = pti.TiContext(spark)
#
# ti.tidbMapDatabase("jerry_test")
#
# # sql("use tpch_test")
# spark.sql("select count(*) from esmm_pre_data").show(6)
#
# # conf = SparkConf().setAppName("esmm_pre").set("spark.io.compression.codec", "lzf")
#
# spark.sql("""
# select concat(tmp1.device_id,",",tmp1.city_id) as device_city, tmp1.merge_queue from (select device_id,if(city_id='world','worldwide',city_id) city_id,similarity_cid as merge_queue from nd_device_cid_similarity_matrix
# union select device_id,if(city_id='world','worldwide',city_id) city_id,native_queue as merge_queue from ffm_diary_queue
# union select device_id,city_id,search_queue as merge_queue from search_queue) as tmp1 where tmp1.device_id in (select distinct device_id from data_feed_click where stat_date='{}'
# """.format(yesterday)).show(6)
def some_function(x):
# Use the libraries to do work
return np.sin(x)**2 + 2
if __name__ == '__main__':
test()
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf") \
.set("spark.driver.maxResultSize", "8g")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
# spark.sparkContext.setLogLevel("WARN")
# sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60"
# spark.sql(sql).show(6)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment