# -*- coding: utf-8 -*-
"""
Created on Tue May 15 13:59:43 2018

@author: hanye
"""

import json
import datetime
import random
import time
import re
from typing import Dict, List
import pymysql
import requests
from elasticsearch.exceptions import TransportError
from crawler_sys.framework.redis_interact import feed_url_into_redis
from crawler_sys.framework.redis_interact import redis_path
from crawler_sys.framework.es_ccr_index_defination import es_framework as es_site_crawler
from crawler_sys.framework.es_ccr_index_defination import index_url_register
from crawler_sys.framework.es_ccr_index_defination import doc_type_url_register
from crawler_sys.framework.es_ccr_index_defination import fields_url_register
from write_data_into_es.func_cal_doc_id import cal_doc_id
from crawler_sys.utils.write_into_file import write_str_into_file
from crawler.crawler_sys.proxy_pool.func_get_proxy_form_kuaidaili import get_proxy
from lxml import html
from lxml.html.clean import Cleaner
from crawler.gm_upload.gm_upload import upload, upload_file

index_site_crawler = 'crawler-data-raw'
doc_type_site_crawler = 'doc'


# 实例化mysql连接对象
class mysql_conn():
    def __init__(self, mysql_name):
        if mysql_name == "test":
            self.conn = pymysql.connect(host='bj-cdb-6slgqwlc.sql.tencentcdb.com', port=62120, user='work',
                                        passwd='Gengmei1',
                                        db='mimas_test', charset='utf8')
        elif mysql_name == "mimas":
            self.conn = pymysql.connect(host='172.16.30.138', port=3306, user='mimas',
                                        passwd='GJL3UJe1Ck9ggL6aKnZCq4cRvM',
                                        db='mimas_prod', charset='utf8mb4')
        self.cur = self.conn.cursor()


def form_data_Lst_for_url_register(data_Lst_ori):
    ts = int(datetime.datetime.now().timestamp() * 1e3)
    data_Lst_reg = []
    for line in data_Lst_ori:
        try:
            fields_ori = set(line.keys())
            fields_rem = set(fields_url_register)
            fields_to_remove = fields_ori.difference(fields_rem)
            for field in fields_to_remove:
                line.pop(field)
            line['timestamp'] = ts
            data_Lst_reg.append(line)
        except:
            print('attributeerror at %s' % data_Lst_ori.index(line))
    return data_Lst_reg


def hot_words_output_result(result_Lst, output_index="short-video-hotwords"):
    bulk_all_body = ""
    for count, result in enumerate(result_Lst):
        doc_id = result["platform"] + "_" + result["title"]
        bulk_head = '{"index": {"_id":"%s"}}' % doc_id
        data_str = json.dumps(result, ensure_ascii=False)
        bulk_one_body = bulk_head + '\n' + data_str + '\n'
        bulk_all_body += bulk_one_body
        if count % 500 == 0 and count > 0:

            eror_dic = es_site_crawler.bulk(index=output_index,
                                            body=bulk_all_body, request_timeout=200)
            bulk_all_body = ''
            if eror_dic['errors'] is True:
                print(eror_dic['items'])
                print(bulk_all_body)
            print(count)

    if bulk_all_body != '':
        eror_dic = es_site_crawler.bulk(body=bulk_all_body,
                                        index=output_index,

                                        request_timeout=200)
        if eror_dic['errors'] is True:
            print(eror_dic)


WHITE_TAGS = {
    "basic": ["div", "p", "span", "img", "br", "video", 'a'],  # 暂定小程序及爬取数据使用
    "all": [
        "div", "p", "span", "img", "br", "video", "audio", "a", "b", "strong", "i", "ul", "ol", "li", "em", "h1",
        "h2", "h3", "h4", "h5", "h6", "iframe",
    ]  # 可以展示的所有白标签
}

img_type = {
    "OTHER": 1,
    # '其他图片'
    "GIF": 2,
    # "GIF动图")
    "JPG": 3,
    # "JPG图片")
    "JPEG": 4,
    # "JPEG图片")
    "PNG": 5,
    # "PNG图片")
    "BMP": 6,
    # "BMP位图")
    "WEBP": 7,
    # "WEBP图片类型")
    "TIFF": 8,
    # "TIFF图片类型")
}

