Coverage for CIResults/filtering.py: 100%

402 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 13:11 +0000

1from arpeggio import Optional, ZeroOrMore, OneOrMore, EOF, ParserPython 

2from arpeggio import PTNodeVisitor, visit_parse_tree, NoMatch 

3from arpeggio import RegExMatch as _ 

4 

5from dateutil import parser as datetimeparser 

6from django.contrib import messages 

7from django.db.models import Q 

8from django.http.request import QueryDict 

9from django.utils import timezone 

10from django.utils.dateparse import parse_duration 

11from django.utils.functional import cached_property 

12 

13from shortener.models import Shortener 

14 

15import traceback 

16import pytz 

17import re 

18 

19 

20# Arpeggio's parser 

21def val_none(): return _(r'NONE') 

22 

23 

24def val_int(): return _(r'-?\d+') 

25 

26 

27def val_str(): return [('"', _(r'([^"\\]|\\.)*'), '"'), 

28 ("'", _(r"([^'\\]|\\.)*"), "'")] 

29 

30 

31def val_bool(): return [r'TRUE', r'FALSE'] 

32 

33 

34def val_datetime(): return 'datetime(', [_(r'[^\)]+')], ')' 

35 

36 

37def val_duration(): return 'duration(', [_(r'[^\)]+')], ')' 

38 

39 

40def val_ago(): return 'ago(', [_(r'[^\)]+')], ')' 

41 

42 

43def val_array(): return "[", OneOrMore(([val_none, val_str, val_int, val_bool, val_datetime, val_duration, val_ago], 

44 Optional(','))), "]" 

45 

46 

47def nested_expression(): return ZeroOrMore(ZeroOrMore(_(r'[^()]+')), ZeroOrMore("(", nested_expression, ")"), 

48 ZeroOrMore(_(r'[^()]+'))) 

49 

50 

51def val_subquery(): return [('(', nested_expression, ')')] 

52 

53 

54def filter_field(): return _(r'[a-zA-Z\d_-]+') 

55 

56 

57def filter_object(): return _(r'\w+'), Optional(".", filter_field) 

58 

59 

60def basic_filter(): return [(filter_object, ['IS IN', 'NOT IN'], val_array), 

61 (filter_object, ['<=', r'<', r'>=', r'>'], [val_int, val_datetime, 

62 val_duration, val_ago]), 

63 (filter_object, ['=', '!='], [val_duration, val_datetime, val_int, 

64 val_bool, val_str, val_none]), 

65 (filter_object, [r'~=', r'MATCHES', r'ICONTAINS'], val_str), 

66 (filter_object, [r'CONTAINS'], [val_str, val_array]), 

67 (filter_object, [r'MATCHES'], val_subquery)] 

68 

69 

70def orderby_object(): return _(r'-?\w+') 

71 

72 

73def orderby(): return ("ORDER_BY", orderby_object) 

74 

75 

76def limit(): return ("LIMIT", val_int) 

77 

78 

79def factor(): return Optional("NOT"), [basic_filter, ("(", expression, ")")] 

80 

81 

82def expression(): return factor, ZeroOrMore(["AND", "OR"], factor), Optional(orderby), Optional(limit) 

83 

84 

85def query(): return Optional(expression), EOF 

86 

87 

88class QueryVisitor(PTNodeVisitor): 

89 class NoneObject: 

90 pass 

91 

92 def __init__(self, model, ignore_fields=[], *arg, **kwargs): 

93 """ 

94 Args: 

95 ignore_fields (list): List of fields whose filter conditions will be ignored during parsing. 

96 """ 

97 self.model = model 

98 self.orderby = None 

99 self.limit = None 

100 self.ignore_db_paths = [] 

101 for field in ignore_fields: 

102 if obj := self.model.filter_objects_to_db.get(field, {}): 

103 self.ignore_db_paths.append(obj.db_path) 

104 

105 super().__init__(*arg, **kwargs) 

106 

107 def visit_val_none(self, node, children): 

108 # HACK: I would have rather returned None, but Arppegio interprets this as 

109 # a <no match>... Instead, return a NoneObject that will later be converted 

110 return QueryVisitor.NoneObject() 

111 

112 def visit_val_int(self, node, children): 

113 return FilterObjectInteger.parse_value(node.value) 

114 

115 def visit_val_str(self, node, children): 

116 if len(children) == 0: 

117 return "" 

118 if len(children) > 1: 

119 raise ValueError("val_str cannot have more than one child") # pragma: no cover 

120 return FilterObjectStr.parse_value(children[0]) 

121 

