Skip to content

Commit 4e92b51

Browse files
authored
[JobInfo] Fix the retrieval of job info by making the SSM command to store the outputs on CloudWatch logs to prevent truncation. (#406)
This change fixes #376
1 parent de496fe commit 4e92b51

File tree

4 files changed

+216
-2
lines changed

4 files changed

+216
-2
lines changed

api/PclusterApiHandler.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from api.pcm_globals import set_auth_cookies_in_context, logger, auth_cookies
2626
from api.security.csrf.constants import CSRF_COOKIE_NAME
2727
from api.security.csrf.csrf import csrf_needed
28-
from api.utils import disable_auth
28+
from api.utils import disable_auth, read_and_delete_ssm_output_from_cloudwatch
2929
from api.validation import validated
3030
from api.validation.schemas import PCProxyArgs, PCProxyBody
3131

@@ -47,6 +47,7 @@
4747
JWKS_URL = os.getenv("JWKS_URL")
4848
AUDIENCE = os.getenv("AUDIENCE")
4949
USER_ROLES_CLAIM = os.getenv("USER_ROLES_CLAIM", "cognito:groups")
50+
SSM_LOG_GROUP_NAME = os.getenv("SSM_LOG_GROUP_NAME")
5051

5152
try:
5253
if (not USER_POOL_ID or USER_POOL_ID == "") and SECRET_ID:
@@ -264,10 +265,16 @@ def ssm_command(region, instance_id, user, run_command):
264265
DocumentName="AWS-RunShellScript",
265266
Comment=f"Run ssm command.",
266267
Parameters={"commands": [command]},
268+
CloudWatchOutputConfig={
269+
'CloudWatchLogGroupName': SSM_LOG_GROUP_NAME,
270+
'CloudWatchOutputEnabled': True
271+
},
267272
)
268273

269274
command_id = ssm_resp["Command"]["CommandId"]
270275

276+
logger.info(f"Submitted SSM command {command_id}")
277+
271278
# Wait for command to complete
272279
time.sleep(0.75)
273280
while time.time() - start < 60:
@@ -282,7 +289,13 @@ def ssm_command(region, instance_id, user, run_command):
282289
if status["Status"] != "Success":
283290
raise Exception(status["StandardErrorContent"])
284291

285-
output = status["StandardOutputContent"]
292+
output = read_and_delete_ssm_output_from_cloudwatch(
293+
region=region,
294+
log_group_name=SSM_LOG_GROUP_NAME,
295+
command_id=command_id,
296+
instance_id=instance_id,
297+
)
298+
286299
return output
287300

288301