def gm_convert_html_tags(rich_text, all_tags=False, remove_tags=None):
    """
    富文本内容重新清洗,剔除不需要的样式
    :param rich_text:  富文本
    :param all_tags:  是否需要匹配所有白名单中的标签
    :param remove_tags: 需要剔除的,白名单标签 []
    :return:
    """
    if not rich_text:
        return ""

    # rich_text = _get_rich_text(rich_text)

    # 标签清洗 + 补齐 参数
    tags = WHITE_TAGS["all"] if all_tags else WHITE_TAGS["basic"]
    if remove_tags:
        tags = [tag for tag in tags if tag not in remove_tags]

    kw = {
        "remove_unknown_tags": False,
        "allow_tags": tags,
        "safe_attrs": ["src", ],
    }

    if "a" in tags:
        kw["safe_attrs"].append("href")

    elif all_tags:
        kw["safe_attrs"].extend(["class", "style"])

    if "iframe" in kw["allow_tags"]:
        kw["embedded"] = False

    clear = Cleaner(**kw)
    rich_text = clear.clean_html(rich_text)

    # 增加样式
    element_obj = html.fromstring(rich_text)
    for element in element_obj.xpath(u"//img|//video"):

        if not all_tags:  # 小程序,普通用户,爬取数据
            element.attrib["width"] = "100%"  # 图片、视频增加宽度 100%

        if element.tag == "video" and all_tags:
            element.attrib["class"] = "js_richtext_video"

    # 移除a标签中跳转链不是gengmei开头的链接
    for item in element_obj.xpath("//a[not(starts-with(@href, 'gengmei://'))]"):
        item.getparent().remove(item)

    # a 标签追加样式
    for item in element_obj.xpath("//a"):
        item.attrib["style"] = 'color:#3FB5AF'  # a标签颜色

    rich_text = html.tostring(element_obj, encoding="unicode")

    return rich_text


def push_data_to_user(res_data: Dict) -> Dict:
    """
    处理数据为可以入库的格式
    :param res_data:
    :return:
    """
    qiniu_img_list = []
    if res_data["img_list"]:
        for img_url in res_data["img_list"]:
            try:
                img_wb = retry_get_url(img_url).content
                res = upload(img_wb)
                print(res)
                img_info = retry_get_url(res + "-imageinfo")
                img_info_json = img_info.json()
                qiniu_img_list.append((res + "-w", img_info_json))
            except Exception as e:
                print("down load img error %s" % e)
                return {}

        # 替换图片
        if res_data["platform"] == "weibo":
            res_data["qiniu_img_list"] = qiniu_img_list
            if "http://t.cn/" in res_data["title"]:
                res_data["title"] = res_data["title"].split("http://t.cn/")[0]
                res_data["content"] = res_data["title"]
        elif res_data["platform"] == "douban":
            content = res_data.get("content")
            if content:
                for count, img_url in enumerate(res_data["img_list"]):
                    # print(qiniu_img_list[count][0])
                    content = content.replace(img_url, qiniu_img_list[count][0])
                res_data["qiniu_img_list"] = qiniu_img_list
                res_data["content"] = content
    if res_data["platform"] == "weibo":
        res_data["content"] = gm_convert_html_tags(res_data["title"], all_tags=True)
        res_data["title"] = ""
    elif res_data["platform"] == "douban":
        res_data["content"] = gm_convert_html_tags(res_data["content"], all_tags=True)
    return res_data


def write_data_into_mysql(res_data,mysql):
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # 清洗数据为可以入库的格式
    data = push_data_to_user(res_data)
    if not data.get("content"):
        return None
    if not data.get("qiniu_img_list"):
        return None
    tractate_id = 0
    try:
        sql_query = """insert into api_tractate 
            (user_id,content,is_online,status,platform,content_level,is_excellent,create_time,last_modified,user_del,low_quality,low_quality_deal,platform_id,pgc_type,title) 
            values ({user_id},'{content}',{is_online},{status},{platform},{content_level},{is_excellent},'{create_time}','{last_modified}',{user_del},{low_quality},{low_quality_deal},'{platform_id}',{pgc_type},'{title}');""".format(
            user_id=random.choice(user_id_list), content=data["content"], is_online=0, status=2, platform=15,
            content_level=data["level"],
            is_excellent=0, create_time=now_str,
            last_modified=now_str, user_del=0,
            low_quality=0, low_quality_deal=0, platform_id=data["doc_id"], pgc_type=0, title=data["title"])
        res = mysql.cur.execute(sql_query)
        tractate_id = int(mysql.conn.insert_id())
        if res:
            mysql.conn.commit()
    except Exception as e:
        print("commit error %s" % e)
        print(data)
        mysql.conn.rollback()

    if data.get("qiniu_img_list"):
        for img_info in data.get("qiniu_img_list"):
            if img_info[0] in data.get("content"):
                image_url_source = 2
            else:
                image_url_source = 3
            try:
                image_type = img_type.get(img_info[1]["format"].upper())
            except:
                image_type = 1
            try:
                width = img_info[1]["width"]
                height = img_info[1]["height"]
            except:
                width = 0
                height = 0
            try:
                if img_type == 7:
                    sql_query = """
                                        insert into api_tractate_images (tractate_id,image_url,width,image_webp,height,image_url_source,image_type,image_webp,create_time,update_time)
                                        values ({tractate_id},'{image_url}',{width},{height},{image_webp},{image_url_source},{image_type},{image_webp},'{create_time}','{update_time}')
                                        """.format(tractate_id=tractate_id, image_url=img_info[0], width=width,
                                                   height=height, image_url_source=image_url_source,
                                                   image_type=image_type, image_webp=img_info[0],
                                                   create_time=now_str, update_time=now_str)
                else:
                    sql_query = """
                    insert into api_tractate_images (tractate_id,image_url,width,height,image_url_source,image_type,create_time,update_time)
                    values ({tractate_id},'{image_url}',{width},{height},{image_url_source},{image_type},'{create_time}','{update_time}')
                    """.format(tractate_id=tractate_id, image_url=img_info[0], width=width,
                               height=height, image_url_source=image_url_source, image_type=image_type,
                               create_time=now_str, update_time=now_str)
                res = mysql.cur.execute(sql_query)

                if res:
                    mysql.conn.commit()
            except Exception as e:
                print("commit error %s" % e)
                mysql.conn.rollback()
    if tractate_id:
        return tractate_id
    else:
        return None


