From cf85367b4ba6a0ee25dd3b99d2ac0580d867dd45 Mon Sep 17 00:00:00 2001
From: pycook <pycook@126.com>
Date: Thu, 19 Oct 2023 11:49:56 +0800
Subject: [PATCH] fix(celery worker): db server has gone away

---
 cmdb-api/api/lib/decorator.py | 42 +++++++++++++++++++++++++++++++++++
 cmdb-api/api/tasks/acl.py     | 14 +++++++-----
 cmdb-api/api/tasks/cmdb.py    | 31 +++++++++++++++++---------
 3 files changed, 72 insertions(+), 15 deletions(-)

diff --git a/cmdb-api/api/lib/decorator.py b/cmdb-api/api/lib/decorator.py
index 6284f65..185bd3c 100644
--- a/cmdb-api/api/lib/decorator.py
+++ b/cmdb-api/api/lib/decorator.py
@@ -4,8 +4,13 @@
 from functools import wraps
 
 from flask import abort
+from flask import current_app
 from flask import request
+from sqlalchemy.exc import InvalidRequestError
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.exc import StatementError
 
+from api.extensions import db
 from api.lib.resp_format import CommonErrFormat
 
 
@@ -70,3 +75,40 @@ def args_validate(model_cls, exclude_args=None):
         return wrapper
 
     return decorate
