Skip to content

Commit

Permalink
CORE-18913 extract the group allocator logic in the mediator to its o…
Browse files Browse the repository at this point in the history
…wn service (#5257)

This allows for better unit testing of the allocator logic and cleaner code in the mediator.
  • Loading branch information
LWogan authored Dec 18, 2023
1 parent 63b69f7 commit 462ab74
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ class FlowMapperServiceIntegrationTest {
producer {
close.timeout = 6000
}
mediator {
poolSize = 1
minPoolRecordCount = 20
}
pollTimeout = 100
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class TestFlowEventMediatorFactoryImpl @Activate constructor(
.threads(1)
.threadName("flow-event-mediator")
.stateManager(stateManagerFactory.create(stateManagerConfig))
.minGroupSize(20)
.build()

private fun createMessageRouterFactory() = MessageRouterFactory { clientFinder ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import net.corda.schema.Schemas.Flow.FLOW_MAPPER_START
import net.corda.schema.Schemas.Flow.FLOW_SESSION
import net.corda.schema.Schemas.Flow.FLOW_START
import net.corda.schema.Schemas.P2P.P2P_OUT_TOPIC
import net.corda.schema.configuration.FlowConfig
import net.corda.schema.configuration.MessagingConfig.Subscription.PROCESSING_MIN_POOL_RECORD_COUNT
import net.corda.schema.configuration.MessagingConfig.Subscription.PROCESSING_THREAD_POOL_SIZE
import net.corda.session.mapper.service.executor.FlowMapperMessageProcessor
import org.osgi.service.component.annotations.Activate
import org.osgi.service.component.annotations.Component
Expand Down Expand Up @@ -51,15 +52,13 @@ class FlowMapperEventMediatorFactoryImpl @Activate constructor(
stateManager: StateManager,
) = eventMediatorFactory.create(
createEventMediatorConfig(
flowConfig,
messagingConfig,
FlowMapperMessageProcessor(flowMapperEventExecutorFactory, flowConfig),
stateManager,
)
)

private fun createEventMediatorConfig(
flowConfig: SmartConfig,
messagingConfig: SmartConfig,
messageProcessor: StateAndEventProcessor<String, FlowMapperState, FlowMapperEvent>,
stateManager: StateManager,
Expand All @@ -84,9 +83,10 @@ class FlowMapperEventMediatorFactoryImpl @Activate constructor(
)
.messageProcessor(messageProcessor)
.messageRouterFactory(createMessageRouterFactory())
.threads(flowConfig.getInt(FlowConfig.PROCESSING_THREAD_POOL_SIZE))
.threads(messagingConfig.getInt(PROCESSING_THREAD_POOL_SIZE))
.threadName("flow-mapper-event-mediator")
.stateManager(stateManager)
.minGroupSize(messagingConfig.getInt(PROCESSING_MIN_POOL_RECORD_COUNT))
.build()

private fun createMessageRouterFactory() = MessageRouterFactory { clientFinder ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import net.corda.messaging.api.mediator.config.EventMediatorConfig
import net.corda.messaging.api.mediator.factory.MediatorConsumerFactoryFactory
import net.corda.messaging.api.mediator.factory.MessagingClientFactoryFactory
import net.corda.messaging.api.mediator.factory.MultiSourceEventMediatorFactory
import net.corda.schema.configuration.FlowConfig
import net.corda.schema.configuration.MessagingConfig
import net.corda.session.mapper.messaging.mediator.FlowMapperEventMediatorFactory
import net.corda.session.mapper.messaging.mediator.FlowMapperEventMediatorFactoryImpl
import org.junit.jupiter.api.Assertions.assertNotNull
Expand All @@ -24,13 +24,13 @@ class FlowMapperEventMediatorFactoryImplTest {
private val mediatorConsumerFactoryFactory = mock<MediatorConsumerFactoryFactory>()
private val messagingClientFactoryFactory = mock<MessagingClientFactoryFactory>()
private val multiSourceEventMediatorFactory = mock<MultiSourceEventMediatorFactory>()
private val flowConfig = mock<SmartConfig>()
private val config = mock<SmartConfig>()

@BeforeEach
fun beforeEach() {
`when`(multiSourceEventMediatorFactory.create(any<EventMediatorConfig<String, FlowMapperState, FlowMapperEvent>>()))
.thenReturn(mock())
`when`(flowConfig.getInt(FlowConfig.PROCESSING_THREAD_POOL_SIZE)).thenReturn(10)
`when`(config.getInt(MessagingConfig.Subscription.PROCESSING_THREAD_POOL_SIZE)).thenReturn(10)

flowMapperEventMediatorFactory = FlowMapperEventMediatorFactoryImpl(
flowMapperEventExecutorFactory,
Expand All @@ -42,7 +42,7 @@ class FlowMapperEventMediatorFactoryImplTest {

@Test
fun `successfully creates event mediator`() {
val mediator = flowMapperEventMediatorFactory.create(flowConfig, mock(), mock())
val mediator = flowMapperEventMediatorFactory.create(mock(), config, mock())

assertNotNull(mediator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import net.corda.data.uniqueness.UniquenessCheckRequestAvro
import net.corda.flow.pipeline.factory.FlowEventProcessorFactory
import net.corda.ledger.utxo.verification.TransactionVerificationRequest
import net.corda.libs.configuration.SmartConfig
import net.corda.libs.configuration.helper.getConfig
import net.corda.libs.platform.PlatformInfoProvider
import net.corda.libs.statemanager.api.StateManager
import net.corda.messaging.api.constants.WorkerRPCPaths.CRYPTO_PATH
Expand Down Expand Up @@ -43,8 +42,8 @@ import net.corda.schema.configuration.BootConfig.PERSISTENCE_WORKER_REST_ENDPOIN
import net.corda.schema.configuration.BootConfig.TOKEN_SELECTION_WORKER_REST_ENDPOINT
import net.corda.schema.configuration.BootConfig.UNIQUENESS_WORKER_REST_ENDPOINT
import net.corda.schema.configuration.BootConfig.VERIFICATION_WORKER_REST_ENDPOINT
import net.corda.schema.configuration.ConfigKeys
import net.corda.schema.configuration.FlowConfig
import net.corda.schema.configuration.MessagingConfig.Subscription.PROCESSING_MIN_POOL_RECORD_COUNT
import net.corda.schema.configuration.MessagingConfig.Subscription.PROCESSING_THREAD_POOL_SIZE
import org.osgi.service.component.annotations.Activate
import org.osgi.service.component.annotations.Component
import org.osgi.service.component.annotations.Reference
Expand Down Expand Up @@ -79,15 +78,13 @@ class FlowEventMediatorFactoryImpl @Activate constructor(
stateManager: StateManager,
) = eventMediatorFactory.create(
createEventMediatorConfig(
configs,
messagingConfig,
flowEventProcessorFactory.create(configs),
stateManager,
)
)

private fun createEventMediatorConfig(
configs: Map<String, SmartConfig>,
messagingConfig: SmartConfig,
messageProcessor: StateAndEventProcessor<String, Checkpoint, FlowEvent>,
stateManager: StateManager,
Expand Down Expand Up @@ -115,9 +112,10 @@ class FlowEventMediatorFactoryImpl @Activate constructor(
)
.messageProcessor(messageProcessor)
.messageRouterFactory(createMessageRouterFactory(messagingConfig))
.threads(configs.getConfig(ConfigKeys.FLOW_CONFIG).getInt(FlowConfig.PROCESSING_THREAD_POOL_SIZE))
.threads(messagingConfig.getInt(PROCESSING_THREAD_POOL_SIZE))
.threadName("flow-event-mediator")
.stateManager(stateManager)
.minGroupSize(messagingConfig.getInt(PROCESSING_MIN_POOL_RECORD_COUNT))
.build()

private fun createMessageRouterFactory(messagingConfig: SmartConfig) = MessageRouterFactory { clientFinder ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import net.corda.schema.Schemas.Flow.FLOW_EVENT_TOPIC
import net.corda.schema.Schemas.Flow.FLOW_MAPPER_SESSION_OUT
import net.corda.schema.Schemas.Flow.FLOW_STATUS_TOPIC
import net.corda.schema.configuration.ConfigKeys
import net.corda.schema.configuration.FlowConfig
import net.corda.schema.configuration.MessagingConfig
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach
Expand All @@ -51,7 +51,7 @@ class FlowEventMediatorFactoryImplTest {
private val multiSourceEventMediatorFactory = mock<MultiSourceEventMediatorFactory>()
private val cordaAvroSerializationFactory = mock<CordaAvroSerializationFactory>()
private val platformInfoProvider = mock<PlatformInfoProvider>()
private val flowConfig = mock<SmartConfig>()
private val config = mock<SmartConfig>()

val captor = argumentCaptor<EventMediatorConfig<String, Checkpoint, FlowEvent>>()

Expand All @@ -63,7 +63,7 @@ class FlowEventMediatorFactoryImplTest {
`when`(multiSourceEventMediatorFactory.create(captor.capture()))
.thenReturn(mock())

`when`(flowConfig.getInt(FlowConfig.PROCESSING_THREAD_POOL_SIZE)).thenReturn(10)
`when`(config.getInt(MessagingConfig.Subscription.PROCESSING_THREAD_POOL_SIZE)).thenReturn(10)

flowEventMediatorFactory = FlowEventMediatorFactoryImpl(
flowEventProcessorFactory,
Expand All @@ -83,7 +83,7 @@ class FlowEventMediatorFactoryImplTest {
@Test
fun `successfully creates event mediator with expected routes`() {
val mediator = flowEventMediatorFactory.create(
mapOf(ConfigKeys.FLOW_CONFIG to flowConfig),
mapOf(ConfigKeys.MESSAGING_CONFIG to config),
mock(),
mock(),
)
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ commonsLangVersion = 3.12.0
commonsTextVersion = 1.10.0
# Corda API libs revision (change in 4th digit indicates a breaking change)
# Change to 5.2.0.xx-SNAPSHOT to pick up maven local published copy
cordaApiVersion=5.2.0.16-beta+
cordaApiVersion=5.2.0.17-beta+

disruptorVersion=3.4.4
felixConfigAdminVersion=1.9.26
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package net.corda.messaging.mediator

import net.corda.messaging.api.mediator.config.EventMediatorConfig
import net.corda.messaging.api.records.Record
import kotlin.math.ceil
import kotlin.math.min

/**
* Helper class to use in the mediator to divide polled records into groups for processing.
*/
class GroupAllocator {

/**
* Allocate events into groups based on their keys, a configured minimum group size and thread count.
* This allows for more efficient multi-threaded processing.
* The threshold record count to establish a new group is [config.minGroupSize].
* If the number of groups exceeds the number of threads then the group count is set to the number of [config.threads]
* Records of the same key are always placed into the same group regardless of group size and count.
* @param events Events to allocate to groups
* @param config Mediator config
* @return Records allocated to groups.
*/
fun <K : Any, S : Any, E : Any> allocateGroups(
events: List<Record<K, E>>,
config: EventMediatorConfig<K, S, E>
): List<Map<K, List<Record<K, E>>>> {
val groups = setUpGroups(config, events)
val buckets = events
.groupBy { it.key }.toList()
.sortedByDescending { it.second.size }

buckets.forEach { (key, records) ->
val leastFilledGroup = groups.minByOrNull { it.values.flatten().size }
leastFilledGroup?.put(key, records)
}

return groups.filter { it.values.isNotEmpty() }
}

private fun <E : Any, S: Any, K : Any> setUpGroups(
config: EventMediatorConfig<K, S, E>,
events: List<Record<K, E>>
): MutableList<MutableMap<K, List<Record<K, E>>>> {
val numGroups = min(
ceil(events.size.toDouble() / config.minGroupSize).toInt(),
config.threads
)

return MutableList(numGroups) { mutableMapOf() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MultiSourceEventMediatorImpl<K : Any, S : Any, E : Any>(
private val taskManagerHelper = TaskManagerHelper(
taskManager, stateManagerHelper, metrics
)
private val groupAllocator = GroupAllocator()
private val uniqueId = UUID.randomUUID().toString()
private val lifecycleCoordinatorName = LifecycleCoordinatorName(
"MultiSourceEventMediator--${config.name}", uniqueId
Expand Down Expand Up @@ -164,7 +165,7 @@ class MultiSourceEventMediatorImpl<K : Any, S : Any, E : Any>(
val messages = consumer.poll(pollTimeout)
val startTimestamp = System.nanoTime()
if (messages.isNotEmpty()) {
var groups = allocateGroups(messages.map { it.toRecord() })
var groups = groupAllocator.allocateGroups(messages.map { it.toRecord() }, config)
var states = stateManager.get(messages.map { it.key.toString() }.distinct())

while (groups.isNotEmpty()) {
Expand Down Expand Up @@ -230,7 +231,7 @@ class MultiSourceEventMediatorImpl<K : Any, S : Any, E : Any>(
states = failedToCreate + failedToDelete + failedToUpdateOptimisticLockFailure

groups = if (states.isNotEmpty()) {
allocateGroups(flowEvents.filterKeys { states.containsKey(it) }.values.flatten())
groupAllocator.allocateGroups(flowEvents.filterKeys { states.containsKey(it) }.values.flatten(), config)
} else {
listOf()
}
Expand Down Expand Up @@ -319,23 +320,4 @@ class MultiSourceEventMediatorImpl<K : Any, S : Any, E : Any>(
}
}
}

private fun allocateGroups(events: List<Record<K, E>>): List<Map<K, List<Record<K, E>>>> {
val groups = mutableListOf<MutableMap<K, List<Record<K, E>>>>()
val groupCountBasedOnEvents = (events.size / 20).coerceAtLeast(1)
val groupsCount = if (groupCountBasedOnEvents < config.threads) groupCountBasedOnEvents else config.threads
for (i in 0 until groupsCount) {
groups.add(mutableMapOf())
}
val buckets = events.groupBy { it.key }
val bucketSizes = buckets.keys.sortedByDescending { buckets[it]?.size }
for (i in buckets.size - 1 downTo 0 step 1) {
val group = groups.minBy { it.values.flatten().size }
val key = bucketSizes[i]
val records = buckets[key]!!
group[key] = records
}

return groups
}
}
Loading

0 comments on commit 462ab74

Please sign in to comment.