122 def visit_val_bool(self, node, children): 

123 return FilterObjectBool.parse_value(node.value) 

124 

125 def visit_val_datetime(self, node, children): 

126 if len(children) > 1: 

127 raise ValueError("val_datetime cannot have more than one child") # pragma: no cover 

128 return FilterObjectDateTime.parse_value(children[0]) 

129 

130 def visit_val_duration(self, node, children): 

131 if len(children) > 1: 

132 raise ValueError("val_duration cannot have more than one child") # pragma: no cover 

133 return FilterObjectDuration.parse_value(children[0]) 

134 

135 def visit_val_ago(self, node, children): 

136 if len(children) > 1: 

137 raise ValueError("val_ago cannot have more than one child") # pragma: no cover 

138 duration = FilterObjectDuration.parse_value(children[0]) 

139 return timezone.now() - duration 

140 

141 def visit_filter_field(self, node, children): 

142 if '__' in node.value: 

143 raise ValueError("Dict object keys cannot contain the substring '__'") 

144 

145 return node.value 

146 

147 def visit_filter_object(self, node, children): 

148 filter_obj = self.model.filter_objects_to_db.get(children[0]) 

149 if filter_obj is None: 

150 raise ValueError("The object '{}' does not exist".format(children[0])) 

151 

152 if isinstance(filter_obj, FilterObjectJSON): 

153 if len(children) != 2: 

154 raise ValueError("The dict object '{}' requires a key to access its data".format(children[0])) 

155 filter_obj = FilterObjectJSON(filter_obj._db_path, filter_obj.description, children[1]) 

156 elif len(children) != 1: 

157 raise ValueError("The object '{}' cannot have an associated key".format(children[0])) 

158 

159 return filter_obj 

160 

161 def visit_val_array(self, node, children): 

162 return [c for c in children if c != ','] 

163 

164 def visit_val_subquery(self, node, children): 

165 out = "" 

166 for x in list(node): 

167 out += str(x.flat_str()) 

168 return out 

169 

170 def get_related_model(self, path: str): 

171 model = self.model 

172 for field_name in path.split("__"): 

173 model = getattr(model, field_name).field.remote_field.model 

174 return model 

175 

176 def visit_basic_filter(self, node, children): 

177 if len(children) == 3: 

178 filter_obj, lookup, item = children 

179 

180 if filter_obj.db_path in self.ignore_db_paths: 

181 return Q() 

182 

183 key = filter_obj.db_path 

184 

185 if isinstance(filter_obj, FilterObjectModel): 

186 key += '__in' 

187 item = filter_obj.parse_value(item) 

188 else: 

189 if lookup == '<=': 

190 key += '__lte' 

191 elif lookup == '>=': 

192 key += '__gte' 

193 elif lookup == '<': 

194 key += '__lt' 

195 elif lookup == '>': 

196 key += '__gt' 

197 elif lookup == 'CONTAINS' and isinstance(item, list): 

198 # Key has to be split into two separate ones to filter by multiple nested values in relation 

199 # "many-to-many" (example: ts_run__machine__tags__name): 

200 # * "key": path to parent key that is queried over (based on example: ts_run__machine) 

201 # * "subquery_key": relative path to child value key (based on example: tags__name) 

202 key_parts = key.split("__") 

203 key = "__".join(key_parts[:-2]) 

204 subquery_key = "__".join(key_parts[-2:]) 

205 objects = self.get_related_model(key).objects 

206 for value in item: 

207 objects = objects.filter(Q(**{f"{subquery_key}__exact": value})) 

208 key += '__in' 

209 item = objects 

210 elif lookup == 'CONTAINS': 

211 key += '__contains' 

212 elif lookup == 'ICONTAINS': 

213 key += '__icontains' 

214 elif lookup in ['IS IN', 'NOT IN']: 

215 key += '__in' 

216 elif lookup == 'MATCHES' or lookup == '~=': 

217 key += '__regex' 

218 elif lookup in ['=', '!=']: 

219 key += "__exact" 

220 else: # pragma: no cover 

221 raise ValueError("Unknown lookup '{}'".format(lookup)) # pragma: no cover 

222 

223 # HACK: see visit_val_none() 

224 if isinstance(item, QueryVisitor.NoneObject): 

225 item = None 

226 

227 obj = Q(**{key: item}) 

228 if lookup in ['!=', 'NOT IN']: 

229 return ~obj 

230 else: 

231 return obj 

232 else: 

233 raise ValueError("basic_filter: Invalid amount of operands") # pragma: no cover 

234 

235 def visit_factor(self, node, children): 

