# from __future__ import print_function
# import datetime
# import time
# import pymysql
# from pyspark.sql import SparkSession
# from pyspark.sql import SQLContext
# from pyspark import SparkConf,SparkContext
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
# from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
# from pyspark.streaming.kafka import KafkaUtils
# import argparse
# import time
# from datetime import datetime





# def fetch_data(start_date, end_date):
#     # sc = SparkSession.builder.appName("Python Spark SQL basic example") \
#     #     .config('spark.some.config,option0', 'some-value') \
#     #     .getOrCreate()
#     sc = SparkContext(conf=SparkConf().setAppName("mnist_streaming"))
#     ctx = SQLContext(sc)
#     # jdbcDf = ctx.read.format("jdbc").options(url="jdbc:mysql://192.168.15.12:4000",
#     #                                          driver="com.mysql.jdbc.Driver",
#     #                                          # dbtable="((select device_id,cid_id,time,device_type,city_id,1 as clicked from jerry_test.data_feed_click where cid_id in (select id from eagle.src_mimas_prod_api_diary where doctor_id is not null and content_level >3.5)  and  cid_type = 'diary' and stat_date = '2018-08-12') union (select device_id,cid_id,time,device_type,city_id,0 as clicked from jerry_test.data_feed_exposure where cid_id in (select id from eagle.src_mimas_prod_api_diary where doctor_id is not null and content_level >3.5) and  cid_type = 'diary' and stat_date = '2018-08-12')) tmp",user="root",
#     #                                          dbtable="(select id as diary_id,doctor_id from eagle.src_mimas_prod_api_diary where doctor_id is not null and content_level >3.5 and datediff(current_date,created_time)<90) tmp",
#     #                                          user="root",
#     #                                          password="").load()
#     # df = ctx.read.format("jdbc").options(url="jdbc:mysql://rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com:3306/doris_test",
#     #                                      driver="com.mysql.jdbc.Driver",
#     #                                      dbtable="device diary_queue",
#     #                                      user="work", password="workwork").load()
#     # df = ctx.read.format("jdbc").options(url="jdbc:mysql://rm-m5e842126ng59jrv6.mysql.rds.aliyuncs.com:3306/doris_prod",
#     #                                      driver="com.mysql.jdbc.Driver",
#     #                                      dbtable="device diary_queue",
#     #                                      user="doris", password="o5gbA27hXHHm").load()
#
#     jdbcDf = ctx.read.format("jdbc").options(url="jdbc:mysql://192.168.15.12:4000",
#                                              driver="com.mysql.jdbc.Driver",
#                                              dbtable = "(select device_id from data_feed_click limit 8) tmp",
#                                              user="root",password = "3SYz54LS9#^9sBvC").load()
#     jdbcDf.show(6)
#
#     # url = "jdbc:mysql://10.66.157.22:4000/jerry_prod"
#     # table = "data_feed_click"
#     # properties = {"user": "root", "password": "3SYz54LS9#^9sBvC"}
#     # df = sqlContext.read.jdbc(url, table, properties)




