from pyspark.sql import SQLContext
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
import datetime
from pyspark.sql import HiveContext

def get_data(day):
    sc = SparkContext(conf=SparkConf().setAppName("multi_task")).getOrCreate()
    sc.setLogLevel("WARN")
    ctx = SQLContext(sc)
    end_date = (datetime.date.today() - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
    start_date = (datetime.date.today() - datetime.timedelta(days=day)).strftime("%Y-%m-%d")
    dbtable = "(select device_id,cid_id,stat_date from data_feed_click " \
              "where stat_date >= '{}' and stat_date <= '{}')tmp".format(start_date, end_date)




    click = ctx.read.format("jdbc").options(url="jdbc:mysql://10.66.157.22:4000/jerry_prod",
                                             driver="com.mysql.jdbc.Driver",
                                             dbtable=dbtable,
                                             user="root",
                                             password="3SYz54LS9#^9sBvC").load()
    click.show(6)
    click = click.rdd.map(lambda x:(x[0],x[1],x[2]))
    device_id = tuple(click.map(lambda x:x[0]).collect())
    print(device_id[0:2])
    dbtable = "(select device_id,cid_id,stat_date from data_feed_exposure " \
              "where stat_date >= '{}' and stat_date <= '{}' and device_id in {})tmp".format(start_date,end_date,device_id)
    exp = ctx.read.format("jdbc").options(url="jdbc:mysql://10.66.157.22:4000/jerry_prod",
                                            driver="com.mysql.jdbc.Driver",
                                            dbtable=dbtable,
                                            user="root",
                                            password="3SYz54LS9#^9sBvC").load()

    exp.show(6)
    exp = exp.rdd.map(lambda x:(x[0],x[1],x[2])).subtract(click).map(lambda x:((x[0],x[1],x[2]),1))\
        .reduceByKey(lambda x,y:x+y).filter(lambda x:x[1] >= 3).map(lambda x:(x[0][0],x[0][1],x[0][2],0))
    click = click.map(lambda x:(x[0],x[1],x[2],1))

    date = click.map(lambda x:x[2]).collect()

def test():
    sc = SparkContext(conf=SparkConf().setAppName("multi_task")).getOrCreate()
    sc.setLogLevel("WARN")
    ctx = SQLContext(sc)
    end_date = "2018-09-10"
    start_date = "2018-09-09"
    dbtable = "(select device_id,cid_id,stat_date from data_feed_click " \
              "limit 80)tmp".format(start_date)

    click = ctx.read.format("jdbc").options(url="jdbc:mysql://192.168.15.12:4000/jerry_prod",
                                          driver="com.mysql.jdbc.Driver",
                                          dbtable=dbtable,
                                          user="root",
                                          password="").load()
    click.show(6)
    click = click.rdd.map(lambda x: (x[0], x[1], x[2]))

    date = click.map(lambda x: x[2]).collect()
    cid = click.map(lambda x: x[1]).collect()
    click = click.map(lambda x:str(1)+" "+str(cid.index(x[1]))+":"+str(1)+" "+str(date.index(x[2]))+":"+str(1))
    print(click.take(6))

    # device_id = tuple(click.map(lambda x: x[0]).collect())
    # print(device_id[0:2])
    # dbtable = "(select device_id,cid_id,stat_date from data_feed_exposure " \
    #           "where stat_date = '{}' and device_id in {})tmp".format(start_date,device_id)
    # exp = ctx.read.format("jdbc").options(url="jdbc:mysql://192.168.15.12:4000/jerry_prod",
    #                                       driver="com.mysql.jdbc.Driver",
    #                                       dbtable=dbtable,
    #                                       user="root",
    #                                       password="").load()
    # exp.show(6)
    # exp = exp.rdd.map(lambda x: (x[0], x[1], x[2])).subtract(click).map(lambda x: ((x[0], x[1], x[2]), 1)) \
    #     .reduceByKey(lambda x, y: x + y).filter(lambda x: x[1] >= 3).map(lambda x: (x[0][0], x[0][1], x[0][2], 0))
    # click = click.map(lambda x: (x[0], x[1], x[2], 1))

def hive():
    conf = SparkConf().setMaster("spark://10.30.181.88:7077").setAppName("My app")
    sc = SparkContext(conf=conf)
    sc.setLogLevel("WARN")
    sqlContext = HiveContext(sc)
    sql = "select partition_date from online.tl_hdfs_maidian_view limit 10"
    my_dataframe = sqlContext.sql(sql)
    my_dataframe.show(6)





if __name__ == '__main__':
    hive()