Skip to content

Commit

Permalink
fix: improved collection performance on regions that don't have aws r…
Browse files Browse the repository at this point in the history
…esources to collect

Signed-off-by: Sooyoung98 <[email protected]>
  • Loading branch information
Sooyoung98 committed May 21, 2024
1 parent 7fdf5f8 commit 117e530
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 35 deletions.
3 changes: 3 additions & 0 deletions src/plugin/connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def set_client(self, service_name):
def get_account_id(self):
return self.account_id

def load_account_id(self, account_id):
self.account_id = account_id

def set_account_id(self):
sts_client = self.session.client("sts", verify=BOTO3_HTTPS_VERIFIED)
self.account_id = sts_client.get_caller_identity()["Account"]
Expand Down
9 changes: 6 additions & 3 deletions src/plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def collector_collect(params):
service = task_options.get("service")
region = task_options.get("region")
resource_mgrs = ResourceManager.get_manager_by_service(service)
account_id = ResourceManager.get_account_id(secret_data, region)
options["account_id"] = account_id
for resource_mgr in resource_mgrs:
results = resource_mgr().collect_resources(
region, options, secret_data, schema
Expand Down Expand Up @@ -245,7 +247,7 @@ def _add_cloud_service_type_tasks(services: list) -> list:
def _add_metric_tasks(services: list) -> list:
# Specific cloud_service_group list.
metric_services = [
"CertificateManager", # "ACM",
"CertificateManager", # "ACM",
"CloudFront",
"CloudTrail",
"DocumentDB",
Expand All @@ -258,11 +260,12 @@ def _add_metric_tasks(services: list) -> list:
"KMS",
"Lambda",
"Route53",
"S3"
"S3",
]
return [
_make_task_wrapper(
resource_type="inventory.Metric", services=metric_services
resource_type="inventory.Metric",
services=metric_services,
# resource_type="inventory.Metric", services = services # origin
)
]
Expand Down
9 changes: 8 additions & 1 deletion src/plugin/manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from plugin.conf.cloud_service_conf import REGION_INFO
from plugin.connector.base import ResourceConnector


_LOGGER = logging.getLogger(__name__)
CURRENT_DIR = os.path.dirname(__file__)
METRIC_DIR = os.path.join(CURRENT_DIR, "../metrics/")
Expand Down Expand Up @@ -253,6 +252,14 @@ def datetime_to_iso8601(value: datetime.datetime):
return f"{value.isoformat()}"
return None

@classmethod
def get_account_id(cls, secret_data, region):
resource_connector = ResourceConnector(
secret_data=secret_data, region_name=region
)
resource_connector.set_account_id()
return resource_connector.get_account_id()

@abc.abstractmethod
def create_cloud_service_type(self):
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions src/plugin/manager/ec2/ami_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def create_cloud_service(self, region, options, secret_data, schema):
self.cloud_service_type = "AMI"
cloudtrail_resource_type = "AWS::EC2::Ami"
results = self.connector.get_ami_images()
self.connector.set_account_id()
account_id = self.connector.get_account_id()
account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)
for image in results.get("Images", []):
try:
try:
Expand Down
35 changes: 16 additions & 19 deletions src/plugin/manager/ec2/auto_scaling_group_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,15 @@ def create_cloud_service(self, region, options, secret_data, schema):
cloudwatch_namespace = "AWS/AutoScaling"
cloudwatch_dimension_name = "AutoScalingGroupName"
cloudtrail_resource_type = "AWS::AutoScaling::AutoScalingGroup"

account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)
pre_collect_list = [
# self._create_launch_configurations,
self._create_launch_templates,
]
for pre_collect in pre_collect_list:
yield from pre_collect(region)
yield from pre_collect(region, account_id)
results = self.connector.get_auto_scaling_groups()
self.connector.set_account_id()
account_id = self.connector.get_account_id()
policies = None
notification_configurations = None