+
+
+def reconnect_db(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        try:
+            return func(*args, **kwargs)
+        except (StatementError, OperationalError, InvalidRequestError) as e:
+            error_msg = str(e)
+            if 'Lost connection' in error_msg or 'reconnect until invalid transaction' in error_msg or \
+                    'can be emitted within this transaction' in error_msg:
+                current_app.logger.info('[reconnect_db] lost connect rollback then retry')
+                db.session.rollback()
+                return func(*args, **kwargs)
+            else:
+                raise e
+        except Exception as e:
+            raise e
+
+    return wrapper
+
+
+def _flush_db():
+    db.session.commit()
+
+
+def flush_db(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        _flush_db()
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def run_flush_db():
+    _flush_db()
diff --git a/cmdb-api/api/tasks/acl.py b/cmdb-api/api/tasks/acl.py
index bfddd65..750eb2e 100644
--- a/cmdb-api/api/tasks/acl.py
+++ b/cmdb-api/api/tasks/acl.py
@@ -9,7 +9,8 @@ from werkzeug.exceptions import BadRequest
 from werkzeug.exceptions import NotFound
 
 from api.extensions import celery
-from api.extensions import db
+from api.lib.decorator import flush_db
+from api.lib.decorator import reconnect_db
 from api.lib.perm.acl.audit import AuditCRUD
 from api.lib.perm.acl.audit import AuditOperateSource
 from api.lib.perm.acl.audit import AuditOperateType
@@ -28,6 +29,7 @@ from api.models.acl import Trigger
              name="acl.role_rebuild",
              queue=ACL_QUEUE,
              once={"graceful": True, "unlock_before_run": True})
+@reconnect_db
 def role_rebuild(rids, app_id):
     rids = rids if isinstance(rids, list) else [rids]
     for rid in rids:
@@ -37,6 +39,7 @@ def role_rebuild(rids, app_id):
 
 
 @celery.task(name="acl.update_resource_to_build_role", queue=ACL_QUEUE)
+@reconnect_db
 def update_resource_to_build_role(resource_id, app_id, group_id=None):
     rids = [i.id for i in Role.get_by(__func_isnot__key_uid=None, fl='id', to_dict=False)]
     rids += [i.id for i in Role.get_by(app_id=app_id, fl='id', to_dict=False)]
@@ -52,9 +55,9 @@ def update_resource_to_build_role(resource_id, app_id, group_id=None):
 
 
 @celery.task(name="acl.apply_trigger", queue=ACL_QUEUE)
+@flush_db
+@reconnect_db
 def apply_trigger(_id, resource_id=None, operator_uid=None):
-    db.session.remove()
-
     from api.lib.perm.acl.permission import PermissionCRUD
 
     trigger = Trigger.get_by_id(_id)
@@ -118,9 +121,9 @@ def apply_trigger(_id, resource_id=None, operator_uid=None):
 
 
 @celery.task(name="acl.cancel_trigger", queue=ACL_QUEUE)
+@flush_db
+@reconnect_db
 def cancel_trigger(_id, resource_id=None, operator_uid=None):
-    db.session.remove()
-
     from api.lib.perm.acl.permission import PermissionCRUD
 
     trigger = Trigger.get_by_id(_id)
@@ -186,6 +189,7 @@ def cancel_trigger(_id, resource_id=None, operator_uid=None):
 
 
 @celery.task(name="acl.op_record", queue=ACL_QUEUE)
+@reconnect_db
 def op_record(app, rolename, operate_type, obj):
     if isinstance(app, int):
         app = AppCache.get(app)
diff --git a/cmdb-api/api/tasks/cmdb.py b/cmdb-api/api/tasks/cmdb.py
index 83a767a..c41f556 100644
--- a/cmdb-api/api/tasks/cmdb.py
+++ b/cmdb-api/api/tasks/cmdb.py
@@ -16,6 +16,8 @@ from api.lib.cmdb.cache import CITypeAttributesCache
 from api.lib.cmdb.const import CMDB_QUEUE
 from api.lib.cmdb.const import REDIS_PREFIX_CI
 from api.lib.cmdb.const import REDIS_PREFIX_CI_RELATION
+from api.lib.decorator import flush_db
+from api.lib.decorator import reconnect_db
 from api.lib.perm.acl.cache import UserCache
 from api.lib.utils import Lock
 from api.lib.utils import handle_arg_list
@@ -25,11 +27,12 @@ from api.models.cmdb import CITypeAttribute
 
 
 @celery.task(name="cmdb.ci_cache", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def ci_cache(ci_id, operate_type, record_id):
     from api.lib.cmdb.ci import CITriggerManager
 
     time.sleep(0.01)
-    db.session.remove()
 
     m = api.lib.cmdb.ci.CIManager()
     ci_dict = m.get_ci_by_id_from_db(ci_id, need_children=False, use_master=False)
@@ -49,9 +52,10 @@ def ci_cache(ci_id, operate_type, record_id):
 
 
 @celery.task(name="cmdb.batch_ci_cache", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def batch_ci_cache(ci_ids, ):  # only for attribute change index
     time.sleep(1)
-    db.session.remove()
 
     for ci_id in ci_ids:
         m = api.lib.cmdb.ci.CIManager()
@@ -66,6 +70,7 @@ def batch_ci_cache(ci_ids, ):  # only for attribute change index
 
 
 @celery.task(name="cmdb.ci_delete", queue=CMDB_QUEUE)
+@reconnect_db
 def ci_delete(ci_id):
     current_app.logger.info(ci_id)
 
@@ -78,6 +83,7 @@ def ci_delete(ci_id):
 
 
 @celery.task(name="cmdb.ci_delete_trigger", queue=CMDB_QUEUE)
+@reconnect_db
 def ci_delete_trigger(trigger, operate_type, ci_dict):
     current_app.logger.info('delete ci {} trigger'.format(ci_dict['_id']))
     from api.lib.cmdb.ci import CITriggerManager
@@ -89,9 +95,9 @@ def ci_delete_trigger(trigger, operate_type, ci_dict):
 
 
 @celery.task(name="cmdb.ci_relation_cache", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def ci_relation_cache(parent_id, child_id):
-    db.session.remove()
-
     with Lock("CIRelation_{}".format(parent_id)):
         children = rd.get([parent_id], REDIS_PREFIX_CI_RELATION)[0]
         children = json.loads(children) if children is not None else {}
@@ -106,6 +112,8 @@ def ci_relation_cache(parent_id, child_id):
 
 
 @celery.task(name="cmdb.ci_relation_add", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def ci_relation_add(parent_dict, child_id, uid):
     """
     :param parent_dict: key is '$parent_model.attr_name'
@@ -121,8 +129,6 @@ def ci_relation_add(parent_dict, child_id, uid):
     current_app.test_request_context().push()
     login_user(UserCache.get(uid))
 
-    db.session.remove()
-
     for parent in parent_dict:
         parent_ci_type_name, _attr_name = parent.strip()[1:].split('.', 1)
         attr_name = CITypeAttributeManager.get_attr_name(parent_ci_type_name, _attr_name)
@@ -147,10 +153,14 @@ def ci_relation_add(parent_dict, child_id, uid):
                 except Exception as e:
                     current_app.logger.warning(e)
                 finally:
-                    db.session.remove()
+                    try:
+                        db.session.commit()
+                    except:
+                        pass
 
 
 @celery.task(name="cmdb.ci_relation_delete", queue=CMDB_QUEUE)
+@reconnect_db
 def ci_relation_delete(parent_id, child_id):
     with Lock("CIRelation_{}".format(parent_id)):
         children = rd.get([parent_id], REDIS_PREFIX_CI_RELATION)[0]
@@ -165,9 +175,10 @@ def ci_relation_delete(parent_id, child_id):
 
 
 @celery.task(name="cmdb.ci_type_attribute_order_rebuild", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def ci_type_attribute_order_rebuild(type_id, uid):
     current_app.logger.info('rebuild attribute order')
-    db.session.remove()
 
     from api.lib.cmdb.ci_type import CITypeAttributeGroupManager
 
@@ -188,11 +199,11 @@ def ci_type_attribute_order_rebuild(type_id, uid):
 
 
 @celery.task(name="cmdb.calc_computed_attribute", queue=CMDB_QUEUE)
+@flush_db
+@reconnect_db
 def calc_computed_attribute(attr_id, uid):
     from api.lib.cmdb.ci import CIManager
 
-    db.session.remove()
-
     current_app.test_request_context().push()
     login_user(UserCache.get(uid))