api/tests/test_utils.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
from api.utils import read_and_delete_ssm_output_from_cloudwatch, normalize_logs_token
4+
5+
6+
@pytest.fixture
7+
def mock_boto3_client():
8+
with patch('boto3.client') as mock_client:
9+
yield mock_client
10+
11+
@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities")
12+
@pytest.mark.parametrize(
13+
"responses, expected_result, expected_call_count", [
14+
pytest.param(
15+
[
16+
{
17+
'events': [
18+
{'message': 'line1'},
19+
{'message': 'line2'}
20+
],
21+
'nextForwardToken': 'token1',
22+
'nextBackwardToken': 'token1'
23+
},
24+
],
25+
"line1\nline2",
26+
1,
27+
id="logs_on_single_page"
28+
),
29+
pytest.param(
30+
[
31+
{
32+
'events': [
33+
{'message': 'line1'},
34+
{'message': 'line2'}
35+
],
36+
'nextForwardToken': 'token1',
37+
'nextBackwardToken': 'token2'
38+
},
39+
{
40+
'events': [
41+
{'message': 'line3'}
42+
],
43+
'nextForwardToken': 'token2',
44+
'nextBackwardToken': 'token2'
45+
}
46+
],
47+
"line1\nline2\nline3",
48+
2,
49+
id="logs_on_multiple_pages"
50+
),
51+
pytest.param(
52+
[
53+
{
54+
'events': [],
55+
'nextForwardToken': 'token1',
56+
'nextBackwardToken': 'token1'
57+
},
58+
],
59+
"",
60+
1,
61+
id="empty_logs"
62+
),
63+
])
64+
def test_read_and_delete_ssm_output_from_cloudwatch_success(
65+
mock_boto3_client, responses, expected_result, expected_call_count
66+
):
67+
mock_logs = Mock()
68+
mock_logs.get_log_events.side_effect = responses
69+
mock_boto3_client.return_value = mock_logs
70+
71+
result = read_and_delete_ssm_output_from_cloudwatch(
72+
region='us-east-1',
73+
log_group_name='/aws/ssm/test',
74+
command_id='cmd-123',
75+
instance_id='i-123',
76+
)
77+
78+
# Assert
79+
assert result == expected_result
80+
mock_boto3_client.assert_called_once_with('logs', region_name='us-east-1')
81+
assert mock_logs.get_log_events.call_count == expected_call_count
82+
assert mock_logs.delete_log_stream.call_count == 1
83+
84+
@pytest.mark.skip("this test is temporarily disabled because it requires refactoring of the logging utilities")
85+
@pytest.mark.parametrize(
86+
"input_token, expected_output", [
87+
pytest.param(
88+
'f/WHATEVER/s',
89+
'WHATEVER/s',
90+
id="forward_token"
91+
),
92+
pytest.param(
93+
'b/WHATEVER/s',
94+
'WHATEVER/s',
95+
id="backward_token"
96+
),
97+
]
98+
)
99+
def test_normalize_logs_token(input_token, expected_output):
100+
result = normalize_logs_token(str(input_token))
101+
assert result == expected_output, f"Failed for input '{input_token}'. Expected '{expected_output}', got '{result}'"
102+
103+
104+

api/utils.py

+66
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import datetime
1212
import os
1313

14+
import boto3
1415
import dateutil
1516
from flask import Flask, Response, request, send_from_directory
1617
import requests
@@ -110,3 +111,68 @@ def serve_frontend(app, path=""):
110111
return proxy_to("http://localhost:3000/" + path)
111112

112113
return send_from_directory(app.static_folder, "index.html")
114+
115+
def read_and_delete_ssm_output_from_cloudwatch(
116+
region: str,
117+
log_group_name: str,
118+
command_id: str,
119+
instance_id: str,
120+
) -> str:
121+
logs_client = boto3.client('logs', region_name=region)
122+
123+
log_stream_name = f"{command_id}/{instance_id}/aws-runShellScript/stdout"
124+
125+
logger.info(
126+
f"Reading output for SSM command {command_id} from logstream {log_stream_name} in log group {log_group_name}"
127+
)
128+
129+
output_lines = []
130+
131+
try:
132+
next_token = None
133+
while True:
134+
request_params = dict(
135+
logGroupName=log_group_name,
136+
logStreamName=log_stream_name,
137+
startFromHead=True,
138+
)
139+
if next_token:
140+
request_params['nextToken'] = next_token
141+
response = logs_client.get_log_events(**request_params)
142+
log_events = response.get('events', [])
143+
next_token = response.get('nextForwardToken')
144+
next_backward_token = response.get('nextBackwardToken')
145+
146+
for event in log_events:
147+
message = event.get('message', '').strip()
148+
if message:
149+
output_lines.append(message)
150+
if not next_token or normalize_logs_token(next_token) == normalize_logs_token(next_backward_token):
151+
break
152+
delete_log_stream(logs_client, log_group_name, log_stream_name)
153+
except Exception as ex:
154+
logger.error(
155+
f"Failed to read output for SSM command {command_id} "
156+
f"from logstream {log_stream_name} in log group {log_group_name}: {ex}"
157+
)
158+
delete_log_stream(logs_client, log_group_name, log_stream_name)
159+
160+
logger.info(
161+
f"Completed reading of output for SSM command {command_id} "
162+
f"from logstream {log_stream_name} in log group {log_group_name}"
163+
)
164+
165+
return "\n".join(output_lines)
166+
167+
def normalize_logs_token(token: str) -> str:
168+
return token.split('/', 1)[1] if token and '/' in token else token
169+
170+
def delete_log_stream(logs_client, log_group_name: str, log_stream_name: str):
171+
try:
172+
logs_client.delete_log_stream(
173+
logGroupName=log_group_name,
174+
logStreamName=log_stream_name,
175+
)
176+
logger.info(f"Deleted log stream {log_stream_name} in log group {log_group_name}")
177+
except Exception as ex:
178+
logger.error(f"Failed to delete log stream {log_stream_name} in log group {log_group_name}: {ex}")

