Skip to content

Commit

Permalink
Wait for messages in waitForNewLinkedDevice
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi-signal committed Nov 11, 2024
1 parent 3288d3d commit 81f3ba1
Show file tree
Hide file tree
Showing 16 changed files with 374 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
.scheduledExecutorService(name(getClass(), "subscriptionProcessorRetry-%d")).threads(1).build();
ScheduledExecutorService cloudflareTurnRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build();
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "messagePollExecutor-%d")).threads(1).build();

final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
Expand Down Expand Up @@ -620,7 +622,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,9 @@ The amount of time (in seconds) to wait for a response. If the expected device i
linkedDeviceListenerCounter.incrementAndGet();

final Timer.Sample sample = Timer.start();

try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -130,6 +131,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final ClientPublicKeysManager clientPublicKeysManager;
private final Executor accountLockExecutor;
private final ScheduledExecutorService messagesPollExecutor;
private final Clock clock;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;

Expand Down Expand Up @@ -163,6 +165,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid")));

private static Duration MESSAGE_POLL_INTERVAL = Duration.ofSeconds(1);
private static Duration MAX_SERVER_CLOCK_DRIFT = Duration.ofSeconds(5);

// An account that's used at least daily will get reset in the cache at least once per day when its "last seen"
// timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read
// frequently (e.g. the account is in an active group and receives messages frequently), but aren't actively used by
Expand Down Expand Up @@ -209,6 +214,7 @@ public AccountsManager(final Accounts accounts,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final ClientPublicKeysManager clientPublicKeysManager,
final Executor accountLockExecutor,
final ScheduledExecutorService messagesPollExecutor,
final Clock clock,
final byte[] linkDeviceSecret,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
Expand All @@ -227,6 +233,7 @@ public AccountsManager(final Accounts accounts,
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.clientPublicKeysManager = clientPublicKeysManager;
this.accountLockExecutor = accountLockExecutor;
this.messagesPollExecutor = messagesPollExecutor;
this.clock = requireNonNull(clock);
this.dynamicConfigurationManager = dynamicConfigurationManager;

Expand Down Expand Up @@ -1428,20 +1435,90 @@ private CompletableFuture<Void> redisDeleteAsync(final Account account) {
.thenRun(Util.NOOP);
}

public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) {
public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(
final UUID accountIdentifier,
final Device linkingDevice,
final String linkDeviceTokenIdentifier,
final Duration timeout) {
if (!linkingDevice.isPrimary()) {
throw new IllegalArgumentException("Only primary devices can link devices");
}

// Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256
// hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash.
if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) {
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
}

return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier,
linkDeviceTokenIdentifier,
getLinkedDeviceKey(linkDeviceTokenIdentifier),
timeout,
this::handleDeviceAdded);
final Instant deadline = clock.instant().plus(timeout);
final CompletableFuture<Optional<DeviceInfo>> deviceAdded = waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier,
linkDeviceTokenIdentifier, getLinkedDeviceKey(linkDeviceTokenIdentifier), timeout, this::handleDeviceAdded);

return deviceAdded.thenCompose(maybeDeviceInfo -> maybeDeviceInfo.map(deviceInfo -> {
// The device finished linking, we now want to make sure the client has fetched messages that could
// have come in before the device's mailbox was set up.

// A worst case estimate of the wall clock time at which the linked device was added to the account record
Instant deviceLinked = Instant.ofEpochMilli(deviceInfo.created()).plus(MAX_SERVER_CLOCK_DRIFT);

Instant now = clock.instant();

// We know at `now` the device finished linking, so if we waited for all the messages before now it would be
// sufficient. However, now might be much later that the device was linked, so we don't want to force
// the client to wait for messages that are past our worst case estimate of when the device was linked
Instant messageEpoch = Collections.min(List.of(deviceLinked, now));

// We assume that any message with a timestamp after the messageEpoch made it into the linked device's queues
return waitForPreLinkMessagesToBeFetched(accountIdentifier, linkingDevice, deviceInfo, messageEpoch, deadline);
})
.orElseGet(() -> CompletableFuture.completedFuture(maybeDeviceInfo)));
}

/**
* Wait until there are no pending messages for the authenticatedDevice that have a timestamp lower than the provided
* messageEpoch.
*
* @param aci The account identifier of the device doing the linking
* @param linkingDevice The device doing the linking
* @param linkedDeviceInfo Information about the newly linked device
* @param messageEpoch A time at which the device was linked
* @param deadline The time at which the method will stop waiting
* @return A future that completes when there are no pending messages for the linking device with a timestamp earlier
* the provided messageEpoch, or after the deadline is reached. If the deadline was exceeded, the future will be empty.
*/
private CompletableFuture<Optional<DeviceInfo>> waitForPreLinkMessagesToBeFetched(
final UUID aci,
final Device linkingDevice,
final DeviceInfo linkedDeviceInfo,
final Instant messageEpoch,
final Instant deadline) {
return messagesManager.getEarliestUndeliveredTimestampForDevice(aci, linkingDevice)
.thenCompose(maybeEarliestTimestamp -> {

final boolean clientHasOldMessages = maybeEarliestTimestamp
.map(earliestTimestamp -> earliestTimestamp.isBefore(messageEpoch))
.orElse(false);

if (!clientHasOldMessages) {
// The client has fetched all messages before the messageEpoch
return CompletableFuture.completedFuture(Optional.of(linkedDeviceInfo));
}

final Instant now = clock.instant();
if (now.plus(MESSAGE_POLL_INTERVAL).isAfter(deadline)) {
// Not enough time to try again before the deadline
return CompletableFuture.completedFuture(Optional.empty());
}

// Schedule a retry
return CompletableFuture.supplyAsync(
() -> waitForPreLinkMessagesToBeFetched(aci, linkingDevice, linkedDeviceInfo, messageEpoch, deadline),
r -> messagesPollExecutor.schedule(r, MESSAGE_POLL_INTERVAL.toMillis(), TimeUnit.MILLISECONDS))
.thenCompose(Function.identity());
});
}