236 if len(children) > 1: 

237 if children[0] == "NOT": 

238 return ~children[-1] 

239 return children[-1] 

240 

241 def visit_orderby_object(self, node, children): 

242 reverse = node.value[0] == '-' 

243 

244 obj_name = node.value if not reverse else node.value[1:] 

245 

246 filter_obj = self.model.filter_objects_to_db.get(obj_name) 

247 if filter_obj is not None: 

248 return "{}{}".format("-" if reverse else "", filter_obj.db_path) 

249 else: 

250 raise ValueError("The object '{}' does not exist".format(obj_name)) 

251 

252 def visit_orderby(self, node, children): 

253 if len(children) == 1: 

254 self.orderby = children[0] 

255 else: 

256 raise ValueError("orderby: Invalid amount of operands") # pragma: no cover 

257 

258 def visit_limit(self, node, children): 

259 if len(children) == 1: 

260 if children[0] < 0: 

261 raise ValueError("Negative limits are not supported") 

262 

263 self.limit = children[0] 

264 else: 

265 raise ValueError("limit: Invalid amount of operands") # pragma: no cover 

266 

267 def visit_expression(self, node, children): 

268 if len(children) >= 1: 

269 qResult = children[0] 

270 for i in range(2, len(children), 2): 

271 if children[i-1] == "AND": 

272 qResult &= children[i] 

273 elif children[i-1] == "OR": 

274 qResult |= children[i] 

275 return qResult 

276 

277 def visit_query(self, node, children): 

278 if len(children) > 1: 

279 raise ValueError("query cannot have more than one child") # pragma: no cover 

280 elif len(children) == 1: 

281 return children[0] 

282 else: 

283 return Q() 

284 

285 

286class QueryParser: 

287 def __init__(self, model, user_query, ignore_fields: list[str] = []): 

288 self.model = model 

289 self.user_query = user_query 

290 

291 self.error = None 

292 self.q_objects = Q() 

293 self.orderby = None 

294 self.limit = None 

295 

296 try: 

297 parser = ParserPython(query) 

298 parse_tree = parser.parse(self.user_query) 

299 query_visitor = QueryVisitor(self.model, ignore_fields=ignore_fields) 

300 

301 self.q_objects = visit_parse_tree(parse_tree, query_visitor) 

302 self.orderby = query_visitor.orderby 

303 self.limit = query_visitor.limit 

304 except ValueError as e: 

305 self.error = str(e) 

306 except NoMatch as e: 

307 self.error = str(e) 

308 

309 @property 

310 def query_key(self): 

311 return Shortener.get_or_create(self.user_query).shorthand 

312 

313 @property 

314 def is_valid(self): 

315 return self.error is None 

316 

317 @property 

318 def is_empty(self): 

319 return not self.is_valid or len(self.user_query) == 0 

320 

321 @cached_property 

322 def objects(self): 

323 if self.is_valid: 

324 query = self.model.objects.filter(self.q_objects).distinct() 

325 query = query.order_by(self.orderby) if self.orderby is not None else query 

326 return query[:self.limit] if self.limit is not None else query 

327 else: 

328 return self.model.objects.none() 

329 

330 

331class LegacyParser: 

332 userfilters_allowed_lookups = {'exact': '=', 'in': 'IS IN', 'regex': '~=', 'contains': 'CONTAINS', 

333 'icontains': 'ICONTAINS', 'gt': '>', 'gte': '>=', 'lt': '<', 'lte': '<='} 

334 userfilters_allowed_types = ['str', 'int', 'bool', 'datetime', 'duration'] 

335 

336 def __init__(self, model, **user_filters): 

337 # Filters should all be of the following format: 

338 # (only|exclude)__(object)__(in|regex|gt|lt) = str or format(value) 

339 lookups = "|".join(self.userfilters_allowed_lookups.keys()) 

340 format_re = re.compile((r'(?P<action>(only|exclude))__(?P<object>\w+)__' 

341 '(?P<lookup>({lookups}))'.format(lookups=lookups))) 

342 

343 # Iterate through the user filters, match them to our format regex, 

344 # then construct the right ORM call 

345 only = [] 

346 exclude = [] 

347 for key, item in user_filters.items(): 

348 match = format_re.match(key) 

349 if match: 

350 fields = match.groupdict() 

351 

352 db_object = model.filter_objects_to_db.get(fields['object']) 

353 if db_object is None: 

354 continue 

355 

356 # aggregate all regular expressions into one request 

357 if fields['lookup'] == 'regex' and isinstance(item, list) and len(item) > 1: 

