fix(api): search for multiple CIType

This commit is contained in:
pycook 2024-09-24 17:46:27 +08:00
parent eb69029a51
commit 78da728105
2 changed files with 39 additions and 19 deletions

View File

@ -719,7 +719,7 @@ class CIManager(object):
if unique_required: if unique_required:
_d[d.get('unique')] = d.get(d.get('unique')) _d[d.get('unique')] = d.get(d.get('unique'))
_fields = list(fields.get(_d['_type']) or []) if isinstance(fields, dict) else fields _fields = list(fields.get(_d['_type']) or [] if isinstance(fields, dict) else fields)
for field in _fields + ['ci_type_alias', 'unique', 'unique_alias']: for field in _fields + ['ci_type_alias', 'unique', 'unique_alias']:
_d[field] = d.get(field) _d[field] = d.get(field)
_res.append(_d) _res.append(_d)

View File

@ -4,8 +4,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
import six
import time import time
from flask import current_app from flask import current_app
from flask_login import current_user from flask_login import current_user
from jinja2 import Template from jinja2 import Template
@ -104,9 +104,10 @@ class Search(object):
else: else:
raise SearchError(ErrFormat.attribute_not_found.format(key)) raise SearchError(ErrFormat.attribute_not_found.format(key))
def _type_query_handler(self, v, queries): def _type_query_handler(self, v, queries, is_sub=False):
new_v = v[1:-1].split(";") if v.startswith("(") and v.endswith(")") else [v] new_v = v[1:-1].split(";") if v.startswith("(") and v.endswith(")") else [v]
type_num = len(new_v) type_num = len(new_v)
type_id_list = []
for _v in new_v: for _v in new_v:
ci_type = CITypeCache.get(_v) ci_type = CITypeCache.get(_v)
@ -115,19 +116,32 @@ class Search(object):
if ci_type is not None: if ci_type is not None:
if self.valid_type_names == "ALL" or ci_type.name in self.valid_type_names: if self.valid_type_names == "ALL" or ci_type.name in self.valid_type_names:
if not is_sub:
self.type_id_list.append(str(ci_type.id)) self.type_id_list.append(str(ci_type.id))
if ci_type.id in self.type2filter_perms: type_id_list.append(str(ci_type.id))
if ci_type.id in self.type2filter_perms and not is_sub:
ci_filter = self.type2filter_perms[ci_type.id].get('ci_filter') ci_filter = self.type2filter_perms[ci_type.id].get('ci_filter')
if ci_filter and self.use_ci_filter and not self.use_id_filter: if ci_filter and self.use_ci_filter and not self.use_id_filter:
sub = [] sub = []
ci_filter = Template(ci_filter).render(user=current_user) ci_filter = Template(ci_filter).render(user=current_user)
for i in ci_filter.split(','): for i in ci_filter.split(','):
if type_num == 1:
if i.startswith("~") and not sub: if i.startswith("~") and not sub:
queries.append(i) queries.append(i)
else: else:
sub.append(i) sub.append(i)
else:
sub.append(i)
if sub: if sub:
if type_num == 1:
queries.append(dict(operator="&", queries=sub)) queries.append(dict(operator="&", queries=sub))
else:
if str(ci_type.id) in self.type_id_list:
self.type_id_list.remove(str(ci_type.id))
type_id_list.remove(str(ci_type.id))
sub.extend([i for i in queries[1:] if isinstance(i, six.string_types)])
sub.insert(0, "_type:{}".format(ci_type.id))
queries.append(dict(operator="|", queries=sub))
if self.type2filter_perms[ci_type.id].get('attr_filter'): if self.type2filter_perms[ci_type.id].get('attr_filter'):
if type_num == 1: if type_num == 1:
@ -152,13 +166,17 @@ class Search(object):
else: else:
raise SearchError(ErrFormat.ci_type_not_found2.format(_v)) raise SearchError(ErrFormat.ci_type_not_found2.format(_v))
if self.type_id_list: if type_num != len(self.type_id_list) and queries and queries[0].startswith('_type') and not is_sub:
type_ids = ",".join(self.type_id_list) queries[0] = "_type:({})".format(";".join(self.type_id_list))
if type_id_list:
type_ids = ",".join(type_id_list)
_query_sql = QUERY_CI_BY_TYPE.format(type_ids) _query_sql = QUERY_CI_BY_TYPE.format(type_ids)
if self.only_type_query: if self.only_type_query:
return _query_sql return _query_sql
else: elif type_num > 1:
return "" return "select c_cis.id as ci_id from c_cis where c_cis.id=0"
return "" return ""
@staticmethod @staticmethod
@ -331,7 +349,9 @@ class Search(object):
INNER JOIN ({2}) as {3} USING(ci_id)""".format(query_sql, alias, _query_sql, alias + "A") INNER JOIN ({2}) as {3} USING(ci_id)""".format(query_sql, alias, _query_sql, alias + "A")
elif operator == "|" or operator == "|~": elif operator == "|" or operator == "|~":
query_sql = "SELECT * FROM ({0}) as {1} UNION ALL ({2})".format(query_sql, alias, _query_sql) query_sql = "SELECT * FROM ({0}) as {1} UNION ALL SELECT * FROM ({2}) as {3}".format(query_sql, alias,
_query_sql,
alias + "A")
elif operator == "~": elif operator == "~":
query_sql = """SELECT * FROM ({0}) as {1} LEFT JOIN ({2}) as {3} USING(ci_id) query_sql = """SELECT * FROM ({0}) as {1} LEFT JOIN ({2}) as {3} USING(ci_id)
@ -436,14 +456,14 @@ class Search(object):
return result return result
def __query_by_attr(self, q, queries, alias): def __query_by_attr(self, q, queries, alias, is_sub=False):
k = q.split(":")[0].strip() k = q.split(":")[0].strip()
v = "\:".join(q.split(":")[1:]).strip() v = "\:".join(q.split(":")[1:]).strip()
v = v.replace("'", "\\'") v = v.replace("'", "\\'")
v = v.replace('"', '\\"') v = v.replace('"', '\\"')
field, field_type, operator, attr = self._attr_name_proc(k) field, field_type, operator, attr = self._attr_name_proc(k)
if field == "_type": if field == "_type":
_query_sql = self._type_query_handler(v, queries) _query_sql = self._type_query_handler(v, queries, is_sub)
elif field == "_id": elif field == "_id":
_query_sql = self._id_query_handler(v) _query_sql = self._id_query_handler(v)
@ -490,19 +510,20 @@ class Search(object):
return alias, _query_sql, operator return alias, _query_sql, operator
def __query_build_by_field(self, queries, is_first=True, only_type_query_special=True, alias='A', operator='&'): def __query_build_by_field(self, queries, is_first=True, only_type_query_special=True, alias='A', operator='&',
is_sub=False):
query_sql = "" query_sql = ""
for q in queries: for q in queries:
_query_sql = "" _query_sql = ""
if isinstance(q, dict): if isinstance(q, dict):
alias, _query_sql, operator = self.__query_build_by_field(q['queries'], True, True, alias) alias, _query_sql, operator = self.__query_build_by_field(q['queries'], True, True, alias, is_sub=True)
current_app.logger.info(_query_sql) current_app.logger.info(_query_sql)
current_app.logger.info((operator, is_first, alias)) current_app.logger.info((operator, is_first, alias))
operator = q['operator'] operator = q['operator']
elif ":" in q and not q.startswith("*"): elif ":" in q and not q.startswith("*"):
alias, _query_sql, operator = self.__query_by_attr(q, queries, alias) alias, _query_sql, operator = self.__query_by_attr(q, queries, alias, is_sub)
elif q == "*": elif q == "*":
continue continue
elif q: elif q:
@ -553,7 +574,6 @@ class Search(object):
queries = handle_arg_list(self.orig_query) queries = handle_arg_list(self.orig_query)
queries = self._extra_handle_query_expr(queries) queries = self._extra_handle_query_expr(queries)
queries = self.__confirm_type_first(queries) queries = self.__confirm_type_first(queries)
current_app.logger.debug(queries)
_, query_sql, _ = self.__query_build_by_field(queries) _, query_sql, _ = self.__query_build_by_field(queries)