Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parallelize cli with multithreaded task manager #1580

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions datashare-app/src/main/java/org/icij/datashare/CliApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Properties;

Expand Down Expand Up @@ -110,47 +111,38 @@ private static void runTaskWorker(CommonMode mode, Properties properties) throws
PipelineHelper pipeline = new PipelineHelper(new PropertiesProvider(properties));
logger.info("executing {}", pipeline);
if (pipeline.has(Stage.DEDUPLICATE)) {
Long result = taskFactory.createDeduplicateTask(
new Task<>(DeduplicateTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage);return null;}).call();
logger.info("removed {} duplicates", result);
taskManager.startTask(
new Task<>(DeduplicateTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.SCANIDX)) {
Long result = taskFactory.createScanIndexTask(
new Task<>(ScanIndexTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage);return null;}).call();
logger.info("scanned {}", result);
taskManager.startTask(
new Task<>(ScanIndexTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.SCAN)) {
taskFactory.createScanTask(
new Task<>(ScanTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage); return null;}).call();
taskManager.startTask(
new Task<>(ScanTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.INDEX)) {
taskFactory.createIndexTask(
new Task<>(IndexTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage); return null;}).call();
taskManager.startTask(
new Task<>(IndexTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.ENQUEUEIDX)) {
taskFactory.createEnqueueFromIndexTask(
new Task<>(EnqueueFromIndexTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage); return null;}).call();
taskManager.startTask(
new Task<>(EnqueueFromIndexTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.NLP)) {
taskFactory.createExtractNlpTask(
new Task<>(ExtractNlpTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage); return null;}).call();
taskManager.startTask(
new Task<>(ExtractNlpTask.class.getName(), nullUser(), propertiesToMap(properties)));
}

if (pipeline.has(Stage.ARTIFACT)) {
taskFactory.createArtifactTask(
new Task<>(ArtifactTask.class.getName(), nullUser(), propertiesToMap(properties)),
(percentage) -> {logger.info("percentage: {}% done", percentage); return null;}).call();
taskManager.startTask(
new Task<>(ArtifactTask.class.getName(), nullUser(), propertiesToMap(properties)));
}
taskManager.shutdownAndAwaitTermination(Integer.MAX_VALUE, SECONDS);
indexer.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@

import com.google.inject.Inject;
import com.google.inject.Singleton;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;

import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.Task;


@Singleton
public class TaskManagerMemory extends org.icij.datashare.asynctasks.TaskManagerMemory {

@Inject
public TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, DatashareTaskFactory taskFactory) {
this(taskQueue, taskFactory, new CountDownLatch(1));
public TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, DatashareTaskFactory taskFactory, PropertiesProvider propertiesProvider) {
this(taskQueue, taskFactory, propertiesProvider, new CountDownLatch(1));
}

TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, DatashareTaskFactory taskFactory, CountDownLatch latch) {
super(taskQueue, taskFactory, latch);
TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, DatashareTaskFactory taskFactory, PropertiesProvider propertiesProvider, CountDownLatch latch) {
super(taskQueue, taskFactory, propertiesProvider, latch);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void test_main_loop_exit_with_sigterm_when_running_batch() throws Excepti
@Before
public void setUp() throws IOException {
initMocks(this);
taskManager = new TaskManagerMemory(batchSearchQueue, factory, startLoop);
taskManager = new TaskManagerMemory(batchSearchQueue, factory, new PropertiesProvider(), startLoop);
mockSearch = new MockSearch<>(indexer, Indexer.QueryBuilderSearcher.class);

Task<Object> taskView = new Task<>(testBatchSearch.uuid, BatchSearchRunner.class.getName(), local());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class TaskWorkerLoopForPipelineTasksTest {
@Mock Function<Double, Void> updateCallback;
@Mock ElasticsearchSpewer spewer;
private final BlockingQueue<Task<?>> taskQueue = new LinkedBlockingQueue<>();
private final TaskManagerMemory taskSupplier = new TaskManagerMemory(taskQueue, taskFactory);
private final TaskManagerMemory taskSupplier = new TaskManagerMemory(taskQueue, taskFactory, new PropertiesProvider());

@Test
public void test_scan_task() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ private void testTripleQuote(Boolean phraseMatch, String query, String tripleQuo
@Before
public void setUp() {
initMocks(this);
taskManager = new TaskManagerMemory(new ArrayBlockingQueue<>(5), factory);
taskManager = new TaskManagerMemory(new ArrayBlockingQueue<>(5), factory, new PropertiesProvider());
when(factory.createBatchSearchRunner(any(), any())).thenReturn(mock(BatchSearchRunner.class));
configure(routes -> routes.add(new BatchSearchResource(new PropertiesProvider(), taskManager, batchSearchRepository)).
filter(new LocalUserFilter(new PropertiesProvider(), jooqRepository)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class TaskResourceTest extends AbstractProdWebServerTest {
JooqRepository jooqRepository;
private static final DatashareTaskFactoryForTest taskFactory = mock(DatashareTaskFactoryForTest.class);
private static final BlockingQueue<Task<?>> taskQueue = new ArrayBlockingQueue<>(3);
private static final TaskManagerMemory taskManager = new TaskManagerMemory(taskQueue, taskFactory);
private static final TaskManagerMemory taskManager = new TaskManagerMemory(taskQueue, taskFactory, new PropertiesProvider());

@Before
public void setUp() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ private void setupAppWith(DatashareTaskFactory taskFactory, String... userLogins
final PropertiesProvider propertiesProvider = new PropertiesProvider(new HashMap<>() {{
put("mode", "LOCAL");
}});
taskManager = new TaskManagerMemory(new ArrayBlockingQueue<>(3), taskFactory);
taskManager = new TaskManagerMemory(new ArrayBlockingQueue<>(3), taskFactory, new PropertiesProvider());
configure(new CommonMode(propertiesProvider.getProperties()) {
@Override
protected void configure() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.icij.datashare.asynctasks;

import org.apache.commons.lang3.NotImplementedException;
import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.asynctasks.bus.amqp.TaskError;
import org.icij.datashare.asynctasks.bus.amqp.TaskEvent;
import org.icij.datashare.user.User;
Expand All @@ -17,32 +18,41 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static java.lang.Integer.parseInt;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static java.util.stream.Collectors.toList;
import static org.icij.datashare.asynctasks.Task.State.RUNNING;


public class TaskManagerMemory implements TaskManager, TaskSupplier {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final ExecutorService executor = newSingleThreadExecutor();
private final ExecutorService executor;
private final ConcurrentMap<String, Task<?>> tasks = new ConcurrentHashMap<>();
private final BlockingQueue<Task<?>> taskQueue;
private final TaskWorkerLoop loop;
private final List<TaskWorkerLoop> loops;
private final AtomicInteger executedTasks = new AtomicInteger(0);

public TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, TaskFactory taskFactory) {
this(taskQueue, taskFactory, new CountDownLatch(1));
this(taskQueue, taskFactory, new PropertiesProvider(), new CountDownLatch(1));
}

public TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, TaskFactory taskFactory, CountDownLatch latch) {
public TaskManagerMemory(BlockingQueue<Task<?>> taskQueue, TaskFactory taskFactory, PropertiesProvider propertiesProvider, CountDownLatch latch) {
this.taskQueue = taskQueue;
loop = new TaskWorkerLoop(taskFactory, this, latch);
executor.submit(loop);
int parallelism = parseInt(propertiesProvider.get("parallelism").orElse("1"));
logger.info("running TaskManager with {} threads", parallelism);
executor = Executors.newFixedThreadPool(parallelism);
loops = IntStream.range(0, parallelism).mapToObj(i -> new TaskWorkerLoop(taskFactory, this, latch)).collect(Collectors.toList());
loops.forEach(executor::submit);
}

public <V> Task<V> getTask(final String taskId) {
Expand Down Expand Up @@ -148,7 +158,7 @@ public boolean stopTask(String taskId) {
canceled(taskView, false);
return removed;
case RUNNING:
loop.cancel(taskId, false);
loops.forEach(l -> l.cancel(taskId, false));
return true;
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import org.icij.datashare.PropertiesProvider;
import org.icij.datashare.test.LogbackCapturingRule;
import org.icij.datashare.user.User;
import org.junit.After;
Expand All @@ -28,7 +30,7 @@ public class TaskManagerMemoryTest {
@Before
public void setUp() throws Exception {
LinkedBlockingQueue<Task<?>> taskViews = new LinkedBlockingQueue<>();
taskManager = new TaskManagerMemory(taskViews, factory, waitForLoop);
taskManager = new TaskManagerMemory(taskViews, factory, new PropertiesProvider(), waitForLoop);
taskInspector = new TaskInspector(taskManager);
waitForLoop.await();
}
Expand Down