def output_result(result_Lst, platform,
                  output_to_file=False, filepath=None,
                  output_to_es_raw=False,
                  output_to_es_register=False,
                  push_to_redis=False,
                  output_to_test_mysql=False,
                  output_to_mimas_mysql=False,
                  es_index=index_site_crawler,
                  rds_path="on_line",
                  **kwargs):
    # write data into es crawler-raw index
    if output_to_es_raw:
        bulk_write_into_es(result_Lst, es_index)

    # write data into es crawler-url-register index
    if output_to_es_register:
        # data_Lst_reg = form_data_Lst_for_url_register(result_Lst)
        bulk_write_into_es(result_Lst,
                           index=es_index,
                           construct_id=True,
                           platform=platform,

                           )

    if output_to_test_mysql:
        pass
    # feed url into redis
    if push_to_redis:
        rds = redis_path(rds_path)
        feed_url_into_redis(
            result_Lst, expire=kwargs.get("expire"),rds=rds)

    # output into file according to passed in parameters
    if output_to_file is True and filepath is not None:
        output_fn = ('crawler_%s_on_%s_json'
                     % (platform, datetime.datetime.now().isoformat()[:10]))
        output_f = open(filepath + '/' + output_fn, 'a', encoding='utf-8')
        write_into_file(result_Lst, output_f)
        output_f.close()
    else:
        return result_Lst


def retry_get_url(url, retrys=3, proxies=None, timeout=10, **kwargs):
    retry_c = 0
    while retry_c < retrys:
        try:
            if proxies:
                proxies_dic = get_proxy(proxies)
                if not proxies_dic:
                    get_resp = requests.get(url, timeout=timeout, **kwargs)
                else:
                    get_resp = requests.get(url, proxies=proxies_dic, timeout=timeout, **kwargs)
            else:
                get_resp = requests.get(url, timeout=timeout, **kwargs)
            return get_resp
        except Exception as e:
            retry_c += 1
            time.sleep(1)
            print(e)
    print('Failed to get page %s after %d retries, %s'
          % (url, retrys, datetime.datetime.now()))
    return None


def get_ill_encoded_str_posi(UnicodeEncodeError_msg):
    posi_nums = []
    get_err_posi = re.findall('\s+[0-9]+-[0-9]+:', UnicodeEncodeError_msg)
    if get_err_posi != []:
        posi = get_err_posi[0].replace(' ', '').replace(':', '').split('-')
        for pp in posi:
            try:
                ppn = int(pp)
                posi_nums.append(ppn)
            except:
                pass
    else:
        pass
    return posi_nums


