diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 0f021125e..aa35ec47f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -559,6 +559,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro .scheduledExecutorService(name(getClass(), "subscriptionProcessorRetry-%d")).threads(1).build(); ScheduledExecutorService cloudflareTurnRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build(); + ScheduledExecutorService messagePollExecutor = environment.lifecycle() + .scheduledExecutorService(name(getClass(), "messagePollExecutor-%d")).threads(1).build(); final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup(); final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next()) @@ -620,7 +622,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager, secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager, - registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, + registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor, clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs); APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 179d6ac0c..40369b72e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -362,9 +362,9 @@ The amount of time (in seconds) to wait for a response. If the expected device i linkedDeviceListenerCounter.incrementAndGet(); final Timer.Sample sample = Timer.start(); - try { - return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) + return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(), + authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) .thenApply(maybeDeviceInfo -> maybeDeviceInfo .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index d28b4296d..ed2a72266 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -47,6 +47,7 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -130,6 +131,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final ClientPublicKeysManager clientPublicKeysManager; private final Executor accountLockExecutor; + private final ScheduledExecutorService messagesPollExecutor; private final Clock clock; private final DynamicConfigurationManager dynamicConfigurationManager; @@ -163,6 +165,9 @@ public class AccountsManager extends RedisPubSubAdapter implemen private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() .writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); + private static Duration MESSAGE_POLL_INTERVAL = Duration.ofSeconds(1); + private static Duration MAX_SERVER_CLOCK_DRIFT = Duration.ofSeconds(5); + // An account that's used at least daily will get reset in the cache at least once per day when its "last seen" // timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read // frequently (e.g. the account is in an active group and receives messages frequently), but aren't actively used by @@ -209,6 +214,7 @@ public AccountsManager(final Accounts accounts, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final ClientPublicKeysManager clientPublicKeysManager, final Executor accountLockExecutor, + final ScheduledExecutorService messagesPollExecutor, final Clock clock, final byte[] linkDeviceSecret, final DynamicConfigurationManager dynamicConfigurationManager) { @@ -227,6 +233,7 @@ public AccountsManager(final Accounts accounts, this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager); this.clientPublicKeysManager = clientPublicKeysManager; this.accountLockExecutor = accountLockExecutor; + this.messagesPollExecutor = messagesPollExecutor; this.clock = requireNonNull(clock); this.dynamicConfigurationManager = dynamicConfigurationManager; @@ -1428,20 +1435,90 @@ private CompletableFuture redisDeleteAsync(final Account account) { .thenRun(Util.NOOP); } - public CompletableFuture> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) { + public CompletableFuture> waitForNewLinkedDevice( + final UUID accountIdentifier, + final Device linkingDevice, + final String linkDeviceTokenIdentifier, + final Duration timeout) { + if (!linkingDevice.isPrimary()) { + throw new IllegalArgumentException("Only primary devices can link devices"); + } + // Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256 // hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash. if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) { return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier")); } - return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier, - linkDeviceTokenIdentifier, - getLinkedDeviceKey(linkDeviceTokenIdentifier), - timeout, - this::handleDeviceAdded); + final Instant deadline = clock.instant().plus(timeout); + final CompletableFuture> deviceAdded = waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier, + linkDeviceTokenIdentifier, getLinkedDeviceKey(linkDeviceTokenIdentifier), timeout, this::handleDeviceAdded); + + return deviceAdded.thenCompose(maybeDeviceInfo -> maybeDeviceInfo.map(deviceInfo -> { + // The device finished linking, we now want to make sure the client has fetched messages that could + // have come in before the device's mailbox was set up. + + // A worst case estimate of the wall clock time at which the linked device was added to the account record + Instant deviceLinked = Instant.ofEpochMilli(deviceInfo.created()).plus(MAX_SERVER_CLOCK_DRIFT); + + Instant now = clock.instant(); + + // We know at `now` the device finished linking, so if we waited for all the messages before now it would be + // sufficient. However, now might be much later that the device was linked, so we don't want to force + // the client to wait for messages that are past our worst case estimate of when the device was linked + Instant messageEpoch = Collections.min(List.of(deviceLinked, now)); + + // We assume that any message with a timestamp after the messageEpoch made it into the linked device's queues + return waitForPreLinkMessagesToBeFetched(accountIdentifier, linkingDevice, deviceInfo, messageEpoch, deadline); + }) + .orElseGet(() -> CompletableFuture.completedFuture(maybeDeviceInfo))); + } + + /** + * Wait until there are no pending messages for the authenticatedDevice that have a timestamp lower than the provided + * messageEpoch. + * + * @param aci The account identifier of the device doing the linking + * @param linkingDevice The device doing the linking + * @param linkedDeviceInfo Information about the newly linked device + * @param messageEpoch A time at which the device was linked + * @param deadline The time at which the method will stop waiting + * @return A future that completes when there are no pending messages for the linking device with a timestamp earlier + * the provided messageEpoch, or after the deadline is reached. If the deadline was exceeded, the future will be empty. + */ + private CompletableFuture> waitForPreLinkMessagesToBeFetched( + final UUID aci, + final Device linkingDevice, + final DeviceInfo linkedDeviceInfo, + final Instant messageEpoch, + final Instant deadline) { + return messagesManager.getEarliestUndeliveredTimestampForDevice(aci, linkingDevice) + .thenCompose(maybeEarliestTimestamp -> { + + final boolean clientHasOldMessages = maybeEarliestTimestamp + .map(earliestTimestamp -> earliestTimestamp.isBefore(messageEpoch)) + .orElse(false); + + if (!clientHasOldMessages) { + // The client has fetched all messages before the messageEpoch + return CompletableFuture.completedFuture(Optional.of(linkedDeviceInfo)); + } + + final Instant now = clock.instant(); + if (now.plus(MESSAGE_POLL_INTERVAL).isAfter(deadline)) { + // Not enough time to try again before the deadline + return CompletableFuture.completedFuture(Optional.empty()); + } + + // Schedule a retry + return CompletableFuture.supplyAsync( + () -> waitForPreLinkMessagesToBeFetched(aci, linkingDevice, linkedDeviceInfo, messageEpoch, deadline), + r -> messagesPollExecutor.schedule(r, MESSAGE_POLL_INTERVAL.toMillis(), TimeUnit.MILLISECONDS)) + .thenCompose(Function.identity()); + }); } + private void handleDeviceAdded(final CompletableFuture> future, final String deviceInfoJson) { try { future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 36e877b2d..095333382 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -290,7 +290,7 @@ public Publisher get(final UUID destinationUuid, final b clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); final Flux allMessages = getAllMessages(destinationUuid, destinationDevice, - earliestAllowableEphemeralTimestamp) + earliestAllowableEphemeralTimestamp, PAGE_SIZE) .publish() // We expect exactly two subscribers to this base flux: // 1. the websocket that delivers messages to clients @@ -311,6 +311,12 @@ public Publisher get(final UUID destinationUuid, final b .tap(Micrometer.metrics(Metrics.globalRegistry)); } + public Mono getEarliestUndeliveredTimestamp(final UUID destinationUuid, final byte destinationDevice) { + return getAllMessages(destinationUuid, destinationDevice, -1, 1) + .next() + .map(MessageProtos.Envelope::getServerTimestamp); + } + private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message, long earliestAllowableTimestamp) { return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp; @@ -330,17 +336,17 @@ private void discardStaleEphemeralMessages(final UUID destinationUuid, final byt @VisibleForTesting Flux getAllMessages(final UUID destinationUuid, final byte destinationDevice, - final long earliestAllowableEphemeralTimestamp) { + final long earliestAllowableEphemeralTimestamp, final int pageSize) { // fetch messages by page - return getNextMessagePage(destinationUuid, destinationDevice, -1) + return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize) .expand(queueItemsAndLastMessageId -> { // expand() is breadth-first, so each page will be published in order if (queueItemsAndLastMessageId.first().isEmpty()) { return Mono.empty(); } - return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second()); + return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize); }) .limitRate(1) // we want to ensure we don’t accidentally block the Lettuce/netty i/o executors @@ -478,9 +484,9 @@ void removeRecipientViewFromMrmData(final List sharedMrmKeys, final Serv } private Mono, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, - long messageId) { + long messageId, int pageSize) { - return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId) + return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId) .map(queueItems -> { logger.trace("Processing page: {}", messageId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 23872ac08..838e22a21 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -8,6 +8,7 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; +import java.time.Instant; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -200,6 +201,16 @@ public int persistMessages( return messagesRemovedFromCache; } + public CompletableFuture> getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, Device destinationDevice) { + // If there's any message in the persisted layer, return the oldest + return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp) + // If not, return the oldest message in the cache + .switchIfEmpty(messagesCache.getEarliestUndeliveredTimestamp(destinationUuid, destinationDevice.getId())) + .map(epochMilli -> Optional.of(Instant.ofEpochMilli(epochMilli))) + .switchIfEmpty(Mono.just(Optional.empty())) + .toFuture(); + } + /** * Inserts the shared multi-recipient message payload to storage. * diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index e8fa1c9dc..774cbf912 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -146,6 +146,8 @@ static CommandDependencies build( .scheduledExecutorService(name(name, "remoteStorageRetry-%d")).threads(1).build(); ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "storageServiceRetry-%d")).threads(1).build(); + ScheduledExecutorService messagePollExecutor = environment.lifecycle() + .scheduledExecutorService(name(name, "messagePollExecutor-%d")).threads(1).build(); ExternalServiceCredentialsGenerator storageCredentialsGenerator = SecureStorageController.credentialsGenerator( configuration.getSecureStorageServiceConfiguration()); @@ -227,7 +229,7 @@ static CommandDependencies build( AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, pubsubClient, accountLockManager, keys, messagesManager, profilesManager, secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager, - registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, + registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor, clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 28a925de6..ddb691dc5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -919,7 +919,8 @@ void waitForLinkedDevice() { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); - when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + when(accountsManager + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); try (final Response response = resources.getJerseyTest() @@ -942,7 +943,8 @@ void waitForLinkedDevice() { void waitForLinkedDeviceNoDeviceLinked() { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); - when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + when(accountsManager + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); try (final Response response = resources.getJerseyTest() @@ -959,7 +961,8 @@ void waitForLinkedDeviceNoDeviceLinked() { void waitForLinkedDeviceBadTokenIdentifier() { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); - when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + when(accountsManager + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); try (final Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index 8a4c76cfe..dfaf4d304 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -25,6 +25,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; @@ -74,7 +75,7 @@ public class AccountCreationDeletionIntegrationTest { private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); - private ExecutorService accountLockExecutor; + private ScheduledExecutorService executor; private AccountsManager accountsManager; private KeysManager keysManager; @@ -113,12 +114,12 @@ void setUp() { DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(), DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName()); - accountLockExecutor = Executors.newSingleThreadExecutor(); + executor = Executors.newSingleThreadScheduledExecutor(); final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); - clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, executor); final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -164,7 +165,8 @@ void setUp() { webSocketConnectionEventManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, - accountLockExecutor, + executor, + executor, CLOCK, "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager); @@ -172,10 +174,10 @@ void setUp() { @AfterEach void tearDown() throws InterruptedException { - accountLockExecutor.shutdown(); + executor.shutdown(); //noinspection ResultOfMethodCallIgnored - accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS); + executor.awaitTermination(1, TimeUnit.SECONDS); } @CartesianTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 335139240..a6c753963 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -23,6 +23,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -69,7 +70,7 @@ class AccountsManagerChangeNumberIntegrationTest { private KeysManager keysManager; private DisconnectionRequestManager disconnectionRequestManager; private WebSocketConnectionEventManager webSocketConnectionEventManager; - private ExecutorService accountLockExecutor; + private ScheduledExecutorService executor; private AccountsManager accountsManager; @@ -104,13 +105,13 @@ void setup() throws InterruptedException { Tables.DELETED_ACCOUNTS.tableName(), Tables.USED_LINK_DEVICE_TOKENS.tableName()); - accountLockExecutor = Executors.newSingleThreadExecutor(); + executor = Executors.newSingleThreadScheduledExecutor(); final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.DELETED_ACCOUNTS_LOCK.tableName()); final ClientPublicKeysManager clientPublicKeysManager = - new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + new ClientPublicKeysManager(clientPublicKeys, accountLockManager, executor); final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -151,7 +152,8 @@ void setup() throws InterruptedException { webSocketConnectionEventManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, - accountLockExecutor, + executor, + executor, mock(Clock.class), "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager); @@ -160,10 +162,10 @@ void setup() throws InterruptedException { @AfterEach void tearDown() throws InterruptedException { - accountLockExecutor.shutdown(); + executor.shutdown(); //noinspection ResultOfMethodCallIgnored - accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS); + executor.awaitTermination(1, TimeUnit.SECONDS); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 7260db735..d64ec325a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -30,6 +30,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -139,6 +140,7 @@ void setup() throws InterruptedException { mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), mock(Executor.class), + mock(ScheduledExecutorService.class), mock(Clock.class), "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java index eb2b70d86..227b23b00 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java @@ -29,6 +29,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; @@ -68,6 +69,7 @@ void setUp() { mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), mock(ExecutorService.class), + mock(ScheduledExecutorService.class), Clock.systemUTC(), "link-device-secret".getBytes(StandardCharsets.UTF_8), mock(DynamicConfigurationManager.class)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 41ac05f0a..442064d38 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -54,6 +54,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -262,6 +263,7 @@ void setup() throws InterruptedException { registrationRecoveryPasswordsManager, clientPublicKeysManager, mock(Executor.class), + mock(ScheduledExecutorService.class), CLOCK, LINK_DEVICE_SECRET, dynamicConfigurationManager); @@ -1537,6 +1539,21 @@ void testSetUsernameViaUpdate() { assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsernameHash(USERNAME_HASH_1))); } + @Test + void testOnlyPrimaryCanWaitForDeviceLinked() { + final Device primaryDevice = new Device(); + primaryDevice.setId(Device.PRIMARY_ID); + + final Device linkedDevice = new Device(); + linkedDevice.setId((byte) (Device.PRIMARY_ID + 1)); + + final Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice, linkedDevice)); + + assertThrows(IllegalArgumentException.class, + () -> accountsManager.waitForNewLinkedDevice(account.getUuid(), linkedDevice, "", Duration.ofSeconds(1))); + + } + @Test void testJsonRoundTripSerialization() throws Exception { String originalJson; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 306d618bf..30b8eb90e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -157,6 +157,7 @@ private void buildAccountsManager(final int initialWidth, int discriminatorMaxWi mock(RegistrationRecoveryPasswordsManager.class), mock(ClientPublicKeysManager.class), Executors.newSingleThreadExecutor(), + Executors.newSingleThreadScheduledExecutor(), mock(Clock.class), "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index 9b0991bbf..a4d2ef779 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -7,15 +7,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.i18n.phonenumbers.PhoneNumberUtil; import java.nio.charset.StandardCharsets; -import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.time.ZoneId; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -23,11 +22,15 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; @@ -42,6 +45,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.TestClock; public class AddRemoveDeviceIntegrationTest { @@ -67,14 +71,14 @@ public class AddRemoveDeviceIntegrationTest { @RegisterExtension static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); - private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); - private ExecutorService accountLockExecutor; + private ScheduledExecutorService messagePollExecutor; private KeysManager keysManager; private ClientPublicKeysManager clientPublicKeysManager; private MessagesManager messagesManager; private AccountsManager accountsManager; + private TestClock clock; @BeforeEach void setUp() { @@ -84,6 +88,8 @@ void setUp() { final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + clock = TestClock.pinned(Instant.now()); + keysManager = new KeysManager( DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(), @@ -106,6 +112,7 @@ void setUp() { DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName()); accountLockExecutor = Executors.newSingleThreadExecutor(); + messagePollExecutor = mock(ScheduledExecutorService.class); final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); @@ -155,7 +162,8 @@ void setUp() { registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, - CLOCK, + messagePollExecutor, + clock, "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager); @@ -210,10 +218,15 @@ void addDevice() throws InterruptedException { final byte addedDeviceId = updatedAccountAndDevice.second().getId(); - assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); - assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue( + keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); + assertTrue( + keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join() + .isPresent()); assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); - assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue( + keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join() + .isPresent()); } @Test @@ -317,15 +330,18 @@ void removeDevice() throws InterruptedException { assertEquals(1, updatedAccount.getDevices().size()); assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); - assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertFalse( + keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertFalse(keysManager.getLastResort(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); assertFalse(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertFalse(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); - assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue( + keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getLastResort(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue( + keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); } @@ -371,21 +387,27 @@ void removeDevicePartialFailure() throws InterruptedException { final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow(); - clientPublicKeysManager.setPublicKey(retrievedAccount, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); - clientPublicKeysManager.setPublicKey(retrievedAccount, addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(retrievedAccount, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()) + .join(); + clientPublicKeysManager.setPublicKey(retrievedAccount, addedDeviceId, Curve.generateKeyPair().getPublicKey()) + .join(); assertEquals(2, retrievedAccount.getDevices().size()); assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); - assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue( + keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); - assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue( + keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); - assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join() + .isPresent()); assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue( + keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); } @@ -403,11 +425,15 @@ void waitForNewLinkedDevice() throws InterruptedException { final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); - final CompletableFuture> displacedFuture = - accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); + final CompletableFuture> displacedFuture = accountsManager.waitForNewLinkedDevice( + account.getUuid(), account.getPrimaryDevice(), + linkDeviceTokenIdentifier, Duration.ofSeconds(5)); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); final CompletableFuture> activeFuture = - accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); + accountsManager.waitForNewLinkedDevice(account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier, + Duration.ofSeconds(5)); assertEquals(Optional.empty(), displacedFuture.join()); @@ -470,8 +496,11 @@ void waitForNewLinkedDeviceAlreadyAdded() throws InterruptedException { linkDeviceToken) .join(); - final CompletableFuture> linkedDeviceFuture = - accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMinutes(1)); + when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + final CompletableFuture> linkedDeviceFuture = accountsManager.waitForNewLinkedDevice( + account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofMinutes(1)); final Optional maybeDeviceInfo = linkedDeviceFuture.join(); @@ -483,15 +512,121 @@ void waitForNewLinkedDeviceAlreadyAdded() throws InterruptedException { } @Test - void waitForNewLinkedDeviceTimeout() { + void waitForNewLinkedDeviceTimeout() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + final Account account = AccountsHelper.createAccount(accountsManager, number); + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); - final CompletableFuture> linkedDeviceFuture = - accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(1)); + final CompletableFuture> linkedDeviceFuture = accountsManager.waitForNewLinkedDevice( + account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofMillis(1)); final Optional maybeDeviceInfo = linkedDeviceFuture.join(); assertTrue(maybeDeviceInfo.isEmpty()); } + + @ParameterizedTest + @CsvSource({ + "10_000,1000,,false", // no pending messages + "10_000,1000,1000,true", // pending message at device creation + "10_000,1000,5999,true", // pending message right before device creation + fudge factor + "10_000,1000,6000,false", // pending message at device creation + fudge factor + "3000,5000,4000,false", // pending message after current time but before device creation + }) + void waitForMessageFetch(long currentTime, long deviceCreation, Long oldestMessage, boolean shouldWait) + throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + final Account account = AccountsHelper.createAccount(accountsManager, number); + + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); + final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); + + clock.pin(Instant.ofEpochMilli(deviceCreation)); + final Pair updatedAccountAndDevice = + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + Set.of(), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair)), + linkDeviceToken) + .join(); + + assertEquals(updatedAccountAndDevice.second().getCreated(), deviceCreation); + + when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice())) + .thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(oldestMessage).map(Instant::ofEpochMilli))); + + clock.pin(Instant.ofEpochMilli(currentTime)); + Duration timeout = shouldWait ? Duration.ofMillis(5) : Duration.ofMillis(1000); + Optional result = accountsManager.waitForNewLinkedDevice(account.getUuid(), + account.getPrimaryDevice(), linkDeviceTokenIdentifier, timeout).join(); + assertEquals(result.isEmpty(), shouldWait); + } + + // ThreadMode.SEPARATE_THREAD protects against hangs in the async calls, as this mode allows the test code to be + // preempted by the timeout check + @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + @Test + void waitForMessageFetchRetries() + throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + final Account account = AccountsHelper.createAccount(accountsManager, number); + + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); + final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); + + clock.pin(Instant.ofEpochMilli(0)); + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + Set.of(), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair)), + linkDeviceToken) + .join(); + + when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice())) + // Has a message older than the message epoch + .thenReturn(CompletableFuture.completedFuture(Optional.of(Instant.ofEpochMilli(1000)))) + // The message was fetched + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + clock.pin(Instant.ofEpochMilli(10_000)); + // Run any scheduled job right away + when(messagePollExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(x -> { + x.getArgument(0, Runnable.class).run(); + return null; + }); + Optional result = accountsManager.waitForNewLinkedDevice(account.getUuid(), + account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofSeconds(10)).join(); + assertTrue(result.isPresent()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index d5d1817ed..ef785bf35 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -145,7 +146,7 @@ void testDoubleInsertGuid() { messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); - assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0) + assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10) .count() .blockOptional() .orElse(0L)); @@ -225,6 +226,31 @@ void testHasMessagesAsync() { assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); } + @Test + void getOldestTimestamp() { + final int messageCount = 100; + + final List expectedMessages = new ArrayList<>(messageCount); + + long expectedOldestTimestamp = serialTimestamp; + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + assertEquals(expectedOldestTimestamp, + messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block()); + expectedMessages.add(message); + } + + for (final MessageProtos.Envelope message : expectedMessages) { + assertEquals(expectedOldestTimestamp, + messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block()); + messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, UUID.fromString(message.getServerGuid())).join(); + expectedOldestTimestamp += 1; + } + assertNull(messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block()); + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void testGetMessages(final boolean sealedSender) throws Exception { @@ -236,7 +262,6 @@ void testGetMessages(final boolean sealedSender) throws Exception { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - expectedMessages.add(message); } @@ -322,7 +347,7 @@ void testGetMessagesPublisher(final boolean expectStale) throws Exception { .get(5, TimeUnit.SECONDS); final List messages = messagesCache.getAllMessages(DESTINATION_UUID, - DESTINATION_DEVICE_ID, 0) + DESTINATION_DEVICE_ID, 0, 10) .collectList() .toFuture().get(5, TimeUnit.SECONDS); @@ -655,7 +680,7 @@ void testGetAllMessagesLimitsAndBackpressure() { .thenReturn(Flux.from(emptyFinalPagePublisher)) .thenReturn(Flux.empty()); - final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0); + final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10); // Why initialValue = 3? // 1. messagesCache.getAllMessages() above produces the first call diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index ef4e66157..a58cd1963 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -14,6 +14,8 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import java.time.Instant; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; @@ -21,6 +23,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import reactor.core.publisher.Mono; class MessagesManagerTest { @@ -77,4 +80,28 @@ void mayHaveMessages(final boolean hasCachedMessages, final boolean hasPersisted assertEquals(expectMayHaveMessages, messagesManager.mayHaveMessages(accountIdentifier, device).join()); } + + @ParameterizedTest + @CsvSource({ + ",,", + "1,,1", + ",1,1", + "2,1,1", + "1,2,2" + }) + public void oldestMessageTimestamp(Long oldestCached, Long oldestPersisted, Long expected) { + final UUID accountIdentifier = UUID.randomUUID(); + final Device device = mock(Device.class); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + + when(messagesCache.getEarliestUndeliveredTimestamp(accountIdentifier, Device.PRIMARY_ID)) + .thenReturn(oldestCached == null ? Mono.empty() : Mono.just(oldestCached)); + when(messagesDynamoDb.load(accountIdentifier, device, 1)) + .thenReturn(oldestPersisted == null + ? Mono.empty() + : Mono.just(Envelope.newBuilder().setServerTimestamp(oldestPersisted).build())); + final Optional earliest = + messagesManager.getEarliestUndeliveredTimestampForDevice(accountIdentifier, device).join(); + assertEquals(Optional.ofNullable(expected).map(Instant::ofEpochMilli), earliest); + } }