358 item = r'('+'|'.join(item)+')' 

359 

360 # Try converting the item to the right unit 

361 item = self._convert_user_values(item) 

362 

363 bfilter = "{} {} {}".format(fields['object'], 

364 self.userfilters_allowed_lookups.get(fields['lookup']), 

365 item) 

366 

367 if fields['action'] == 'only': 

368 only.append(bfilter) 

369 else: 

370 exclude.append(bfilter) 

371 

372 self.query = " AND ".join(only) 

373 if len(exclude) > 0: 

374 if len(only) > 0: 

375 self.query += ' AND ' 

376 self.query += "NOT ({})".format(" AND ".join(exclude)) 

377 

378 @classmethod 

379 def _convert_user_value(cls, value): 

380 # Will automatically be cached by python 

381 types = "|".join(cls.userfilters_allowed_types) 

382 item_re = re.compile(r'(?P<type>({types}))\((?P<value>.*)\)'.format(types=types)) 

383 

384 match = item_re.match(value) 

385 if match: 

386 fields = match.groupdict() 

387 

388 try: 

389 if fields['type'] == 'str': 

390 return "'{}'".format(fields['value']) 

391 elif fields['type'] == 'bool': 

392 return "TRUE" if FilterObjectBool.parse_value(fields['value']) else "FALSE" 

393 elif fields['type'] == 'int': 

394 return fields['value'] 

395 elif fields['type'] == 'datetime' or fields['type'] == 'duration': 

396 return value 

397 except Exception: # pragma: no cover 

398 traceback.print_exc() # pragma: no cover 

399 

400 # Default to the variable being a string 

401 return "'" + value + "'" 

402 

403 @classmethod 

404 def _convert_user_values(cls, items): 

405 # detect whether we have a singular value or a list 

406 if isinstance(items, list): 

407 if len(items) > 1: 

408 new = [] 

409 for item in items: 

410 new.append(cls._convert_user_value(item)) 

411 return "[" + ", ".join(new) + "]" 

412 else: 

413 return cls._convert_user_value(items[0]) 

414 else: 

415 return cls._convert_user_value(items) 

416 

417 

418class UserFiltrableMixin: 

419 @classmethod 

420 def _get_value_from_params(cls, user_filters, key): 

421 val = user_filters.get(key) 

422 if isinstance(val, list) and len(val) == 1: 

423 val = val[0] 

424 return val 

425 

426 @classmethod 

427 def from_user_filters(cls, prefix=None, **user_filters): 

428 query_param_name = f'{prefix}_query' if prefix is not None else 'query' 

429 query = cls._get_value_from_params(user_filters, query_param_name) 

430 if query is None: 

431 query_key = cls._get_value_from_params(user_filters, f'{query_param_name}_key') 

432 short = Shortener.objects.filter(shorthand=query_key).first() 

433 if short is not None: 

434 query = short.full 

435 

436 if query is not None: 

437 return QueryParser(cls, query) 

438 else: 

439 query = LegacyParser(cls, **user_filters).query 

440 return QueryParser(cls, query) 

441 

442 

443class FilterObject: 

444 def __init__(self, db_path, description=None): 

445 self._db_path = db_path 

446 self._description = description 

447 

448 @property 

449 def db_path(self): 

450 return self._db_path 

451 

452 @property 

453 def description(self): 

454 if self._description is None: 

455 return "<no description yet>" 

456 else: 

457 return self._description 

458 

459 

460class FilterObjectJSON(FilterObject): 

461 data_type = "anything" 

462 documentation = "Expected format: <JSON field>.<key>" 

463 test_value = "test" 

464 

465 def __init__(self, db_path, description=None, key=None): 

466 self.key = key 

467 super().__init__(db_path, description) 

468 

469 @property 

470 def db_path(self): 

471 if self.key is None: 

472 raise ValueError("Dict field require a key to be accessed") # pragma: no cover 

473 return "{}__{}".format(self._db_path, self.key) 

474 

475 

476class FilterObjectStr(FilterObject): 

477 data_type = "string" 

478 documentation = "Expected format: anything. Use quotes for the new query language (\"\" or ''). " \ 

479 "Escape quotes by placing '\\' before quote character." 

480 test_value = "str_test" 

481 

482 def __init__(self, db_path, description=None): 

483 super().__init__(db_path, description) 

484 

485 @classmethod 

486 def parse_value(cls, value): 

487 return str(value) 

488 

489 

490class FilterObjectDateTime(FilterObject): 

491 data_type = "datetime" 

492 documentation = "Expected format: datetime(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ])" 

