From a09336f00beb5e4c3537adc02259bacab68e1657 Mon Sep 17 00:00:00 2001
From: pycook <pycook@126.com>
Date: Thu, 17 Oct 2024 19:46:39 +0800
Subject: [PATCH] feat(api): relation path search

---
 cmdb-api/api/lib/cmdb/search/ci/db/search.py  |  21 ++--
 .../api/lib/cmdb/search/ci_relation/search.py | 106 ++++++++++++------
 cmdb-api/api/views/cmdb/ci_relation.py        |   9 +-
 3 files changed, 88 insertions(+), 48 deletions(-)

diff --git a/cmdb-api/api/lib/cmdb/search/ci/db/search.py b/cmdb-api/api/lib/cmdb/search/ci/db/search.py
index 42b99ed..bb92aed 100644
--- a/cmdb-api/api/lib/cmdb/search/ci/db/search.py
+++ b/cmdb-api/api/lib/cmdb/search/ci/db/search.py
@@ -66,6 +66,7 @@ class Search(object):
         self.use_id_filter = use_id_filter
         self.use_ci_filter = use_ci_filter
         self.only_ids = only_ids
+        self.multi_type_has_ci_filter = False
 
         self.valid_type_names = []
         self.type2filter_perms = dict()
@@ -140,9 +141,10 @@ class Search(object):
                                         self.type_id_list.remove(str(ci_type.id))
                                     type_id_list.remove(str(ci_type.id))
                                     sub.extend([i for i in queries[1:] if isinstance(i, six.string_types)])
+
                                     sub.insert(0, "_type:{}".format(ci_type.id))
                                     queries.append(dict(operator="|", queries=sub))
-
+                                    self.multi_type_has_ci_filter = True
                         if self.type2filter_perms[ci_type.id].get('attr_filter'):
                             if type_num == 1:
                                 if not self.fl:
@@ -172,9 +174,9 @@ class Search(object):
         if type_id_list:
             type_ids = ",".join(type_id_list)
             _query_sql = QUERY_CI_BY_TYPE.format(type_ids)
-            if self.only_type_query:
+            if self.only_type_query or self.multi_type_has_ci_filter:
                 return _query_sql
-        elif type_num > 1:
+        elif type_num > 1:  # there must be instance-level access control
             return "select c_cis.id as ci_id from c_cis where c_cis.id=0"
 
         return ""
@@ -253,7 +255,7 @@ class Search(object):
             return ret_sql.format(query_sql, "ORDER BY B.ci_id {1} LIMIT {0:d}, {2};".format(
                 (self.page - 1) * self.count, sort_type, self.count))
 