# def hello(args):
#     import tensorflow as tf
#     from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
#     from tensorflow.contrib.boosted_trees.proto import learner_pb2 as gbdt_learner
#
#     # Ignore all GPUs (current TF GBDT does not support GPU).
#     import os
#     os.environ["CUDA_VISIBLE_DEVICES"] = ""
#
#     # Import MNIST data
#     # Set verbosity to display errors only (Remove this line for showing warnings)
#     tf.logging.set_verbosity(tf.logging.ERROR)
#     from tensorflow.examples.tutorials.mnist import input_data
#     mnist = input_data.read_data_sets("/tmp/data/", one_hot=False,
#                                       source_url='http://yann.lecun.com/exdb/mnist/')
#
#     # Parameters
#     batch_size = 10000# The number of samples per batch
#     num_classes = 10  # The 10 digits
#     num_features = 784  # Each image is 28x28 pixels
#     max_steps = 10000
#
#     # GBDT Parameters
#     learning_rate = 0.1
#     l1_regul = 0.
#     l2_regul = 1.
#     examples_per_layer = 1000
#     num_trees = 10
#     max_depth = 16
#
#     # Fill GBDT parameters into the config proto
#     learner_config = gbdt_learner.LearnerConfig()
#     learner_config.learning_rate_tuner.fixed.learning_rate = learning_rate
#     learner_config.regularization.l1 = l1_regul
#     learner_config.regularization.l2 = l2_regul / examples_per_layer
#     learner_config.constraints.max_tree_depth = max_depth
#     growing_mode = gbdt_learner.LearnerConfig.LAYER_BY_LAYER
#     learner_config.growing_mode = growing_mode
#     run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
#     learner_config.multi_class_strategy = (
#         gbdt_learner.LearnerConfig.DIAGONAL_HESSIAN) \
#  \
#         # Create a TensorFlor GBDT Estimator
#     gbdt_model = GradientBoostedDecisionTreeClassifier(
#         model_dir=None,  # No save directory specified
#         learner_config=learner_config,
#         n_classes=num_classes,
#         examples_per_layer=examples_per_layer,
#         num_trees=num_trees,
#         center_bias=False,
#         config=run_config)
#
#     # Display TF info logs
#     tf.logging.set_verbosity(tf.logging.INFO)
#
#     # Define the input function for training
#     input_fn = tf.estimator.inputs.numpy_input_fn(
#         x={'images': mnist.train.images}, y=mnist.train.labels,
#         batch_size=batch_size, num_epochs=None, shuffle=True)
#     # Train the Model
#     gbdt_model.fit(input_fn=input_fn, max_steps=max_steps)
#
#     # Evaluate the Model
#     # Define the input function for evaluating
#     input_fn = tf.estimator.inputs.numpy_input_fn(
#         x={'images': mnist.test.images}, y=mnist.test.labels,
#         batch_size=batch_size, shuffle=False)
#     # Use the Estimator 'evaluate' method
#     e = gbdt_model.evaluate(input_fn=input_fn)
#
#     print("Testing Accuracy:", e['accuracy'])


if __name__ == "__main__":
    from pyspark.sql import SQLContext
    from pyspark.context import SparkContext
    from pyspark.conf import SparkConf

    sc = SparkContext(conf=SparkConf().setAppName("mnist_streaming")).getOrCreate()
    ctx = SQLContext(sc)
    jdbcDf = ctx.read.format("jdbc").options(url="jdbc:mysql://192.168.15.12:4000/jerry_test",
                                             driver="com.mysql.jdbc.Driver",
                                             dbtable="(select * from nd_cid_similarity_matrix) tmp",
                                             user="root",
                                             password="").load()

    print(jdbcDf.printSchema())

    print(jdbcDf.collect())
    jdbcDf.show(6)
    # fetch_data("2018-11-11","2018-11-12")
  # from pyspark.context import SparkContext
  # from pyspark.conf import SparkConf
  # from tensorflowonspark import TFCluster
  # import argparse
  #
  # sc = SparkContext(conf=SparkConf().setAppName("mnist_spark"))
  # executors = sc._conf.get("spark.executor.instances")
  # num_executors = int(executors) if executors is not None else 1
  #
  # parser = argparse.ArgumentParser()
  # parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
  # parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
  # parser.add_argument("--data_dir", help="path to MNIST data", default="MNIST-data")
  # parser.add_argument("--model", help="path to save model/checkpoint", default="mnist_model")
  # parser.add_argument("--num_ps", help="number of PS nodes in cluster", type=int, default=1)
  # parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
  # parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
  #
  # args = parser.parse_args()
  # print("args:", args)
  #
  # cluster = TFCluster.run(sc, hello, args, args.cluster_size, args.num_ps, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model, master_node='master')
  # cluster.shutdown()