# -*- coding:utf-8 -*-


from functools import wraps

from flask import abort
from flask import current_app
from flask import request
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import PendingRollbackError
from sqlalchemy.exc import StatementError

from api.extensions import db
from api.lib.resp_format import CommonErrFormat


def kwargs_required(*required_args):
    def decorate(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for arg in required_args:
                if arg not in kwargs:
                    return abort(400, CommonErrFormat.argument_required.format(arg))

            return func(*args, **kwargs)

        return wrapper

    return decorate


def args_required(*required_args, **value_required):
    def decorate(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for arg in required_args:
                if arg not in request.values:
                    return abort(400, CommonErrFormat.argument_required.format(arg))

                if value_required.get('value_required', True) and not request.values.get(arg):
                    return abort(400, CommonErrFormat.argument_value_required.format(arg))

            return func(*args, **kwargs)

        return wrapper

    return decorate


def args_validate(model_cls, exclude_args=None):
    def decorate(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for arg in request.values:
                if hasattr(model_cls, arg):
                    attr = getattr(model_cls, arg)
                    if not hasattr(attr, "type"):
                        continue

                    if exclude_args and arg in exclude_args:
                        continue

                    if attr.type.python_type == str and attr.type.length and (
                            len(request.values[arg] or '') > attr.type.length):

                        return abort(400, CommonErrFormat.argument_str_length_limit.format(arg, attr.type.length))
                    elif attr.type.python_type in (int, float) and request.values[arg]:
                        try:
                            int(float(request.values[arg]))
                        except (TypeError, ValueError):
                            return abort(400, CommonErrFormat.argument_invalid.format(arg))

            return func(*args, **kwargs)

        return wrapper

    return decorate


def reconnect_db(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except (StatementError, OperationalError, InvalidRequestError) as e:
            error_msg = str(e)
            if 'Lost connection' in error_msg or 'reconnect until invalid transaction' in error_msg or \
                    'can be emitted within this transaction' in error_msg:
                current_app.logger.info('[reconnect_db] lost connect rollback then retry')
                db.session.rollback()
                return func(*args, **kwargs)
            else:
                raise e
        except Exception as e:
            raise e

    return wrapper


def _flush_db():
    try:
        db.session.commit()
    except (StatementError, OperationalError, InvalidRequestError, PendingRollbackError):
        db.session.rollback()


def flush_db(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        _flush_db()
        return func(*args, **kwargs)

    return wrapper


def run_flush_db():
    _flush_db()