private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final b
clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();

final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDevice,
earliestAllowableEphemeralTimestamp)
earliestAllowableEphemeralTimestamp, PAGE_SIZE)
.publish()
// We expect exactly two subscribers to this base flux:
// 1. the websocket that delivers messages to clients
Expand All @@ -311,6 +311,12 @@ public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final b
.tap(Micrometer.metrics(Metrics.globalRegistry));
}

public Mono<Long> getEarliestUndeliveredTimestamp(final UUID destinationUuid, final byte destinationDevice) {
return getAllMessages(destinationUuid, destinationDevice, -1, 1)
.next()
.map(MessageProtos.Envelope::getServerTimestamp);
}

private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) {
return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp;
Expand All @@ -330,17 +336,17 @@ private void discardStaleEphemeralMessages(final UUID destinationUuid, final byt

@VisibleForTesting
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final byte destinationDevice,
final long earliestAllowableEphemeralTimestamp) {
final long earliestAllowableEphemeralTimestamp, final int pageSize) {

// fetch messages by page
return getNextMessagePage(destinationUuid, destinationDevice, -1)
return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize)
.expand(queueItemsAndLastMessageId -> {
// expand() is breadth-first, so each page will be published in order
if (queueItemsAndLastMessageId.first().isEmpty()) {
return Mono.empty();
}

return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second());
return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize);
})
.limitRate(1)
// we want to ensure we don’t accidentally block the Lettuce/netty i/o executors
Expand Down Expand Up @@ -478,9 +484,9 @@ void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final Serv
}

private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
long messageId) {
long messageId, int pageSize) {

return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId)
return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId)
.map(queueItems -> {
logger.trace("Processing page: {}", messageId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
Expand Down Expand Up @@ -200,6 +201,16 @@ public int persistMessages(
return messagesRemovedFromCache;
}

public CompletableFuture<Optional<Instant>> getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, Device destinationDevice) {
// If there's any message in the persisted layer, return the oldest
return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp)
// If not, return the oldest message in the cache
.switchIfEmpty(messagesCache.getEarliestUndeliveredTimestamp(destinationUuid, destinationDevice.getId()))
.map(epochMilli -> Optional.of(Instant.ofEpochMilli(epochMilli)))
.switchIfEmpty(Mono.just(Optional.empty()))
.toFuture();
}

/**
* Inserts the shared multi-recipient message payload to storage.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ static CommandDependencies build(
.scheduledExecutorService(name(name, "remoteStorageRetry-%d")).threads(1).build();
ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "storageServiceRetry-%d")).threads(1).build();
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "messagePollExecutor-%d")).threads(1).build();

ExternalServiceCredentialsGenerator storageCredentialsGenerator = SecureStorageController.credentialsGenerator(
configuration.getSecureStorageServiceConfiguration());
Expand Down Expand Up @@ -227,7 +229,7 @@ static CommandDependencies build(
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,8 @@ void waitForLinkedDevice() {

final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));

try (final Response response = resources.getJerseyTest()
Expand All @@ -942,7 +943,8 @@ void waitForLinkedDevice() {
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));

try (final Response response = resources.getJerseyTest()
Expand All @@ -959,7 +961,8 @@ void waitForLinkedDeviceNoDeviceLinked() {
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));

try (final Response response = resources.getJerseyTest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
Expand Down Expand Up @@ -74,7 +75,7 @@ public class AccountCreationDeletionIntegrationTest {

private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());

private ExecutorService accountLockExecutor;
private ScheduledExecutorService executor;

private AccountsManager accountsManager;
private KeysManager keysManager;
Expand Down Expand Up @@ -113,12 +114,12 @@ void setUp() {
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());

accountLockExecutor = Executors.newSingleThreadExecutor();
executor = Executors.newSingleThreadScheduledExecutor();

final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());

clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, executor);

final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
Expand Down Expand Up @@ -164,18 +165,19 @@ void setUp() {
webSocketConnectionEventManager,
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,
executor,
executor,
CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
}

@AfterEach
void tearDown() throws InterruptedException {
accountLockExecutor.shutdown();
executor.shutdown();

//noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS);
executor.awaitTermination(1, TimeUnit.SECONDS);
}

@CartesianTest
Expand Down
Loading

0 comments on commit 81f3ba1

Please sign in to comment.