Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 118 additions & 47 deletions tests/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,41 @@ def container_id(self) -> str:
def loginuid(self) -> int:
return self._loginuid

@override
def __eq__(self, other: Any) -> bool:
if isinstance(other, ProcessSignal):
if self.pid is not None and self.pid != other.pid:
return False

return (
self.uid == other.uid and
self.gid == other.gid and
self.exe_path == other.exec_file_path and
self.args == other.args and
self.name == other.name and
self.container_id == other.container_id and
self.loginuid == other.login_uid
)
raise NotImplementedError
def diff(self, other: ProcessSignal) -> dict | None:
"""
Compare this Process with a ProcessSignal protobuf message.

Args:
other: ProcessSignal protobuf message to compare against

Returns:
None if identical, dict of differences if not matching

Raises:
NotImplementedError: If other is not a ProcessSignal
"""
if not isinstance(other, ProcessSignal):
raise NotImplementedError(
f'Cannot compare Process with {type(other)}')

diff = {}

# Compare each field
if self.pid is not None:
Event._diff_field(diff, 'pid', self.pid, other.pid)

Event._diff_field(diff, 'uid', self.uid, other.uid)
Event._diff_field(diff, 'gid', self.gid, other.gid)
Event._diff_field(diff, 'exe_path',
self.exe_path, other.exec_file_path)
Event._diff_field(diff, 'args', self.args, other.args)
Event._diff_field(diff, 'name', self.name, other.name)
Event._diff_field(diff, 'container_id',
self.container_id, other.container_id)
Event._diff_field(diff, 'loginuid',
self.loginuid, other.login_uid)

return diff if diff else None

@override
def __str__(self) -> str:
Expand All @@ -177,11 +196,25 @@ def __str__(self) -> str:
f'loginuid={self.loginuid})')


def cmp_path(p1: str | Pattern[str], p2: str) -> bool:
if isinstance(p1, Pattern):
return bool(p1.match(p2))
else:
return p1 == p2
def _diff_path(expected: str | Pattern[str], actual: str) -> dict | None:
"""
Compare paths with regex pattern support.

Returns:
(field_name, diff_dict) if paths don't match, None if they match
"""
if isinstance(expected, Pattern):
if not expected.match(actual):
return {
'expected': f'{expected}',
'actual': actual
}
elif expected != actual:
return {
'expected': expected,
'actual': actual
}
return None


class Event:
Expand Down Expand Up @@ -234,32 +267,70 @@ def owner_uid(self) -> int | None:
def owner_gid(self) -> int | None:
return self._owner_gid

@override
def __eq__(self, other: Any) -> bool:
if isinstance(other, FileActivity):
if self.process != other.process or self.event_type.name.lower() != other.WhichOneof('file'):
return False

if self.event_type == EventType.CREATION:
return cmp_path(self.file, other.creation.activity.path) and \
cmp_path(self.host_path, other.creation.activity.host_path)
elif self.event_type == EventType.OPEN:
return cmp_path(self.file, other.open.activity.path) and \
cmp_path(self.host_path, other.open.activity.host_path)
elif self.event_type == EventType.UNLINK:
return cmp_path(self.file, other.unlink.activity.path) and \
cmp_path(self.host_path, other.unlink.activity.host_path)
elif self.event_type == EventType.PERMISSION:
return cmp_path(self.file, other.permission.activity.path) and \
cmp_path(self.host_path, other.permission.activity.host_path) and \
self.mode == other.permission.mode
elif self.event_type == EventType.OWNERSHIP:
return cmp_path(self.file, other.ownership.activity.path) and \
cmp_path(self.host_path, other.ownership.activity.host_path) and \
self.owner_uid == other.ownership.uid and \
self.owner_gid == other.ownership.gid
return False
raise NotImplementedError
@classmethod
def _diff_field(cls, diff, name, expected, actual):
if expected != actual:
diff[name] = {
'expected': expected,
'actual': actual,
}

def diff(self, other: FileActivity) -> dict | None:
"""
Compare this Event with a FileActivity protobuf message.

Args:
other: FileActivity protobuf message to compare against

Returns:
None if identical, dict of differences if not matching

Raises:
NotImplementedError: If other is not a FileActivity
"""
if not isinstance(other, FileActivity):
raise NotImplementedError(
f'Cannot compare Event with {type(other)}')

diff = {}

# Check process differences first
process_diff = self.process.diff(other.process)
if process_diff is not None:
diff['process'] = process_diff

# Check event type
event_type_expected = self.event_type.name.lower()
event_type_actual = other.WhichOneof('file')

Event._diff_field(diff, 'event_type',
event_type_expected, event_type_actual)
if event_type_expected != event_type_actual:
return diff

# Get the appropriate event field based on type
event_field = getattr(other, event_type_expected)

# Compare file and host_path (common to all event types)
# All event types have .activity.path and .activity.host_path except they're accessed differently
file_diff = _diff_path(self.file, event_field.activity.path)
if file_diff is not None:
diff['file'] = file_diff

host_path_diff = _diff_path(self.host_path,
event_field.activity.host_path)
if host_path_diff is not None:
diff['host_path'] = host_path_diff

if self.event_type == EventType.PERMISSION:
Event._diff_field(diff, 'mode', self.mode, event_field.mode)
elif self.event_type == EventType.OWNERSHIP:
Event._diff_field(diff, 'owner_uid',
self.owner_uid, event_field.uid)
Event._diff_field(diff, 'owner_gid',
self.owner_gid, event_field.gid)

return diff if diff else None

@override
def __str__(self) -> str:
Expand Down
13 changes: 8 additions & 5 deletions tests/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from concurrent import futures
from collections import deque
import json
from threading import Event
from time import sleep

from google.protobuf.json_format import MessageToJson
import grpc

from internalapi.sensor import sfa_iservice_pb2_grpc
Expand Down Expand Up @@ -89,12 +89,15 @@ def _wait_events(self, events: list[Event], strict: bool):
continue

print(f'Got event: {msg}')
if msg in events:
events.remove(msg)

# Check if msg matches the next expected event
diff = events[0].diff(msg)
if diff is None:
events.pop(0)
if len(events) == 0:
break
return
elif strict:
raise ValueError(f'Encountered unexpected event: {msg}')
raise ValueError(json.dumps(diff, indent=4))

def wait_events(self, events: list[Event], strict: bool = True):
"""
Expand Down