493 test_value = "2019-01-01" 

494 

495 def __init__(self, db_path, description=None): 

496 super().__init__(db_path, description) 

497 

498 @classmethod 

499 def parse_value(cls, value): 

500 return timezone.make_aware(datetimeparser.parse(value), pytz.utc) 

501 

502 

503class FilterObjectDuration(FilterObject): 

504 data_type = "duration" 

505 documentation = 'Expected format: "duration(DD HH:MM:SS.uuuuuu)", or "duration(P4DT1H15M20S)" (ISO 8601), ' \ 

506 'or "duration(3 days 04:05:06)" (PostgreSQL).' 

507 test_value = "123.456 seconds" 

508 

509 def __init__(self, db_path, description=None): 

510 super().__init__(db_path, description) 

511 

512 @classmethod 

513 def parse_value(cls, value): 

514 duration = parse_duration(value) 

515 if duration is None: 

516 raise ValueError("The value '{}' does not represent a duration. {}".format(value, cls.documentation)) 

517 return duration 

518 

519 

520class FilterObjectBool(FilterObject): 

521 data_type = "boolean" 

522 documentation = "Supported values: bool(false)/bool(0) or bool(true)/bool(1). " \ 

523 "Use TRUE or FALSE for the new query language." 

524 test_value = "True" 

525 

526 def __init__(self, db_path, description=None): 

527 super().__init__(db_path, description) 

528 

529 @classmethod 

530 def parse_value(cls, value): 

531 return str(value).lower() in ["1", "true"] 

532 

533 

534class FilterObjectInteger(FilterObject): 

535 data_type = "integer" 

536 documentation = "Supported values: int(12345). Use 12345 for the new query language." 

537 test_value = 12345 

538 

539 def __init__(self, db_path, description=None): 

540 super().__init__(db_path, description) 

541 

542 @classmethod 

543 def parse_value(cls, value): 

544 return int(float(value)) 

545 

546 

547class FilterObjectModel(FilterObject): 

548 data_type = "subquery" 

549 documentation = "Expected format: Any query compatible with the model selected" 

550 

551 def __init__(self, model, db_path, description=None): 

552 self.model = model 

553 super().__init__(db_path, description) 

554 

555 def parse_value(self, value): 

556 result = QueryParser(self.model, value) 

557 if not result.is_valid: 

558 raise ValueError(result.error) 

559 return result.objects 

560 

561 

562class QueryCreator: 

563 def __init__(self, request, Model, prefix=None, default_query_parameters={}): 

564 self.request = request 

565 self.Model = Model 

566 self.prefix = prefix 

567 self.default_query_parameters = default_query_parameters 

568 

569 def __create_query_from_filters(self, **requested_filters): 

570 query = self.Model.from_user_filters(self.prefix, **requested_filters) 

571 if len(query.user_query) > 0: 

572 if not query.is_valid and self.request: 

573 messages.error(self.request, "Filtering error: " + query.error) 

574 return query 

575 return None 

576 

577 def __build_query_string(self): 

578 op_mappings = { 

579 'string': 'MATCHES', 

580 'datetime': 'MATCHES', 

581 'integer': 'MATCHES', 

582 } 

583 query_str = "" 

584 for obj in self.Model.filter_objects_to_db: 

585 param_value = self.request.GET.get(f'{self.prefix}_{obj}' if self.prefix else obj) 

586 if param_value: 

587 data_type = self.Model.filter_objects_to_db[obj].data_type 

588 if len(query_str) > 0: 

589 query_str += " AND " 

590 query_str += f"{obj} {op_mappings[data_type]} '{param_value}'" 

591 return query_str 

592 

593 def string_to_query(self, query_string): 

594 query_param_name = f'{self.prefix}_query' if self.prefix else 'query' 

595 query_dict = QueryDict('', mutable=True) 

596 query_dict.update({f'{query_param_name}': query_string}) 

597 query = self.__create_query_from_filters(**query_dict) 

598 if query: 

599 return query 

600 return self.Model.from_user_filters(**self.default_query_parameters) 

601 

602 def request_to_query(self): 

603 for params in [self.request.POST, self.request.GET]: 

604 # convert the user filters to a normal dictionary to prevent issues when 

605 # inserting new values 

606 requested_filters = params.copy() 

607 query = self.__create_query_from_filters(**requested_filters) 

608 if query: 

609 return query 

610 

611 return self.Model.from_user_filters(**self.default_query_parameters) 

612 

613 def multiple_request_params_to_query(self): 

614 query_str = self.__build_query_string() 

615 return self.string_to_query(query_str)