Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plumbed in the ContainerManager to the state. #949

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dftimewolf/lib/collectors/grr_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,7 @@ def _ProcessQuery(
data_frame=pd.DataFrame(),
flow_identifier=flow_identifier,
client_identifier=client_identifier)
self.state.StoreContainer(results_container)
self.StoreContainer(results_container)
return

merged_results = pd.concat(results)
Expand All @@ -1372,7 +1372,7 @@ def _ProcessQuery(
flow_identifier=flow_identifier,
client_identifier=client_identifier)

self.state.StoreContainer(dataframe_container)
self.StoreContainer(dataframe_container)

def Process(self, container: containers.Host
) -> None: # pytype: disable=signature-mismatch
Expand All @@ -1383,7 +1383,7 @@ def Process(self, container: containers.Host
"""
client = self._GetClientBySelector(container.hostname)

osquery_containers = self.state.GetContainers(containers.OsqueryQuery)
osquery_containers = self.GetContainers(containers.OsqueryQuery)

host_osquery_futures = []
with ThreadPoolExecutor(self.GetQueryThreadPoolSize()) as executor:
Expand Down
4 changes: 2 additions & 2 deletions dftimewolf/lib/exporters/df_to_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def SetUp(self, output_formats: str, output_directory: str) -> None:

def Process(self) -> None:
"""Perform the exports."""
to_export = self.state.GetContainers(containers.DataFrame)
to_export = self.GetContainers(containers.DataFrame)

for df in to_export:
self._ExportSingleContainer(df)
Expand Down Expand Up @@ -133,7 +133,7 @@ def _ExportSingleContainer(self, container: containers.DataFrame) -> None:
output_format=f,
output_path=output_path)

self.state.StoreContainer(container=containers.File(
self.StoreContainer(container=containers.File(
name=os.path.basename(output_path),
path=output_path,
description=container.description))
Expand Down
3 changes: 2 additions & 1 deletion dftimewolf/lib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def GetContainers(self,
Raises:
RuntimeError: If only one metadata filter parameter is specified.
"""
containers = self.state.GetContainers(container_class,
containers = self.state.GetContainers(self.name,
container_class,
pop,
metadata_filter_key,
metadata_filter_value)
Expand Down
4 changes: 2 additions & 2 deletions dftimewolf/lib/processors/openrelik.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def DownloadWorkflowOutput(self, file_id: int, filename: str) -> str:
file.close()
return local_path

# pytype: disable=signature-mismatch
def Process(self, container: containers.File) -> None:
def Process(self, container: containers.File
) -> None: # pytype: disable=signature-mismatch
file_ids = []
folder_id = self.folder_id
if not folder_id or not self.openrelik_folder_client.folder_exists(
Expand Down
87 changes: 77 additions & 10 deletions dftimewolf/lib/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from concurrent.futures import ThreadPoolExecutor, Future
import json
import importlib
import logging
import time
Expand All @@ -17,6 +18,7 @@
from dftimewolf.lib import errors, utils
from dftimewolf.lib import telemetry
from dftimewolf.lib.containers import interface
from dftimewolf.lib.containers import manager as container_manager
from dftimewolf.lib.containers.interface import AttributeContainer
from dftimewolf.lib.errors import DFTimewolfError
from dftimewolf.lib.modules import manager as modules_manager
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(self, config: Type[Config]) -> None:
self.global_errors = [] # type: List[DFTimewolfError]
self.recipe = {} # type: Dict[str, Any]
self.store = {} # type: Dict[str, List[interface.AttributeContainer]]
self._container_manager = container_manager.ContainerManager() # Simultaneous while ensuring this change does not break anything # pylint: disable=line-too-long
self.streaming_callbacks = {} # type: Dict[Type[interface.AttributeContainer], List[Callable[[Any], Any]]] # pylint: disable=line-too-long
self._abort_execution = False
self.stdout_log = True
Expand Down Expand Up @@ -137,6 +140,7 @@ def LoadRecipe(
module_definitions = recipe.get('modules', [])
preflight_definitions = recipe.get('preflights', [])
self.ImportRecipeModules(module_locations)
self._container_manager.ParseRecipe(recipe)

for module_definition in module_definitions + preflight_definitions:
# Combine CLI args with args from the recipe description
Expand Down Expand Up @@ -222,7 +226,7 @@ def GetFromCache(self, name: str, default_value: Any = None) -> Any:
def StoreContainer(
self,
container: "interface.AttributeContainer",
source_module: str = "") -> None:
source_module: str) -> None:
"""Thread-safe method to store data in the state's store.

Args:
Expand All @@ -233,6 +237,9 @@ def StoreContainer(
container.SetMetadata(interface.METADATA_KEY_SOURCE_MODULE, source_module)
self.store.setdefault(container.CONTAINER_TYPE, []).append(container)

self._container_manager.StoreContainer(source_module=source_module,
container=container)

def LogTelemetry(
self, telemetry_entry: telemetry.TelemetryCollection) -> None:
"""Method to store telemetry in the state's telemetry store.
Expand All @@ -246,6 +253,59 @@ def LogTelemetry(
key, value, telemetry_entry.module_name, telemetry_entry.recipe)

def GetContainers(
self,
requesting_module: str,
container_class: Type[T],
pop: bool = False,
metadata_filter_key: Optional[str] = None,
metadata_filter_value: Optional[Any] = None) -> Sequence[T]:
"""Retrieve previously stored containers.

Args:
requesting_module: The name of the module making the retrieval.
container_class (type): AttributeContainer class used to filter data.
pop (Optional[bool]): Whether to remove the containers from the state when
they are retrieved.
metadata_filter_key (Optional[str]): Metadata key to filter on.
metadata_filter_value (Optional[Any]): Metadata value to filter on.

Returns:
Collection[AttributeContainer]: attribute container objects provided in
the store that correspond to the container type.

Raises:
RuntimeError: If only one metadata filter parameter is specified.
"""
# We're plumbing both methods in for a while, and reporting discrepencies
containers_orig = self._DeprecatedGetContainers(
container_class=container_class,
pop=pop,
metadata_filter_key=metadata_filter_key,
metadata_filter_value=metadata_filter_value)
containers_cm = self._container_manager.GetContainers(requesting_module,
container_class,
metadata_filter_key,
metadata_filter_value)

if (sorted([str(c) for c in containers_orig]) !=
sorted([str(c) for c in containers_cm])):
# Log some telemetry on the difference.
telem = {
'deprecated_implementation_results':
[str(c) for c in containers_orig],
'container_manager_results':
[str(c) for c in containers_orig]}

self.LogTelemetry(telemetry_entry=telemetry.TelemetryCollection(
requesting_module,
requesting_module,
self.recipe.get('name', ''),
{'GetContainer_discrepency': json.dumps(telem)}))
logger.debug('GetContainer_discrepency: %s', json.dumps(telem))

return containers_orig

def _DeprecatedGetContainers(
self,
container_class: Type[T],
pop: bool = False,
Expand Down Expand Up @@ -379,17 +439,19 @@ def _RunModuleProcessThreaded(
Returns:
List of futures for the threads that were started.
"""
cont_count = len(self.GetContainers(module.GetThreadOnContainerType()))
containers = self.GetContainers(
requesting_module=module.name,
container_class=module.GetThreadOnContainerType(),
pop=not module.KeepThreadedContainersInState())
logger.info(
f'Running {cont_count} threads, max {module.GetThreadPoolSize()} '
f'Running {len(containers)} threads, max {module.GetThreadPoolSize()} '
f'simultaneous for module {module.name}')

futures = []

with ThreadPoolExecutor(max_workers=module.GetThreadPoolSize()) \
as executor:
pop = not module.KeepThreadedContainersInState()
for c in self.GetContainers(module.GetThreadOnContainerType(), pop):
for c in containers:
logger.debug(
f'Launching {module.name}.Process thread with {str(c)} from '
f'{c.metadata.get(interface.METADATA_KEY_SOURCE_MODULE, "Unknown")}'
Expand Down Expand Up @@ -501,6 +563,9 @@ def _RunModuleThread(self, module_definition: Dict[str, str]) -> None:

logger.info('Module {0:s} finished execution'.format(runtime_name))
self._threading_event_per_module[runtime_name].set()

self._container_manager.CompleteModule(runtime_name)

self.CleanUp()

def RunPreflights(self) -> None:
Expand Down Expand Up @@ -735,20 +800,22 @@ def _RunModuleProcessThreaded(
Returns:
List of futures for the threads that were started.
"""
cont_count = len(self.GetContainers(module.GetThreadOnContainerType()))
containers = self.GetContainers(
requesting_module=module.name,
container_class=module.GetThreadOnContainerType(),
pop=not module.KeepThreadedContainersInState())
logger.info(
f'Running {cont_count} threads, max {module.GetThreadPoolSize()} '
f'Running {len(containers)} threads, max {module.GetThreadPoolSize()} '
f'simultaneous for module {module.name}')

self.cursesdm.SetThreadedModuleContainerCount(module.name, cont_count)
self.cursesdm.SetThreadedModuleContainerCount(module.name, len(containers))
self.cursesdm.UpdateModuleStatus(module.name, cdm.Status.PROCESSING)

futures = []

with ThreadPoolExecutor(max_workers=module.GetThreadPoolSize()) \
as executor:
pop = not module.KeepThreadedContainersInState()
for c in self.GetContainers(module.GetThreadOnContainerType(), pop):
for c in containers:
futures.append(
executor.submit(self._WrapThreads, module.Process, c, module.name))

Expand Down
25 changes: 13 additions & 12 deletions tests/cli/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,19 @@ def testRecipeModulesAllPresent(self):
"""Tests that a recipe's modules depend only on modules present in the
recipe."""
for recipe in self._recipes_manager.GetRecipes():
declared_modules = set()
wanted_modules = set()
for module in recipe.contents['modules']:
module_name = module['name']
runtime_name = module.get('runtime_name', module_name)
declared_modules.add(runtime_name)
for wanted in module['wants']:
wanted_modules.add(wanted)

for wanted_module in wanted_modules:
self.assertIn(wanted_module, declared_modules,
msg='recipe: {0:s}'.format(recipe.contents['name']))
with self.subTest(recipe.contents['name']):
declared_modules = set()
wanted_modules = set()
for module in (recipe.contents['modules'] +
recipe.contents.get('preflights', [])):
module_name = module['name']
runtime_name = module.get('runtime_name', module_name)
declared_modules.add(runtime_name)
for wanted in module['wants']:
wanted_modules.add(wanted)

self.assertTrue(wanted_modules.issubset(declared_modules),
msg='recipe: {0:s}'.format(recipe.contents['name']))

def testNoDeadlockInRecipe(self):
"""Tests that a recipe will not deadlock."""
Expand Down
34 changes: 24 additions & 10 deletions tests/e2e/aws_disk_forensics.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,19 @@ def testRunRecipe(self):
self.test_state.RunModules()

# AWS Volume in count should equal GCE Disk out count, and be at least 1
self.assertGreaterEqual(
len(self.test_state.GetContainers(containers.AWSVolume)), 1)
self.assertEqual(len(self.test_state.GetContainers(containers.AWSVolume)),
len(self.test_state.GetContainers(containers.GCEDisk)))
aws_volumes = self.test_state.GetContainers(
container_class=containers.AWSVolume,
requesting_module='AWSVolumeSnapshotCollector')
gce_disks = self.test_state.GetContainers(
container_class=containers.GCEDisk,
requesting_module='GCEDiskFromImage')
self.assertGreaterEqual(len(aws_volumes), 1)
self.assertEqual(len(aws_volumes), len(gce_disks))

disks = compute.GoogleCloudCompute(self.gcp_project_id).Disks()
real_gce_disk_names = [disks[k].name for k in disks.keys()]

for d in self.test_state.GetContainers(containers.GCEDisk):
for d in gce_disks:
self.assertIn(d.name, real_gce_disk_names)
real_disk = compute.GoogleComputeDisk(
self.gcp_project_id, self.gcp_zone, d.name)
Expand All @@ -217,15 +221,25 @@ def tearDown(self):
log.warning("Cleaning up after test...")
# All of the following artefacts are created: AWSSnapshot, AWSS3Object,
# GCSObject, GCEImage, GCEDisk
for c in self.test_state.GetContainers(containers.AWSSnapshot):
for c in self.test_state.GetContainers(
container_class=containers.AWSSnapshot,
requesting_module='AWSVolumeSnapshotCollector'):
self._removeAWSSnapshot(c.id)
for c in self.test_state.GetContainers(containers.AWSS3Object):
for c in self.test_state.GetContainers(
container_class=containers.AWSS3Object,
requesting_module='AWSSnapshotS3CopyCollector'):
self._removeAWSS3Object(c.path)
for c in self.test_state.GetContainers(containers.GCSObject):
for c in self.test_state.GetContainers(
container_class=containers.GCSObject,
requesting_module='S3ToGCSCopy'):
self._removeGCSObject(c.path)
for c in self.test_state.GetContainers(containers.GCEImage):
for c in self.test_state.GetContainers(
container_class=containers.GCEImage,
requesting_module='GCSToGCEImage'):
self._removeGCEImage(c.name)
for c in self.test_state.GetContainers(containers.GCEDisk):
for c in self.test_state.GetContainers(
container_class=containers.GCEDisk,
requesting_module='GCEDiskFromImage'):
self._removeGCEDisk(c.name)

def _removeAWSSnapshot(self, snap_id: str):
Expand Down
28 changes: 19 additions & 9 deletions tests/e2e/gcp_disk_forensics.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def setUp(self):

def tearDown(self):
log.info("Cleaning up after test...")
for vm in self.test_state.GetContainers(containers.ForensicsVM):
for vm in self.test_state.GetContainers(
container_class=containers.ForensicsVM,
requesting_module='GCEForensicsVM'):
CleanUp(self.project_id, self.zone, vm.name)

self._recipes_manager.DeregisterRecipe(self._recipe)
Expand All @@ -178,9 +180,11 @@ def testBootDiskCopy(self):
self.test_state.RunModules()

# Get the forensics VM name, confirm it exists
self.assertEqual(1,
len(self.test_state.GetContainers(containers.ForensicsVM)))
for_vm = self.test_state.GetContainers(containers.ForensicsVM)[0]
vm_containers = self.test_state.GetContainers(
container_class=containers.ForensicsVM,
requesting_module='GCEForensicsVM')
self.assertEqual(1, len(vm_containers))
for_vm = vm_containers[0]

gce_instances_client = self.gcp_client.GceApi().instances()
request = gce_instances_client.get(
Expand All @@ -194,7 +198,9 @@ def testBootDiskCopy(self):
actual_disks = compute.GoogleComputeInstance(
self.project_id, self.zone, for_vm.name).ListDisks().keys()
# The source disk will be the first in the container list, so exclude it.
expected_disks = self.test_state.GetContainers(containers.GCEDisk)[1:]
expected_disks = self.test_state.GetContainers(
container_class=containers.GCEDisk,
requesting_module='GCEDiskCopy')[1:]

# Length should differ by 1 for the boot disk
self.assertEqual(len(actual_disks), len(expected_disks) + 1)
Expand Down Expand Up @@ -225,9 +231,11 @@ def testOtherDiskCopy(self):
self.test_state.RunModules()

# Get the forensics VM name, confirm it exists
self.assertEqual(1,
len(self.test_state.GetContainers(containers.ForensicsVM)))
for_vm = self.test_state.GetContainers(containers.ForensicsVM)[0]
vm_containers = self.test_state.GetContainers(
container_class=containers.ForensicsVM,
requesting_module='GCEForensicsVM')
self.assertEqual(1, len(vm_containers))
for_vm = vm_containers[0]

gce_instances_client = self.gcp_client.GceApi().instances()
request = gce_instances_client.get(
Expand All @@ -241,7 +249,9 @@ def testOtherDiskCopy(self):
actual_disks = compute.GoogleComputeInstance(
self.project_id, self.zone, for_vm.name).ListDisks().keys()
# The source disk will be the first in the container list, so exclude it.
expected_disks = self.test_state.GetContainers(containers.GCEDisk)[1:]
expected_disks = self.test_state.GetContainers(
container_class=containers.GCEDisk,
requesting_module='GCEDiskCopy')[1:]

# Length should differ by 1 for the boot disk
self.assertEqual(len(actual_disks), len(expected_disks) + 1)
Expand Down
Loading
Loading