Skip to content

Commit

Permalink
fix(mqtt3/5): deduplicate subscription callbacks using deprecated API (
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeDombo committed Dec 5, 2023
1 parent 6af6aeb commit 0e4cfec
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,19 @@ public CompletableFuture<SubscribeToIoTCoreResponse> handleRequestAsync(Subscrib
.log("Unable to subscribe to topic");
throw new ServiceError(String.format("Subscribe to topic %s failed with error %s", topic, t));
}).thenApply((i) -> {
if (i != null) {
int rc = i.getReasonCode();
if (rc > 2) {
String rcString = SubAckPacket.SubAckReasonCode.UNSPECIFIED_ERROR.name();
try {
rcString = SubAckPacket.SubAckReasonCode.getEnumValueFromInteger(rc).name();
} catch (RuntimeException ignored) {
}

throw new ServiceError(
String.format("Subscribe to topic %s failed with error %s", topic,
rcString))
.withContext(Utils.immutableMap("reasonString", i.getReasonString(),
"reasonCode", i.getReasonCode()));
if (i != null && !i.isSuccessful()) {
String rcString = SubAckPacket.SubAckReasonCode.UNSPECIFIED_ERROR.name();
try {
rcString =
SubAckPacket.SubAckReasonCode.getEnumValueFromInteger(i.getReasonCode()).name();
} catch (RuntimeException ignored) {
}

throw new ServiceError(
String.format("Subscribe to topic %s failed with error %s", topic,
rcString))
.withContext(Utils.immutableMap("reasonString", i.getReasonString(),
"reasonCode", i.getReasonCode()));
}
return new SubscribeToIoTCoreResponse();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public class IotJobsHelper implements InjectionActions {
@Setter // For tests
private IotJobsClientWrapper iotJobsClientWrapper;

private AtomicBoolean isSubscribedToIotJobsTopics = new AtomicBoolean(false);
private final AtomicBoolean isSubscribedToIotJobsTopics = new AtomicBoolean(false);
private Future<?> subscriptionFuture;
private volatile String thingName;

Expand Down Expand Up @@ -179,7 +179,7 @@ public class IotJobsHelper implements InjectionActions {
/**
* Handler that gets invoked when a job description is received.
* Next pending job description is requested when an mqtt message
* is published using {@Code requestNextPendingJobDocument} in {@link IotJobsHelper}
* is published using {@code requestNextPendingJobDocument} in {@link IotJobsHelper}
*/
private final Consumer<DescribeJobExecutionResponse> describeJobExecutionResponseConsumer = response -> {
if (response.execution == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ private void resubscribeDroppedTopicsTask() {
List<CompletableFuture<SubscribeResponse>> subFutures = new ArrayList<>();
for (Subscribe sub : droppedSubscriptionTopics) {
subFutures.add(subscribe(sub).whenComplete((result, error) -> {
if (error == null) {
if (error == null && (result == null || result.isSuccessful())) {
droppedSubscriptionTopics.remove(sub);
} else {
logger.atError().event(RESUB_LOG_EVENT).cause(error).kv(TOPIC_KEY, sub.getTopic())
Expand Down
42 changes: 33 additions & 9 deletions src/main/java/com/aws/greengrass/mqttclient/MqttClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import software.amazon.awssdk.crt.mqtt.MqttMessage;
import software.amazon.awssdk.crt.mqtt.QualityOfService;
import software.amazon.awssdk.crt.mqtt5.packets.PubAckPacket;
import software.amazon.awssdk.crt.mqtt5.packets.SubAckPacket;
import software.amazon.awssdk.iot.AwsIotMqttConnectionBuilder;

import java.io.Closeable;
Expand Down Expand Up @@ -471,11 +472,13 @@ public synchronized CompletableFuture<SubscribeResponse> subscribe(Subscribe req
IndividualMqttClient finalConnection = connection;
return connection.subscribe(request).whenComplete((i, t) -> {
try (LockScope scope = LockScope.lock(connectionLock.readLock())) {
if (t == null) {
if (t == null && (i == null || i.isSuccessful())) {
subscriptionTopics.put(new MqttTopic(request.getTopic()), finalConnection);
} else {
subscriptions.remove(request);
logger.atError().kv(TOPIC_KEY, request.getTopic()).log("Error subscribing", t);
if (t != null) {
logger.atError().kv(TOPIC_KEY, request.getTopic()).log("Error subscribing", t);
}
}
}
});
Expand All @@ -490,18 +493,39 @@ public synchronized CompletableFuture<SubscribeResponse> subscribe(Subscribe req
* @throws ExecutionException if an error occurs
* @throws InterruptedException if the thread is interrupted while subscribing
* @throws TimeoutException if the request times out
* @throws MqttException if the request fails
* @deprecated Use {@code subscribe(Subscribe request)} instead
*/
@Deprecated
@SuppressWarnings("PMD.AvoidCatchingGenericException")
public void subscribe(SubscribeRequest request)
throws ExecutionException, InterruptedException, TimeoutException {
try {
Consumer<Publish> cb = (Publish m) -> request.getCallback()
.accept(new MqttMessage(m.getTopic(), m.getPayload(),
QualityOfService.getEnumValueFromInteger(m.getQos().getValue()), m.isRetain()));
Subscribe newReq =
Subscribe.builder().qos(QOS.fromInt(request.getQos().getValue())).topic(request.getTopic())
.callback(cb).build();
// Deduplicate subscription callbacks so that retries do not result in getting called multiple times
Subscribe newReq = cbMapping.computeIfAbsent(new Pair<>(request.getTopic(), request.getCallback()), (p) -> {
Consumer<Publish> cb = (Publish m) -> request.getCallback()
.accept(new MqttMessage(m.getTopic(), m.getPayload(),
QualityOfService.getEnumValueFromInteger(m.getQos().getValue()), m.isRetain()));

return Subscribe.builder().qos(QOS.fromInt(request.getQos().getValue())).topic(request.getTopic())
.callback(cb).build();
});

subscribe(newReq)
.thenAccept((v) -> cbMapping.put(new Pair<>(request.getTopic(), request.getCallback()), newReq))
.thenApply((v) -> {
// null is a success because subscribe returns null if the subscription already existed
if (v == null || v.isSuccessful()) {
return v;
}
String rcString = SubAckPacket.SubAckReasonCode.UNSPECIFIED_ERROR.name();
try {
rcString = SubAckPacket.SubAckReasonCode.getEnumValueFromInteger(v.getReasonCode()).name();
} catch (RuntimeException ignored) {
}
// Consumers of this deprecated API expect to receive an MqttException if subscribing fails
throw new MqttException(
"Error subscribing. Reason: " + rcString);
})
.get(getMqttOperationTimeoutMillis(), TimeUnit.MILLISECONDS);
} catch (MqttRequestException e) {
throw new ExecutionException(e);
Expand Down
46 changes: 44 additions & 2 deletions src/test/java/com/aws/greengrass/mqttclient/MqttClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.aws.greengrass.mqttclient.v5.PubAck;
import com.aws.greengrass.mqttclient.v5.Publish;
import com.aws.greengrass.mqttclient.v5.QOS;
import com.aws.greengrass.mqttclient.v5.Subscribe;
import com.aws.greengrass.mqttclient.v5.SubscribeResponse;
import com.aws.greengrass.security.SecurityService;
import com.aws.greengrass.testcommons.testutilities.GGExtension;
import com.aws.greengrass.testcommons.testutilities.TestUtils;
Expand All @@ -39,6 +41,7 @@
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.CrtRuntimeException;
import software.amazon.awssdk.crt.mqtt.MqttClientConnection;
import software.amazon.awssdk.crt.mqtt.MqttException;
import software.amazon.awssdk.crt.mqtt.MqttMessage;
Expand Down Expand Up @@ -66,6 +69,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.aws.greengrass.deployment.DeviceConfiguration.DEVICE_MQTT_NAMESPACE;
import static com.aws.greengrass.deployment.DeviceConfiguration.DEVICE_PARAM_AWS_REGION;
Expand All @@ -80,6 +84,7 @@
import static com.aws.greengrass.mqttclient.MqttClient.MAX_NUMBER_OF_FORWARD_SLASHES;
import static com.aws.greengrass.mqttclient.MqttClient.MQTT_MAX_LIMIT_OF_MESSAGE_SIZE_IN_BYTES;
import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionOfType;
import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionUltimateCauseOfType;
import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionWithMessage;
import static com.aws.greengrass.testcommons.testutilities.TestUtils.asyncAssertOnConsumer;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -98,9 +103,12 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atMostOnce;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -220,8 +228,8 @@ void GIVEN_device_not_configured_to_talk_to_cloud_WHEN_publish_THEN_throws_excep

@Test
void GIVEN_multiple_subset_subscriptions_WHEN_subscribe_or_unsubscribe_THEN_only_subscribes_and_unsubscribes_once()
throws ExecutionException, InterruptedException, TimeoutException {
MqttClient client = new MqttClient(deviceConfiguration, spool, false, (c) -> builder, executorService);
throws ExecutionException, InterruptedException, TimeoutException, MqttRequestException {
MqttClient client = spy(new MqttClient(deviceConfiguration, spool, false, (c) -> builder, executorService));
assertFalse(client.connected());

client.subscribe(SubscribeRequest.builder().topic("A/B/+").callback(cb).build());
Expand All @@ -241,10 +249,19 @@ void GIVEN_multiple_subset_subscriptions_WHEN_subscribe_or_unsubscribe_THEN_only

// This subscription shouldn't actually subscribe through the cloud because it is a subset of the previous sub
client.subscribe(SubscribeRequest.builder().topic("A/B/C").callback(cb).build());
// "retry" request to verify that we deduplicate callbacks
client.subscribe(SubscribeRequest.builder().topic("A/B/C").callback(cb).build());

if (mqtt5) {
// verify we've still only called subscribe once
verify(mockMqtt5Client, atMostOnce()).subscribe(any());

// Verify that if someone retries, then we will deduplicate their callback. If we did this improperly,
// then we'd have 3 unique values for callback instead of only 2.
ArgumentCaptor<Subscribe> captor = ArgumentCaptor.forClass(Subscribe.class);
verify(client, times(3)).subscribe(captor.capture());
assertEquals(2,
captor.getAllValues().stream().map(Subscribe::getCallback).collect(Collectors.toSet()).size());
} else {
verify(mockConnection, times(0)).subscribe(eq("A/B/C"), eq(QualityOfService.AT_LEAST_ONCE));
}
Expand Down Expand Up @@ -1118,4 +1135,29 @@ void unreserved_topic_have_8_forward_slashes_WHEN_subscribe_THEN_throw_exception
verify(client).isValidRequestTopic(topic);
verify(mockConnection, never()).subscribe(any(), any());
}

@Test
void GIVEN_subscribe_fails_THEN_deprecated_subscribe_throws(ExtensionContext context)
throws ExecutionException, InterruptedException, TimeoutException, MqttRequestException {
ignoreExceptionUltimateCauseOfType(context, CrtRuntimeException.class);
MqttClient client = spy(new MqttClient(deviceConfiguration, spool, false, (c) -> builder, executorService));
SubscribeRequest request = SubscribeRequest.builder().topic("A").callback(cb).build();

// Handles exceptions thrown by mqtt client
doThrow(CrtRuntimeException.class).when(mockMqtt5Client).subscribe(any());

ExecutionException ee = assertThrows(ExecutionException.class, () -> client.subscribe(request));
assertThat(ee.getCause(), instanceOf(CrtRuntimeException.class));

// Does not throw on null response
doReturn(CompletableFuture.completedFuture(null)).when(client).subscribe(any(Subscribe.class));
client.subscribe(request);
reset(client);

// Throws if Subscription fails with a reason code
doReturn(CompletableFuture.completedFuture(new SubscribeResponse(null,
SubAckPacket.SubAckReasonCode.UNSPECIFIED_ERROR.getValue(), null))).when(client).subscribe(any(Subscribe.class));
ee = assertThrows(ExecutionException.class, () -> client.subscribe(request));
assertThat(ee.getCause(), instanceOf(MqttException.class));
}
}

0 comments on commit 0e4cfec

Please sign in to comment.