infrastructure/parallelcluster-ui.yaml

+31
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ Resources:
278278
- UseCustomDomain
279279
- !FindInMap [ ParallelClusterUI, Constants, CustomDomainBasePath ]
280280
- !Ref AWS::NoValue
281+
SSM_LOG_GROUP_NAME: !Ref SsmLogGroup
281282
FunctionName: !Sub
282283
- ParallelClusterUIFun-${StackIdSuffix}
283284
- { StackIdSuffix: !Select [2, !Split ['/', !Ref 'AWS::StackId']] }
@@ -838,6 +839,12 @@ Resources:
838839
LogGroupName: !Sub /aws/lambda/${ParallelClusterUIFun}
839840
RetentionInDays: 90
840841

842+
SsmLogGroup:
843+
Type: AWS::Logs::LogGroup
844+
Properties:
845+
LogGroupName: !Sub /aws/ssm/ParallelClusterUI-${AWS::StackName}
846+
RetentionInDays: 1
847+
LogGroupClass: STANDARD
841848

842849
ParallelClusterUIUserRole:
843850
Type: AWS::IAM::Role
@@ -861,6 +868,7 @@ Resources:
861868
- !Ref CognitoPolicy
862869
- !Ref EC2Policy
863870
- !Ref StoragePolicy
871+
- !Ref LogsPolicy
864872
- !Ref CostMonitoringAndPricingPolicy
865873
- !Ref SsmPolicy
866874
PermissionsBoundary: !If [UsePermissionBoundary, !Ref PermissionsBoundaryPolicy, !Ref 'AWS::NoValue']
@@ -1054,6 +1062,29 @@ Resources:
10541062
Effect: Allow
10551063
Sid: SsmGetCommandInvocationPolicy
10561064

1065+
LogsPolicy:
1066+
Type: AWS::IAM::ManagedPolicy
1067+
Properties:
1068+
ManagedPolicyName: !Sub
1069+
- ${IAMRoleAndPolicyPrefix}LogsPolicy-${StackIdSuffix}
1070+
- { StackIdSuffix: !Select [ 0, !Split [ '-', !Select [ 2, !Split [ '/', !Ref 'AWS::StackId' ] ] ] ] }
1071+
PolicyDocument:
1072+
Version: '2012-10-17'
1073+
Statement:
1074+
- Action:
1075+
- logs:GetLogEvents
1076+
Resource:
1077+
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:*"
1078+
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*"
1079+
Effect: Allow
1080+
Sid: CloudWatchLogsRead
1081+
- Action:
1082+
- logs:DeleteLogStream
1083+
Resource:
1084+
- !Sub "arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:${SsmLogGroup}:log-stream:*/*/aws-runShellScript/stdout"
1085+
Effect: Allow
1086+
Sid: CloudWatchLogsDelete
1087+
10571088
ApiGatewayCustomDomain:
10581089
Condition: UseCustomDomain
10591090
Type: AWS::ApiGateway::DomainName

0 commit comments

Comments
 (0)