Expand Down Expand Up @@ -152,9 +151,9 @@ def create_cloud_service(self, region, options, secret_data, schema):
}
)
elif (
raw.get("MixedInstancesPolicy", {})
.get("LaunchTemplate", {})
.get("LaunchTemplateSpecification")
raw.get("MixedInstancesPolicy", {})
.get("LaunchTemplate", {})
.get("LaunchTemplateSpecification")
):
_lt_info = (
raw.get("MixedInstancesPolicy", {})
Expand Down Expand Up @@ -259,7 +258,7 @@ def get_asg_instances(self, instances):
max_count = 20
instances_from_ec2 = []
split_instances = [
instances[i : i + max_count] for i in range(0, len(instances), max_count)
instances[i: i + max_count] for i in range(0, len(instances), max_count)
]

for instances in split_instances:
Expand Down Expand Up @@ -297,7 +296,7 @@ def get_load_balancer_arns(self, target_group_arns):
max_count = 20

split_tgs_arns = [
target_group_arns[i : i + max_count]
target_group_arns[i: i + max_count]
for i in range(0, len(target_group_arns), max_count)
]

Expand All @@ -317,7 +316,7 @@ def get_load_balancer_info(self, lb_arns):
max_count = 20

split_lb_arns = [
lb_arns[i : i + max_count] for i in range(0, len(lb_arns), max_count)
lb_arns[i: i + max_count] for i in range(0, len(lb_arns), max_count)
]

load_balancer_data_list = []
Expand Down Expand Up @@ -381,9 +380,9 @@ def _match_launch_template(self, raw):
if raw.get("LaunchTemplate"):
lt_dict = raw.get("LaunchTemplate")
elif (
raw.get("MixedInstancesPolicy", {})
.get("LaunchTemplate", {})
.get("LaunchTemplateSpecification")
raw.get("MixedInstancesPolicy", {})
.get("LaunchTemplate", {})
.get("LaunchTemplateSpecification")
):
lt_dict = (
raw.get("MixedInstancesPolicy", {})
Expand All @@ -400,7 +399,7 @@ def _match_launch_template(self, raw):
launch_template
for launch_template in self._launch_templates
if launch_template.get("LaunchTemplateId")
== lt_dict.get("LaunchTemplateId")
== lt_dict.get("LaunchTemplateId")
),
None,
)
Expand Down Expand Up @@ -508,13 +507,11 @@ def _create_launch_configurations(self, region):

return result_list

def _create_launch_templates(self, region):
def _create_launch_templates(self, region, account_id):
cloud_service_type = "LaunchTemplate"
cloudtrail_resource_type = "AWS::AutoScaling::LaunchTemplate"

response = self.connector.get_launch_templates()
self.connector.set_account_id()
account_id = self.connector.get_account_id()
result_list = []
for data in response:
for raw in data.get("LaunchTemplates", []):
Expand All @@ -539,8 +536,8 @@ def _create_launch_templates(self, region):
account_id="",
resource_type="launch_template",
resource_id=raw["LaunchTemplateId"]
+ "/v"
+ str(match_lt_version.get("VersionNumber")),
+ "/v"
+ str(match_lt_version.get("VersionNumber")),
),
"cloudtrail": self.set_cloudtrail(
region,
Expand Down
4 changes: 2 additions & 2 deletions src/plugin/manager/ec2/eip_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def create_cloud_service_type(self):
def create_cloud_service(self, region, options, secret_data, schema):
cloudtrail_resource_type = "AWS::EC2::EIP"
results = self.connector.get_addresses()
self.connector.set_account_id()
account_id = self.connector.get_account_id()
account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)
nat_gateways = None
network_interfaces = None
eips = results.get("Addresses", [])
Expand Down
4 changes: 2 additions & 2 deletions src/plugin/manager/ec2/security_group_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def create_cloud_service(self, region, options, secret_data, schema):

# Get Security Group
results = self.connector.get_security_groups()
self.connector.set_account_id()
account_id = self.connector.get_account_id()
account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)

for data in results:
for raw in data.get("SecurityGroups", []):
Expand Down
4 changes: 2 additions & 2 deletions src/plugin/manager/ec2/snapshot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def create_cloud_service_type(self):

def create_cloud_service(self, region, options, secret_data, schema):
cloudtrail_resource_type = "AWS::EC2::Snapshot"
self.connector.set_account_id()
account_id = self.connector.get_account_id()
account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)
results = self.connector.get_snapshots()

for data in results:
Expand Down
4 changes: 2 additions & 2 deletions src/plugin/manager/ec2/volume_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def create_cloud_service_type(self):

def create_cloud_service(self, region, options, secret_data, schema):
cloudtrail_resource_type = "AWS::EC2::Volume"
self.connector.set_account_id()
account_id = self.connector.get_account_id()
account_id = options.get("account_id", "")
self.connector.load_account_id(account_id)
cloudwatch_namespace = "AWS/EBS"
cloudwatch_dimension_name = "VolumeId"
results = self.connector.get_volumes()
Expand Down
4 changes: 2 additions & 2 deletions test/api/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
EXTERNAL_ID = os.environ.get("EXTERNAL_ID", None)
REGION_NAME = os.environ.get("REGION_NAME", None)


if AKI == None or SAK == None:
print(
"""
Expand Down Expand Up @@ -90,9 +89,10 @@ def test_full_collect(self):
task_options = task["task_options"]
filter = {}
params = {
"options": task_options,
"options": {},
"secret_data": self.secret_data,
"filter": filter,
"task_options": task_options,
}
res_stream = self.inventory.Collector.collect(params)
for res in res_stream:
Expand Down

0 comments on commit 117e530

Please sign in to comment.