# -*- coding: utf-8 -*-
import logging
import six

from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from django.db.models.sql import compiler
from django.db.models.sql.constants import MULTI

from sqlalchemy import exc
from sqlalchemy import event
from sqlalchemy.pool import manage
from sqlalchemy.pool import Pool


logger = logging.getLogger('django.db.backends')

POOL_PESSIMISTIC_MODE = getattr(settings, "DJORM_POOL_PESSIMISTIC", False)
POOL_SETTINGS = getattr(settings, 'DJORM_POOL_OPTIONS', {})
POOL_SETTINGS.setdefault("recycle", 3600)


@event.listens_for(Pool, "checkout")
def _on_checkout(dbapi_connection, connection_record, connection_proxy):
    logger.info("connection retrieved from pool")

    if POOL_PESSIMISTIC_MODE:
        cursor = dbapi_connection.cursor()
        try:
            cursor.execute("SELECT 1")
        except:
            # raise DisconnectionError - pool will try
            # connecting again up to three times before raising.
            raise exc.DisconnectionError()
        finally:
            cursor.close()


@event.listens_for(Pool, "checkin")
def _on_checkin(*args, **kwargs):
    logger.info("connection returned to pool")


@event.listens_for(Pool, "connect")
def _on_connect(*args, **kwargs):
    logger.info("connection created")


def patch_mysql():
    class hashabledict(dict):
        def __hash__(self):
            if six.PY2:
                _v = sorted(self.items())
            else:
                _v = self.items()
            return hash(tuple(_v))

    class hashablelist(list):
        def __hash__(self):
            if six.PY2:
                _v = sorted(self)
            else:
                _v = self
            return hash(tuple(_v))

    class ManagerProxy(object):
        def __init__(self, manager):
            self.manager = manager

        def __getattr__(self, key):
            return getattr(self.manager, key)

        def connect(self, *args, **kwargs):
            if 'conv' in kwargs:
                conv = kwargs['conv']
                if isinstance(conv, dict):
                    items = []
                    for k, v in conv.items():
                        if isinstance(v, list):
                            v = hashablelist(v)
                        items.append((k, v))
                    kwargs['conv'] = hashabledict(items)
            if 'ssl' in kwargs:
                ssl = kwargs['ssl']
                if isinstance(ssl, dict):
                    items = []
                    for k, v in ssl.items():
                        if isinstance(v, list):
                            v = hashablelist(v)
                        items.append((k, v))
                    kwargs['ssl'] = hashabledict(items)
            return self.manager.connect(*args, **kwargs)

    try:
        from django.db.backends.mysql import base as mysql_base
    except (ImproperlyConfigured, ImportError) as e:
        return

    if not hasattr(mysql_base, "_Database"):
        mysql_base._Database = mysql_base.Database
        mysql_base.Database = ManagerProxy(manage(mysql_base._Database, **POOL_SETTINGS))


def install_django_orm_patch():
    execute_sql = compiler.SQLCompiler.execute_sql

    def patched_execute_sql(self, result_type=MULTI):
        result = execute_sql(self, result_type)
        logger.info(self.connection.in_atomic_block)
        if not self.connection.in_atomic_block:
            self.connection.close()  # return connection to pool by db_pool_patch
        return result

    compiler.SQLCompiler.execute_sql = patched_execute_sql

    insert_execute_sql = compiler.SQLInsertCompiler.execute_sql

    def patched_insert_execute_sql(self, return_id=False):
        result = insert_execute_sql(self, return_id)
        if not self.connection.in_atomic_block:
            self.connection.close()  # return connection to pool by db_pool_patch
        return result

    compiler.SQLInsertCompiler.execute_sql = patched_insert_execute_sql


def install_patch():
    patch_mysql()
    install_django_orm_patch()