import base64

import tensorflow as tf
import requests

#encoding=utf8
import requests
import numpy as np
import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=3)


from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc

# tf.app.flags.DEFINE_string('server', 'localhost:8502', 'PredictionService host:port')
tf.app.flags.DEFINE_string('server', 'tensorserving-sk.paas-develop.env:8090', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS

def prediction():
    options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), ('grpc.max_receive_message_length', 1000 * 1024 * 1024)]
    channel = grpc.insecure_channel(FLAGS.server, options = options)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'service' #对应上图第一个方框
    request.model_spec.signature_name = 'regression' #对应上图第二个方框with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
    for _ in range(1):
        with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:

            count = 0
            examples = []

            for line in f:
                splits = line.split('|')
                features = {
                    'ITEM_CATEGORY_card_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[0].encode()])),
                    'USER_CATEGORY_device_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[2].encode()])),
                    'USER_CATEGORY_os': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[3].encode()])),
                    'USER_CATEGORY_user_city_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[4].encode()])),
                    'USER_MULTI_CATEGORY_second_solutions': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[6].split(','))))),
                    'USER_MULTI_CATEGORY_second_demands': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[7].split(','))))),
                    'USER_MULTI_CATEGORY_second_positions': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[8].split(','))))),
                    'USER_MULTI_CATEGORY_projects': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[9].split(','))))),
                    'ITEM_NUMERIC_click_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[10])])),
                    'ITEM_NUMERIC_click_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[11])])),
                    'ITEM_NUMERIC_click_count_stddev': tf.train.Feature(
                        float_list=tf.train.FloatList(value=[float(splits[12])])),
                    'ITEM_NUMERIC_exp_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[13])])),
                    'ITEM_NUMERIC_exp_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[14])])),
                    'ITEM_NUMERIC_exp_count_stddev': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[15])])),
                    'ITEM_NUMERIC_discount': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[16])])),
                    'ITEM_NUMERIC_case_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[17])])),
                    'ITEM_NUMERIC_sales_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[18])])),
                    'ITEM_CATEGORY_service_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[19].encode()])),
                    'ITEM_CATEGORY_merchant_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[20].encode()])),
                    'ITEM_CATEGORY_doctor_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[21].encode()])),
                    'ITEM_CATEGORY_doctor_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[22].encode()])),
                    'ITEM_CATEGORY_doctor_famous': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[23].encode()])),
                    'ITEM_CATEGORY_hospital_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[24].encode()])),
                    'ITEM_CATEGORY_hospital_city_tag_id': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[splits[25].encode()])),
                    'ITEM_CATEGORY_hospital_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[26].encode()])),
                    'ITEM_CATEGORY_hospital_is_high_quality': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[splits[27].encode()])),
                    'ITEM_MULTI_CATEGORY_second_demands': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[28].split(','))))),
                    'ITEM_MULTI_CATEGORY_second_solutions': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[29].split(','))))),
                    'ITEM_MULTI_CATEGORY_second_positions': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[30].split(','))))),
                    'ITEM_MULTI_CATEGORY_projects': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[31].split(','))))),
                    'ITEM_NUMERIC_sku_price': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[32])])),
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(splits[5])])),
                }
                #         print(features)

                #         print(splits[32])
                tf_features = tf.train.Features(feature=features)

                tf_example = tf.train.Example(features=tf_features)

                tf_serialized = tf_example.SerializeToString()

                examples.append(tf_serialized)

                count += 1
                if count == 1:
                    break
        start = time.time()
        # request.inputs['examples'].CopyFrom(tf.make_tensor_proto(examples))  # in对应上图第三个方框,为模型的输入Name
        tensor_proto = tf.make_tensor_proto(examples)
        request.inputs['inputs'].CopyFrom(tensor_proto)  # in对应上图第三个方框,为模型的输入Name

        result_future = stub.Predict.future(request, 10.0)  # 10 secs timeout
        result = result_future.result()
        # print(result)
        print(time.time() - start)

if __name__ == "__main__":
    prediction()