diff --git a/tests/event.py b/tests/event.py index 7fee2cb..c946500 100644 --- a/tests/event.py +++ b/tests/event.py @@ -152,22 +152,41 @@ def container_id(self) -> str: def loginuid(self) -> int: return self._loginuid - @override - def __eq__(self, other: Any) -> bool: - if isinstance(other, ProcessSignal): - if self.pid is not None and self.pid != other.pid: - return False - - return ( - self.uid == other.uid and - self.gid == other.gid and - self.exe_path == other.exec_file_path and - self.args == other.args and - self.name == other.name and - self.container_id == other.container_id and - self.loginuid == other.login_uid - ) - raise NotImplementedError + def diff(self, other: ProcessSignal) -> dict | None: + """ + Compare this Process with a ProcessSignal protobuf message. + + Args: + other: ProcessSignal protobuf message to compare against + + Returns: + None if identical, dict of differences if not matching + + Raises: + NotImplementedError: If other is not a ProcessSignal + """ + if not isinstance(other, ProcessSignal): + raise NotImplementedError( + f'Cannot compare Process with {type(other)}') + + diff = {} + + # Compare each field + if self.pid is not None: + Event._diff_field(diff, 'pid', self.pid, other.pid) + + Event._diff_field(diff, 'uid', self.uid, other.uid) + Event._diff_field(diff, 'gid', self.gid, other.gid) + Event._diff_field(diff, 'exe_path', + self.exe_path, other.exec_file_path) + Event._diff_field(diff, 'args', self.args, other.args) + Event._diff_field(diff, 'name', self.name, other.name) + Event._diff_field(diff, 'container_id', + self.container_id, other.container_id) + Event._diff_field(diff, 'loginuid', + self.loginuid, other.login_uid) + + return diff if diff else None @override def __str__(self) -> str: @@ -177,11 +196,25 @@ def __str__(self) -> str: f'loginuid={self.loginuid})') -def cmp_path(p1: str | Pattern[str], p2: str) -> bool: - if isinstance(p1, Pattern): - return bool(p1.match(p2)) - else: - return p1 == p2 +def _diff_path(expected: str | Pattern[str], actual: str) -> dict | None: + """ + Compare paths with regex pattern support. + + Returns: + (field_name, diff_dict) if paths don't match, None if they match + """ + if isinstance(expected, Pattern): + if not expected.match(actual): + return { + 'expected': f'{expected}', + 'actual': actual + } + elif expected != actual: + return { + 'expected': expected, + 'actual': actual + } + return None class Event: @@ -234,32 +267,70 @@ def owner_uid(self) -> int | None: def owner_gid(self) -> int | None: return self._owner_gid - @override - def __eq__(self, other: Any) -> bool: - if isinstance(other, FileActivity): - if self.process != other.process or self.event_type.name.lower() != other.WhichOneof('file'): - return False - - if self.event_type == EventType.CREATION: - return cmp_path(self.file, other.creation.activity.path) and \ - cmp_path(self.host_path, other.creation.activity.host_path) - elif self.event_type == EventType.OPEN: - return cmp_path(self.file, other.open.activity.path) and \ - cmp_path(self.host_path, other.open.activity.host_path) - elif self.event_type == EventType.UNLINK: - return cmp_path(self.file, other.unlink.activity.path) and \ - cmp_path(self.host_path, other.unlink.activity.host_path) - elif self.event_type == EventType.PERMISSION: - return cmp_path(self.file, other.permission.activity.path) and \ - cmp_path(self.host_path, other.permission.activity.host_path) and \ - self.mode == other.permission.mode - elif self.event_type == EventType.OWNERSHIP: - return cmp_path(self.file, other.ownership.activity.path) and \ - cmp_path(self.host_path, other.ownership.activity.host_path) and \ - self.owner_uid == other.ownership.uid and \ - self.owner_gid == other.ownership.gid - return False - raise NotImplementedError + @classmethod + def _diff_field(cls, diff, name, expected, actual): + if expected != actual: + diff[name] = { + 'expected': expected, + 'actual': actual, + } + + def diff(self, other: FileActivity) -> dict | None: + """ + Compare this Event with a FileActivity protobuf message. + + Args: + other: FileActivity protobuf message to compare against + + Returns: + None if identical, dict of differences if not matching + + Raises: + NotImplementedError: If other is not a FileActivity + """ + if not isinstance(other, FileActivity): + raise NotImplementedError( + f'Cannot compare Event with {type(other)}') + + diff = {} + + # Check process differences first + process_diff = self.process.diff(other.process) + if process_diff is not None: + diff['process'] = process_diff + + # Check event type + event_type_expected = self.event_type.name.lower() + event_type_actual = other.WhichOneof('file') + + Event._diff_field(diff, 'event_type', + event_type_expected, event_type_actual) + if event_type_expected != event_type_actual: + return diff + + # Get the appropriate event field based on type + event_field = getattr(other, event_type_expected) + + # Compare file and host_path (common to all event types) + # All event types have .activity.path and .activity.host_path except they're accessed differently + file_diff = _diff_path(self.file, event_field.activity.path) + if file_diff is not None: + diff['file'] = file_diff + + host_path_diff = _diff_path(self.host_path, + event_field.activity.host_path) + if host_path_diff is not None: + diff['host_path'] = host_path_diff + + if self.event_type == EventType.PERMISSION: + Event._diff_field(diff, 'mode', self.mode, event_field.mode) + elif self.event_type == EventType.OWNERSHIP: + Event._diff_field(diff, 'owner_uid', + self.owner_uid, event_field.uid) + Event._diff_field(diff, 'owner_gid', + self.owner_gid, event_field.gid) + + return diff if diff else None @override def __str__(self) -> str: diff --git a/tests/server.py b/tests/server.py index d3aa8d1..004fa8e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,9 +1,9 @@ from concurrent import futures from collections import deque +import json from threading import Event from time import sleep -from google.protobuf.json_format import MessageToJson import grpc from internalapi.sensor import sfa_iservice_pb2_grpc @@ -89,12 +89,15 @@ def _wait_events(self, events: list[Event], strict: bool): continue print(f'Got event: {msg}') - if msg in events: - events.remove(msg) + + # Check if msg matches the next expected event + diff = events[0].diff(msg) + if diff is None: + events.pop(0) if len(events) == 0: - break + return elif strict: - raise ValueError(f'Encountered unexpected event: {msg}') + raise ValueError(json.dumps(diff, indent=4)) def wait_events(self, events: list[Event], strict: bool = True): """