mirror of https://github.com/veops/cmdb.git
119 lines
3.5 KiB
Python
119 lines
3.5 KiB
Python
# -*- coding:utf-8 -*-
|
|
|
|
|
|
from functools import wraps
|
|
|
|
from flask import abort
|
|
from flask import current_app
|
|
from flask import request
|
|
from sqlalchemy.exc import InvalidRequestError
|
|
from sqlalchemy.exc import OperationalError
|
|
from sqlalchemy.exc import PendingRollbackError
|
|
from sqlalchemy.exc import StatementError
|
|
|
|
from api.extensions import db
|
|
from api.lib.resp_format import CommonErrFormat
|
|
|
|
|
|
def kwargs_required(*required_args):
|
|
def decorate(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
for arg in required_args:
|
|
if arg not in kwargs:
|
|
return abort(400, CommonErrFormat.argument_required.format(arg))
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorate
|
|
|
|
|
|
def args_required(*required_args, **value_required):
|
|
def decorate(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
for arg in required_args:
|
|
if arg not in request.values:
|
|
return abort(400, CommonErrFormat.argument_required.format(arg))
|
|
|
|
if value_required.get('value_required', True) and not request.values.get(arg):
|
|
return abort(400, CommonErrFormat.argument_value_required.format(arg))
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorate
|
|
|
|
|
|
def args_validate(model_cls, exclude_args=None):
|
|
def decorate(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
for arg in request.values:
|
|
if hasattr(model_cls, arg):
|
|
attr = getattr(model_cls, arg)
|
|
if not hasattr(attr, "type"):
|
|
continue
|
|
|
|
if exclude_args and arg in exclude_args:
|
|
continue
|
|
|
|
if attr.type.python_type == str and attr.type.length and (
|
|
len(request.values[arg] or '') > attr.type.length):
|
|
|
|
return abort(400, CommonErrFormat.argument_str_length_limit.format(arg, attr.type.length))
|
|
elif attr.type.python_type in (int, float) and request.values[arg]:
|
|
try:
|
|
int(float(request.values[arg]))
|
|
except (TypeError, ValueError):
|
|
return abort(400, CommonErrFormat.argument_invalid.format(arg))
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorate
|
|
|
|
|
|
def reconnect_db(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except (StatementError, OperationalError, InvalidRequestError) as e:
|
|
error_msg = str(e)
|
|
if 'Lost connection' in error_msg or 'reconnect until invalid transaction' in error_msg or \
|
|
'can be emitted within this transaction' in error_msg:
|
|
current_app.logger.info('[reconnect_db] lost connect rollback then retry')
|
|
db.session.rollback()
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise e
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return wrapper
|
|
|
|
|
|
def _flush_db():
|
|
try:
|
|
db.session.commit()
|
|
except (StatementError, OperationalError, InvalidRequestError, PendingRollbackError):
|
|
db.session.rollback()
|
|
|
|
|
|
def flush_db(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
_flush_db()
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def run_flush_db():
|
|
_flush_db()
|