def bulk_write_into_es(dict_Lst,
                       index,
                       construct_id=False,
                       platform=None):
    bulk_write_body = ''
    write_counter = 0

    def bulk_write_with_retry_UnicodeEncodeError(index, bulk_write_body,
                                                 retry_counter_for_UnicodeEncodeError):
        if bulk_write_body != '':
            try:
                bulk_write_resp = es_site_crawler.bulk(index=index,
                                                       body=bulk_write_body,
                                                       request_timeout=100)
                if bulk_write_resp['errors'] is True:
                    print(bulk_write_resp)
                bulk_write_body = ''
                #                print(bulk_write_resp)
                print('Writing into es done')
            except UnicodeEncodeError as ue:
                print('Got UnicodeEncodeError, will remove ill formed string and retry.')
                retry_counter_for_UnicodeEncodeError += 1
                UnicodeEncodeError_msg = ue.__str__()
                ill_str_idxs = get_ill_encoded_str_posi(UnicodeEncodeError_msg)
                if len(ill_str_idxs) == 2:
                    ill_str = bulk_write_body[ill_str_idxs[0]: ill_str_idxs[1] + 1]
                    bulk_write_body = bulk_write_body.replace(ill_str, '')
                    bulk_write_with_retry_UnicodeEncodeError(index, bulk_write_body,
                                                             retry_counter_for_UnicodeEncodeError
                                                             )
            except TransportError:
                print("output to es register error")
                write_str_into_file(file_path='/home/',
                                    file_name='debug',
                                    var=bulk_write_body)
        return retry_counter_for_UnicodeEncodeError

    for line in dict_Lst:
        write_counter += 1
        if construct_id and platform is not None:
            doc_id = cal_doc_id(platform, url=line["url"], doc_id_type='all-time-url', data_dict=line)
            action_str = ('{ "index" : { "_index" : "%s", "_id" : "%s" } }'
                          % (index, doc_id))
        else:
            action_str = ('{ "index" : { "_index" : "%s" }'
                          % (index))
        data_str = json.dumps(line, ensure_ascii=False)
        line_body = action_str + '\n' + data_str + '\n'
        bulk_write_body += line_body
        if write_counter % 1000 == 0 or write_counter == len(dict_Lst):
            print('Writing into es %s %d/%d' % (index,
                                                write_counter,
                                                len(dict_Lst)))
            if bulk_write_body != '':
                retry_counter_for_UnicodeEncodeError = 0
                retry_counter_for_UnicodeEncodeError = bulk_write_with_retry_UnicodeEncodeError(index,
                                                                                                bulk_write_body,
                                                                                                retry_counter_for_UnicodeEncodeError)
                bulk_write_body = ''


def write_into_file(dict_Lst, file_obj):
    for data_dict in dict_Lst:
        json_str = json.dumps(data_dict)
        file_obj.write(json_str)
        file_obj.write('\n')
        file_obj.flush()


def load_json_file_into_dict_Lst(filename, path):
    if path[-1] != '/':
        path += '/'
    data_Lst = []
    with open(path + filename, 'r', encoding='utf-8') as f:
        for line in f:
            line_d = json.loads(line)
            if 'data_provider' not in line_d:
                line_d['data_provider'] = 'CCR'
            if 'releaser_id' in line_d:
                try:
                    line_d['releaser_id'] = int(line_d['releaser_id'])
                except:
                    line_d.pop('releaser_id')
            data_Lst.append(line_d)
    return data_Lst


def crawl_a_url_and_update_redis(url, platform, urlhash, rds,processID=-1,):
    # find crawler
    # perform crawling, get the data
    # write es or output to files
    # update redis
    ts = int(datetime.datetime.now().timestamp() * 1e3)
    redis_hmset_dict = {'push_time': ts, 'is_fetched': 1,
                        'url': url, 'platform': platform}
    rds.hmset(urlhash, redis_hmset_dict)


def crawl_batch_task(url_Lst):
    for url_info in url_Lst:
        crawl_a_url_and_update_redis(url_info['url'],
                                     url_info['platform'],
                                     url_info['urlhash'])


def scan_redis_to_crawl(rds):
    batch_size = 1000
    cur = 0
    task_batchs = []
    scan_counter = 0
    while True:
        scan_counter += 1
        if scan_counter % 5 == 0:
            print(scan_counter, 'cur:', cur, datetime.datetime.now())
        cur, hash_keys = rds.scan(cur)
        for urlhash in hash_keys:
            if len(urlhash) == 40:
                url_d = rds.hgetall(urlhash)
                url = url_d[b'url'].decode()
                platform = url_d[b'platform'].decode()
                is_fetched = int(url_d[b'is_fetched'].decode())
                if is_fetched == 0:
                    task_batchs.append({'url': url,
                                        'platform': platform,
                                        'urlhash': urlhash})
                    if len(task_batchs) == batch_size:
                        # multi-processing here
                        crawl_batch_task(rds,task_batchs)
                        task_batchs.clear()
        if cur == 0:
            break


def remove_fetched_url_from_redis(rds,remove_interval=10):
    time.sleep(remove_interval)
    cur = 0
    delete_counter = 0
    while True:
        cur, hash_keys = rds.scan(cur)
        for urlhash in hash_keys:
            if len(urlhash) == 40:
                url_d = rds.hgetall(urlhash)
                try:
                    is_fetched = int(url_d[b'is_fetched'].decode())
                    if is_fetched == 1:
                        rds.delete(urlhash)
                        delete_counter += 1
                except:
                    pass
        if cur == 0:
            break
    print('delete_counter', delete_counter)
    return delete_counter