filters.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import logging
  2. from .common import MatchSeries, ALLOWED_NAMES, ALLOWED_GLOBALS
  3. logger = logging.getLogger(__name__)
  4. """
  5. This middleware filters the measurements by series and yields if any given field matches.
  6. """
  7. class MatchAny(MatchSeries):
  8. def __init__(self, parent, series, **kwargs) -> None:
  9. super().__init__(series)
  10. self._fields = kwargs
  11. def execute(self, values):
  12. for measurement in values:
  13. dataset = self.get_series(measurement)
  14. if not dataset:
  15. continue
  16. if not self._fields:
  17. yield measurement
  18. continue
  19. # check if any field matches
  20. for field, value in self._fields.items():
  21. v = getattr(dataset, field, None)
  22. if v == value or (isinstance(v, tuple) and value in v):
  23. yield measurement
  24. break
  25. """
  26. This middleware filters the measurements by series and yields if all given fields match.
  27. """
  28. class MatchAll(MatchSeries):
  29. def __init__(self, parent, series, **kwargs) -> None:
  30. super().__init__(series)
  31. self._fields = kwargs
  32. def execute(self, values):
  33. for measurement in values:
  34. dataset = self.get_series(measurement)
  35. if not dataset:
  36. continue
  37. # check if all fields match
  38. success = True
  39. for field, value in self._fields.items():
  40. v = getattr(dataset, field, None)
  41. if (not isinstance(v, tuple) and v != value) or (isinstance(v, tuple) and not all(x == value for x in v)):
  42. success = False
  43. break
  44. if success:
  45. yield measurement
  46. class ComplexFilter():
  47. def __init__(self, parent, predicate) -> None:
  48. self._predicate = predicate
  49. self._compiled = compile(predicate, "<string>", "eval")
  50. # Validate allowed names
  51. for name in self._compiled.co_names:
  52. if name not in ALLOWED_NAMES:
  53. raise NameError(f"The use of '{name}' is not allowed in '{predicate}'")
  54. def execute(self, values):
  55. for measurement in values:
  56. try:
  57. if eval(self._compiled, {"__builtins__": ALLOWED_GLOBALS}, measurement.__dict__):
  58. yield measurement
  59. except Exception as e:
  60. logger.error(f"Error while evaluating predicate '{self._predicate}': {e}")