Skip to content

Commit

Permalink
Add persistent timer utility backed by redis
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi-signal committed Jan 29, 2025
1 parent 1446d1a commit 282bcf6
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager;
import org.whispersystems.textsecuregcm.storage.PersistentTimer;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
Expand Down Expand Up @@ -1097,6 +1098,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
log.info("Registered spam filter: {}", filter.getClass().getName());
});

final PersistentTimer persistentTimer = new PersistentTimer(rateLimitersCluster, clock);

final PhoneVerificationTokenManager phoneVerificationTokenManager = new PhoneVerificationTokenManager(
phoneNumberIdentifiers, registrationServiceClient, registrationRecoveryPasswordsManager, registrationRecoveryChecker);
Expand All @@ -1115,7 +1117,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, config.getMaxDevices()),
new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, persistentTimer, config.getMaxDevices()),
new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
config.getDeviceCheck().backupRedemptionLevel(),
config.getDeviceCheck().backupRedemptionDuration()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
Expand Down Expand Up @@ -81,6 +80,7 @@
import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.storage.PersistentTimer;
import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
Expand All @@ -100,6 +100,7 @@ public class DeviceController {
private final AccountsManager accounts;
private final ClientPublicKeysManager clientPublicKeysManager;
private final RateLimiters rateLimiters;
private final PersistentTimer persistentTimer;
private final Map<String, Integer> maxDeviceConfiguration;

private final EnumMap<ClientPlatform, AtomicInteger> linkedDeviceListenersByPlatform;
Expand All @@ -108,9 +109,11 @@ public class DeviceController {
private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME =
MetricsUtil.name(DeviceController.class, "linkedDeviceListeners");

private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE = "wait_for_linked_device";
private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration");

private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE = "wait_for_transfer_archive";
private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration");

Expand All @@ -124,11 +127,13 @@ public class DeviceController {
public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters,
final PersistentTimer persistentTimer,
final Map<String, Integer> maxDeviceConfiguration) {

this.accounts = accounts;
this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters;
this.persistentTimer = persistentTimer;
this.maxDeviceConfiguration = maxDeviceConfiguration;

linkedDeviceListenersByPlatform =
Expand Down Expand Up @@ -366,32 +371,30 @@ The amount of time (in seconds) to wait for a response. If the expected device i
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start();

return rateLimiters.getWaitForLinkedDeviceLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(),
tokenIdentifier,
Duration.ofSeconds(timeoutSeconds)))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();

if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier))
.thenCompose(sample -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(),
tokenIdentifier,
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();

if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.register(Metrics.globalRegistry));
}
}));
}

private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
Expand Down Expand Up @@ -529,7 +532,8 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo
public CompletionStage<Void> recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) {

return rateLimiters.getUploadTransferArchiveLimiter().validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
return rateLimiters.getUploadTransferArchiveLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(),
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
Expand Down Expand Up @@ -568,30 +572,25 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo

@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) {

final Timer.Sample sample = Timer.start();

final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) +
":" + authenticatedDevice.getAuthenticatedDevice().getId();

return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey)
.thenCompose(ignored -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(),
authenticatedDevice.getAuthenticatedDevice(),
Duration.ofSeconds(timeoutSeconds)))
.thenApply(maybeTransferArchive -> maybeTransferArchive
.map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.whenComplete((response, throwable) -> {
if (response == null) {
return;
}
sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of(
"archiveUploaded",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
});
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey))
.thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(),
authenticatedDevice.getAuthenticatedDevice(),
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeTransferArchive -> maybeTransferArchive
.map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.whenComplete((response, throwable) -> {
if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.register(Metrics.globalRegistry));
}
}));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.storage;

import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.SetArgs;
import io.micrometer.core.instrument.Timer;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.util.Util;

/**
* Timers for operations that may span machines or requests and require a persistently stored timer start itme
*/
public class PersistentTimer {

private static final Logger logger = LoggerFactory.getLogger(PersistentTimer.class);

private static String TIMER_NAMESPACE = "persistent_timer";
@VisibleForTesting
static final Duration TIMER_TTL = Duration.ofHours(1);

private final FaultTolerantRedisClusterClient redisClient;
private final Clock clock;


public PersistentTimer(final FaultTolerantRedisClusterClient redisClient, final Clock clock) {
this.redisClient = redisClient;
this.clock = clock;
}

public class Sample {

private final Instant start;
private final String redisKey;

public Sample(final Instant start, final String redisKey) {
this.start = start;
this.redisKey = redisKey;
}

/**
* Stop the timer, recording the duration between now and the first call to start. This deletes the persistent timer.
*
* @param timer The micrometer timer to record the duration to
* @return A future that completes when the resources associated with the persistent timer have been destroyed
*/
public CompletableFuture<Void> stop(Timer timer) {
Duration duration = Duration.between(start, clock.instant());
timer.record(duration);
return redisClient.withCluster(connection -> connection.async().del(redisKey))
.toCompletableFuture()
.thenRun(Util.NOOP);
}
}

/**
* Start the timer if a timer with the provided namespaced key has not already been started, otherwise return the
* existing sample.
*
* @param namespace A namespace prefix to use for the timer
* @param key The unique key within the namespace that identifies the timer
* @return A future that completes with a {@link Sample} that can later be used to record the final duration.
*/
public CompletableFuture<Sample> start(final String namespace, final String key) {
final Instant now = clock.instant();
final String redisKey = redisKey(namespace, key);

return redisClient.withCluster(connection ->
connection.async().setGet(redisKey, String.valueOf(now.getEpochSecond()), SetArgs.Builder.nx().ex(TIMER_TTL)))
.toCompletableFuture()
.thenApply(serialized -> new Sample(parseStoredTimestamp(serialized).orElse(now), redisKey));
}

@VisibleForTesting
String redisKey(final String namespace, final String key) {
return String.format("%s::%s::%s", TIMER_NAMESPACE, namespace, key);
}

private static Optional<Instant> parseStoredTimestamp(final @Nullable String serialized) {
return Optional
.ofNullable(serialized)
.flatMap(s -> {
try {
return Optional.of(Long.parseLong(s));
} catch (NumberFormatException e) {
logger.warn("Failed to parse stored timestamp {}", s, e);
return Optional.empty();
}
})
.map(Instant::ofEpochSecond);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -90,6 +89,7 @@
import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.storage.PersistentTimer;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
Expand All @@ -104,6 +104,7 @@ class DeviceControllerTest {

private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final ClientPublicKeysManager clientPublicKeysManager = mock(ClientPublicKeysManager.class);
private static final PersistentTimer persistentTimer = mock(PersistentTimer.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
@SuppressWarnings("unchecked")
Expand All @@ -123,6 +124,7 @@ class DeviceControllerTest {
accountsManager,
clientPublicKeysManager,
rateLimiters,
persistentTimer,
deviceConfiguration);

@RegisterExtension
Expand Down Expand Up @@ -161,6 +163,9 @@ void setup() {
when(clientPublicKeysManager.setPublicKey(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));

when(persistentTimer.start(anyString(), anyString()))
.thenReturn(CompletableFuture.completedFuture(mock(PersistentTimer.Sample.class)));

AccountsHelper.setupMockUpdate(accountsManager);
}

Expand Down
Loading

0 comments on commit 282bcf6

Please sign in to comment.