-        elif self.type_id_list:
+        elif self.type_id_list and not self.multi_type_has_ci_filter:
             self.query_sql = "SELECT B.ci_id FROM ({0}) AS B {1}".format(
                 query_sql,
                 "INNER JOIN c_cis on c_cis.id=B.ci_id WHERE c_cis.type_id IN ({0}) ".format(
@@ -278,7 +280,7 @@ class Search(object):
     def __sort_by_type(self, sort_type, query_sql):
         ret_sql = "SELECT SQL_CALC_FOUND_ROWS DISTINCT B.ci_id FROM ({0}) AS B {1}"
 
-        if self.type_id_list:
+        if self.type_id_list and not self.multi_type_has_ci_filter:
             self.query_sql = "SELECT B.ci_id FROM ({0}) AS B {1}".format(
                 query_sql,
                 "INNER JOIN c_cis on c_cis.id=B.ci_id WHERE c_cis.type_id IN ({0}) ".format(
@@ -311,7 +313,7 @@ class Search(object):
                           WHERE {1}.attr_id = {3}""".format("ALIAS", table_name, query_sql, attr_id)
         new_table = _v_query_sql
 
-        if self.only_type_query or not self.type_id_list:
+        if self.only_type_query or not self.type_id_list or self.multi_type_has_ci_filter:
             return ("SELECT SQL_CALC_FOUND_ROWS DISTINCT C.ci_id FROM ({0}) AS C ORDER BY C.value {2} "
                     "LIMIT {1:d}, {3};".format(new_table, (self.page - 1) * self.count, sort_type, self.count))
 
@@ -518,8 +520,8 @@ class Search(object):
             _query_sql = ""
             if isinstance(q, dict):
                 alias, _query_sql, operator = self.__query_build_by_field(q['queries'], True, True, alias, is_sub=True)
-                current_app.logger.info(_query_sql)
-                current_app.logger.info((operator, is_first, alias))
+                # current_app.logger.info(_query_sql)
+                # current_app.logger.info((operator, is_first, alias))
                 operator = q['operator']
 
             elif ":" in q and not q.startswith("*"):
@@ -617,6 +619,7 @@ class Search(object):
                 k, _, _, _ = self._attr_name_proc(f)
                 if k:
                     _fl.append(k)
+
             return _fl
         else:
             return self.fl
@@ -638,6 +641,8 @@ class Search(object):
         if ci_ids:
             response = CIManager.get_cis_by_ids(ci_ids, ret_key=self.ret_key, fields=_fl, excludes=self.excludes)
         for res in response:
+            if not res:
+                continue
             ci_type = res.get("ci_type")
             if ci_type not in counter.keys():
                 counter[ci_type] = 0
diff --git a/cmdb-api/api/lib/cmdb/search/ci_relation/search.py b/cmdb-api/api/lib/cmdb/search/ci_relation/search.py
index a56a41f..c96271c 100644
--- a/cmdb-api/api/lib/cmdb/search/ci_relation/search.py
+++ b/cmdb-api/api/lib/cmdb/search/ci_relation/search.py
@@ -29,6 +29,8 @@ from api.lib.cmdb.utils import ValueTypeMap
 from api.lib.perm.acl.acl import ACLManager
 from api.lib.perm.acl.acl import is_app_admin
 from api.models.cmdb import CI
+from api.models.cmdb import CITypeRelation
+from api.models.cmdb import RelationType
 
 
 class Search(object):
@@ -437,7 +439,10 @@ class Search(object):
         if not q.startswith('_type:'):
             q = "_type:({}),{}".format(";".join(map(str, type_ids)), q)
 
-        return SearchFromDB(q, ci_ids=target_ids, use_ci_filter=False, only_ids=True, count=100000).search()
+        ci_ids = SearchFromDB(q, ci_ids=target_ids, use_ci_filter=True, only_ids=True, count=100000).search()
+        cis = CI.get_by(fl=['id', 'type_id'], only_query=True).filter(CI.id.in_(ci_ids))
+
+        return [(str(i.id), i.type_id) for i in cis]
 
     @staticmethod
     def _path2level(src_type_id, target_type_ids, path):
@@ -445,21 +450,31 @@ class Search(object):
             return abort(400, ErrFormat.relation_path_search_src_target_required)
 
         graph = nx.DiGraph()
-        graph.add_edges_from([(int(s), d) for s in path for d in path[s]])
-
+        graph.add_edges_from([(n, _path[idx + 1]) for _path in path for idx, n in enumerate(_path[:-1])])
+        relation_types = defaultdict(dict)
         level2type = defaultdict(set)
-        for target_type_id in target_type_ids:
-            paths = list(nx.all_simple_paths(graph, source=src_type_id, target=target_type_id))
-            for _path in paths:
-                for idx, node in enumerate(_path[1:]):
-                    level2type[idx + 1].add(node)
+        type2show_key = dict()
+        for _path in path:
+            for idx, node in enumerate(_path[1:]):
+                level2type[idx + 1].add(node)
+
+                src = CITypeCache.get(_path[idx])
+                target = CITypeCache.get(node)
+                relation_type = RelationType.get_by(only_query=True).join(
+                    CITypeRelation, CITypeRelation.relation_type_id == RelationType.id).filter(
+                    CITypeRelation.parent_id == src.id).filter(CITypeRelation.child_id == target.id).first()
+                relation_types[src.alias].update({target.alias: relation_type.name})
+
+                if src.id not in type2show_key:
+                    type2show_key[src.id] = AttributeCache.get(src.show_id or src.unique_id).name
+                if target.id not in type2show_key:
+                    type2show_key[target.id] = AttributeCache.get(target.show_id or target.unique_id).name
+
         nodes = graph.nodes()
 
-        del graph
+        return level2type, list(nodes), relation_types, type2show_key
 
-        return level2type, list(nodes)
-
-    def _build_graph(self, source_ids, level2type, target_type_ids, acl):
+    def _build_graph(self, source_ids, source_type_id, level2type, target_type_ids, acl):
         type2filter_perms = dict()
         if not self.is_app_admin:
             res2 = acl.get_resources(ResourceTypeEnum.CI_FILTER)
@@ -469,7 +484,8 @@ class Search(object):
         target_type_ids = set(target_type_ids)
         graph = nx.DiGraph()
         target_ids = []
-        key = list(map(str, source_ids))
+        key = [(str(i), source_type_id) for i in source_ids]
+        graph.add_nodes_from(key)
         for level in level2type:
             filter_type_ids = level2type[level]
             id_filter_limit = dict()
@@ -480,10 +496,11 @@ class Search(object):
 
             has_target = filter_type_ids & target_type_ids
 
-            res = [json.loads(x).items() for x in [i or '{}' for i in rd.get(key, REDIS_PREFIX_CI_RELATION) or []]]
+            res = [json.loads(x).items() for x in [i or '{}' for i in rd.get([i[0] for i in key],
+                                                                             REDIS_PREFIX_CI_RELATION) or []]]
             _key = []
             for idx, _id in enumerate(key):
-                valid_targets = [i[0] for i in res[idx] if i[1] in filter_type_ids and
+                valid_targets = [i for i in res[idx] if i[1] in filter_type_ids and
                                  (not id_filter_limit or int(i[0]) in id_filter_limit)]
                 _key.extend(valid_targets)
                 graph.add_edges_from(zip([_id] * len(valid_targets), valid_targets))
@@ -496,31 +513,41 @@ class Search(object):
         return graph, target_ids
 
     @staticmethod
-    def _find_paths(graph, source_ids, target_ids, max_depth=6):
+    def _find_paths(graph, source_ids, source_type_id, target_ids, valid_path, max_depth=6):
         paths = []
         for source_id in source_ids:
-            _paths = nx.all_simple_paths(graph, source=source_id, target=target_ids, cutoff=max_depth)
-            paths.extend(_paths)
+            _paths = nx.all_simple_paths(graph,
+                                         source=(source_id, source_type_id),
+                                         target=target_ids,
+                                         cutoff=max_depth)
+            for __path in _paths:
+                if tuple([i[1] for i in __path]) in valid_path:
+                    paths.append([i[0] for i in __path])
 
         return paths
 
     @staticmethod
-    def _wrap_path_result(paths, types):
+    def _wrap_path_result(paths, types, valid_path, target_types, type2show_key):
         ci_ids = [j for i in paths for j in i]
 
         response, _, _, _, _, _ = SearchFromDB("_type:({})".format(";".join(map(str, types))),
                                                use_ci_filter=False,
                                                ci_ids=list(map(int, ci_ids)),
                                                count=1000000).search()
-        id2ci = {str(i.get('_id')): i for i in response}
+        id2ci = {str(i.get('_id')): i if i['_type'] in target_types else {
+            type2show_key[i['_type']]: i[type2show_key[i['_type']]],
+            "ci_type_alias": i["ci_type_alias"],
+            "_type": i["_type"],
+        } for i in response}
 
         result = defaultdict(list)
         counter = defaultdict(int)
 
         for path in paths:
             key = "-".join([id2ci.get(i, {}).get('ci_type_alias') or '' for i in path])
-            counter[key] += 1
-            result[key].append(path)
+            if tuple([id2ci.get(i, {}).get('_type') for i in path]) in valid_path:
+                counter[key] += 1
+                result[key].append(path)
 
         return result, counter, id2ci
 
@@ -529,33 +556,38 @@ class Search(object):
 
         :param source: {type_id: id, q: expr}
         :param target: {type_ids: [id], q: expr}
-        :param path: {parent_id: [child_id]}, use type id
+        :param path: [source_type_id, ..., target_type_id], use type id
         :return:
         """
         acl = ACLManager('cmdb')
         if not self.is_app_admin:
             res = {i['name'] for i in acl.get_resources(ResourceTypeEnum.CI_TYPE)}
-            for type_id in (source.get('type_id') or []) + (target.get('type_ids') or []):
+            for type_id in (source.get('type_id') and [source['type_id']] or []) + (target.get('type_ids') or []):
                 _type = CITypeCache.get(type_id)
                 if _type and _type.name not in res:
                     return abort(403, ErrFormat.no_permission.format(_type.alias, PermEnum.READ))
 
-        level2type, types = self._path2level(source.get('type_id'), target.get('type_ids'), path)
+        target['type_ids'] = [i[-1] for i in path]
+        level2type, types, relation_types, type2show_key = self._path2level(
+            source.get('type_id'), target.get('type_ids'), path)
         if not level2type:
-            return [], {}, 0, self.page, 0, {}
+            return [], {}, 0, self.page, 0, {}, {}
 
         source_ids = self._get_src_ids(source)
 
-        graph, target_ids = self._build_graph(source_ids, level2type, target['type_ids'], acl)
-        if target.get('q'):
-            target_ids = self._filter_target_ids(target_ids, target['type_ids'], target['q'])
+        graph, target_ids = self._build_graph(source_ids, source['type_id'], level2type, target['type_ids'], acl)
+        target_ids = self._filter_target_ids(target_ids, target['type_ids'], target.get('q') or '')
+        paths = self._find_paths(graph,
+                                 source_ids,
+                                 source['type_id'],
+                                 set(target_ids),
+                                 {tuple(i): 1 for i in path})
 
-        paths = self._find_paths(graph, source_ids, set(target_ids))
-        del graph
-
-        numfound = len(target_ids)
+        numfound = len(paths)
         paths = paths[(self.page - 1) * self.count:self.page * self.count]
-
-        response, counter, id2ci = self._wrap_path_result(paths, types)
-
-        return response, counter, len(paths), self.page, numfound, id2ci
+        response, counter, id2ci = self._wrap_path_result(paths,
+                                                          types,
+                                                          {tuple(i): 1 for i in path},
+                                                          set(target.get('type_ids') or []),
+                                                          type2show_key)
+        return response, counter, len(paths), self.page, numfound, id2ci, relation_types, type2show_key
diff --git a/cmdb-api/api/views/cmdb/ci_relation.py b/cmdb-api/api/views/cmdb/ci_relation.py
index ae9bc7d..d20b67c 100644
--- a/cmdb-api/api/views/cmdb/ci_relation.py
+++ b/cmdb-api/api/views/cmdb/ci_relation.py
@@ -73,7 +73,7 @@ class CIRelationSearchPathView(APIView):
                     page_size | count: page size
                     source: source CIType, e.g. {type_id: 1, q: `search expr`}
                     target: target CIType, e.g. {type_ids: [2], q: `search expr`}
-                    path: Path from the Source CIType to the Target CIType, e.g. {source_id: [target_id]}
+                    path: Path from the Source CIType to the Target CIType, e.g. [1, ..., 2]
         """
 
         page = get_page(request.values.get("page", 1))
@@ -85,7 +85,8 @@ class CIRelationSearchPathView(APIView):
 
         s = Search(page=page, count=count)
         try:
-            response, counter, total, page, numfound, id2ci = s.search_by_path(source, target, path)
+            (response, counter, total, page, numfound, id2ci,
+             relation_types, type2show_key) = s.search_by_path(source, target, path)
         except SearchError as e:
             return abort(400, str(e))
 
@@ -94,7 +95,9 @@ class CIRelationSearchPathView(APIView):
                             page=page,
                             counter=counter,
                             paths=response,
-                            id2ci=id2ci)
+                            id2ci=id2ci,
+                            relation_types=relation_types,
+                            type2show_key=type2show_key)
 
 
 class CIRelationStatisticsView(APIView):