Skip to content

Commit

Permalink
Add ThreadContextPermission for markAsSystemContext and allow core to…
Browse files Browse the repository at this point in the history
… perform the method (opensearch-project#15016)

* Add RuntimePermission for markAsSystemContext and allow core to perform the method

Signed-off-by: Craig Perkins <[email protected]>

* private

Signed-off-by: Craig Perkins <[email protected]>

* Surround with doPrivileged

Signed-off-by: Craig Perkins <[email protected]>

* Create ThreadContextAccess

Signed-off-by: Craig Perkins <[email protected]>

* Create notion of ThreadContextPermission

Signed-off-by: Craig Perkins <[email protected]>

* Add to CHANGELOG

Signed-off-by: Craig Perkins <[email protected]>

* Add javadoc

Signed-off-by: Craig Perkins <[email protected]>

* Add to test-framework.policy file

Signed-off-by: Craig Perkins <[email protected]>

* Mark as internal

Signed-off-by: Craig Perkins <[email protected]>

---------

Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks committed Jul 31, 2024
1 parent 5c19809 commit 597747d
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- [Workload Management] Add queryGroupId to Task ([14708](https://github.com/opensearch-project/OpenSearch/pull/14708))
- Add setting to ignore throttling nodes for allocation of unassigned primaries in remote restore ([#14991](https://github.com/opensearch-project/OpenSearch/pull/14991))
- Add basic aggregation support for derived fields ([#14618](https://github.com/opensearch-project/OpenSearch/pull/14618))
- Add ThreadContextPermission for markAsSystemContext and allow core to perform the method ([#15016](https://github.com/opensearch-project/OpenSearch/pull/15016))

### Dependencies
- Bump `org.apache.commons:commons-lang3` from 3.14.0 to 3.15.0 ([#14861](https://github.com/opensearch-project/OpenSearch/pull/14861))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.secure_sm;

import java.security.BasicPermission;

/**
* Permission to utilize methods in the ThreadContext class that are normally not accessible
*
* @see ThreadGroup
* @see SecureSM
*/
public final class ThreadContextPermission extends BasicPermission {

/**
* Creates a new ThreadContextPermission object.
*
* @param name target name
*/
public ThreadContextPermission(String name) {
super(name);
}

/**
* Creates a new ThreadContextPermission object.
* This constructor exists for use by the {@code Policy} object to instantiate new Permission objects.
*
* @param name target name
* @param actions ignored
*/
public ThreadContextPermission(String name, String actions) {
super(name, actions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.telemetry.metrics.noop.NoopMetricsRegistry;
import org.opensearch.telemetry.metrics.tags.Tags;
Expand Down Expand Up @@ -396,7 +397,7 @@ private void submitStateUpdateTask(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
final UpdateTask updateTask = new UpdateTask(
config.priority(),
source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.Assertions;
import org.opensearch.core.common.text.Text;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
Expand Down Expand Up @@ -1022,7 +1023,7 @@ public <T> void submitStateUpdateTasks(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);

List<Batcher.UpdateTask> safeTasks = tasks.entrySet()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.http.HttpTransportSettings;
import org.opensearch.secure_sm.ThreadContextPermission;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskThreadContextStatePropagator;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.Permission;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -111,6 +113,10 @@ public final class ThreadContext implements Writeable {
*/
public static final String ACTION_ORIGIN_TRANSIENT_NAME = "action.origin";

// thread context permissions

private static final Permission ACCESS_SYSTEM_THREAD_CONTEXT_PERMISSION = new ThreadContextPermission("markAsSystemContext");

private static final Logger logger = LogManager.getLogger(ThreadContext.class);
private static final ThreadContextStruct DEFAULT_CONTEXT = new ThreadContextStruct();
private final Map<String, String> defaultHeader;
Expand Down Expand Up @@ -554,8 +560,19 @@ boolean isDefaultContext() {
/**
* Marks this thread context as an internal system context. This signals that actions in this context are issued
* by the system itself rather than by a user action.
*
* Usage of markAsSystemContext is guarded by a ThreadContextPermission. In order to use
* markAsSystemContext, the codebase needs to explicitly be granted permission in the JSM policy file.
*
* Add an entry in the grant portion of the policy file like this:
*
* permission org.opensearch.secure_sm.ThreadContextPermission "markAsSystemContext";
*/
public void markAsSystemContext() {
SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(ACCESS_SYSTEM_THREAD_CONTEXT_PERMISSION);
}
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.util.concurrent;

import org.opensearch.SpecialPermission;
import org.opensearch.common.annotation.InternalApi;

import java.security.AccessController;
import java.security.PrivilegedAction;

/**
* This class wraps the {@link ThreadContext} operations requiring access in
* {@link AccessController#doPrivileged(PrivilegedAction)} blocks.
*
* @opensearch.internal
*/
@SuppressWarnings("removal")
@InternalApi
public final class ThreadContextAccess {

private ThreadContextAccess() {}

public static <T> T doPrivileged(PrivilegedAction<T> operation) {
SpecialPermission.check();
return AccessController.doPrivileged(operation);
}

public static void doPrivilegedVoid(Runnable action) {
SpecialPermission.check();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
action.run();
return null;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.index.shard.ShardId;
Expand Down Expand Up @@ -98,7 +99,7 @@ public GlobalCheckpointSyncAction(
public void updateGlobalCheckpointForShard(final ShardId shardId) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
execute(new Request(shardId), ActionListener.wrap(r -> {}, e -> {
if (ExceptionsHelper.unwrap(e, AlreadyClosedException.class, IndexShardClosedException.class) == null) {
logger.info(new ParameterizedMessage("{} global checkpoint sync failed", shardId), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -122,7 +123,7 @@ final void backgroundSync(ShardId shardId, String primaryAllocationId, long prim
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_background_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -137,7 +138,7 @@ final void sync(
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -113,7 +114,7 @@ final void publish(IndexShard indexShard, ReplicationCheckpoint checkpoint) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
PublishCheckpointRequest request = new PublishCheckpointRequest(checkpoint);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "segrep_publish_checkpoint", request);
final ReplicationTimer timer = new ReplicationTimer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -136,7 +137,7 @@ void collectNodes(ActionListener<Function<String, DiscoveryNode>> listener) {
new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);

final ClusterStateRequest request = new ClusterStateRequest();
request.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -349,7 +350,7 @@ private void collectRemoteNodes(Iterator<Supplier<DiscoveryNode>> seedNodes, Act
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any
// existing context information.
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
transportService.sendRequest(
connection,
ClusterStateAction.NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ grant codeBase "${codebase.opensearch}" {
permission java.lang.RuntimePermission "setContextClassLoader";
// needed for SPI class loading
permission java.lang.RuntimePermission "accessDeclaredMembers";
permission org.opensearch.secure_sm.ThreadContextPermission "markAsSystemContext";
};

//// Very special jar permissions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,5 @@ grant {
permission java.lang.RuntimePermission "reflectionFactoryAccess";
permission java.lang.RuntimePermission "accessClassInPackage.sun.reflect";
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
permission org.opensearch.secure_sm.ThreadContextPermission "markAsSystemContext";
};
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -225,7 +226,7 @@ public void testUpdateTemplates() {
service.upgradesInProgress.set(additionsCount + deletionsCount + 2); // +2 to skip tryFinishUpgrade
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
service.upgradeTemplates(additions, deletions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ public void testPreservesThreadsOriginalContextOnRunException() throws IOExcepti
threadContext.putHeader("foo", "bar");
boolean systemContext = randomBoolean();
if (systemContext) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
}
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {
Expand Down Expand Up @@ -736,7 +736,7 @@ public void testMarkAsSystemContext() throws IOException {
assertFalse(threadContext.isSystemContext());
try (ThreadContext.StoredContext context = threadContext.stashContext()) {
assertFalse(threadContext.isSystemContext());
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
assertTrue(threadContext.isSystemContext());
}
assertFalse(threadContext.isSystemContext());
Expand All @@ -761,7 +761,7 @@ public void testSystemContextWithPropagator() {
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("bar", threadContext.getHeader("foo"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
Expand Down Expand Up @@ -793,7 +793,7 @@ public void testSerializeSystemContext() throws IOException {
threadContext.writeTo(out);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
threadContext.writeTo(outFromSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContext.StoredContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.telemetry.Telemetry;
import org.opensearch.telemetry.TelemetrySettings;
import org.opensearch.telemetry.metrics.MetricsTelemetry;
Expand Down Expand Up @@ -260,7 +261,7 @@ public void testSpanNotPropagatedToChildSystemThreadContext() {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

grant {
// allow to test Security policy and codebases
// allow to test Security policy and codebases
permission java.util.PropertyPermission "*", "read,write";
permission java.security.SecurityPermission "createPolicy.JavaPolicy";
};
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.node.Node;
import org.opensearch.telemetry.metrics.noop.NoopMetricsRegistry;
Expand Down Expand Up @@ -134,7 +135,7 @@ public void run() {
scheduledNextTask = false;
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.markAsSystemContext();
ThreadContextAccess.doPrivilegedVoid(threadContext::markAsSystemContext);
task.run();
}
if (waitForPublish == false) {
Expand Down

0 comments on commit 597747d

Please sign in to comment.