Skip to content

Commit

Permalink
Fix mTLS authorization bug (#1455)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technoboy- authored Sep 11, 2024
1 parent d2c8afc commit 13db85b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ private Map<String, AuthenticationProvider> getAuthenticationProviders(List<Stri

public AuthenticationResult authenticate(boolean fromProxy,
SSLSession session, MqttConnectMessage connectMessage) {
if (fromProxy) {
return new AuthenticationResult(true, null, null);
}
String authMethod = MqttMessageUtils.getAuthMethod(connectMessage);
if (authMethod != null) {
byte[] authData = MqttMessageUtils.getAuthData(connectMessage);
if (authData == null) {
return AuthenticationResult.FAILED;
}
if (fromProxy && AUTH_MTLS.equalsIgnoreCase(authMethod)) {
return new AuthenticationResult(true, new String(authData),
new AuthenticationDataCommand(new String(authData), null, session));
}

return authenticate(connectMessage.payload().clientIdentifier(), authMethod,
new AuthenticationDataCommand(new String(authData), null, session));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.streamnative.pulsar.handlers.mqtt.proxy;

import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_MTLS;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttConnectMessage;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttPublishMessage;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttSubscribeMessage;
Expand Down Expand Up @@ -141,11 +140,9 @@ public void doProcessConnect(MqttAdapterMessage adapter, String userRole,
.processor(this)
.build();
connection.sendConnAck();
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttConnectMessage connectMessage = createMqttConnectMessage(msg, AUTH_MTLS, userRole);
msg = connectMessage;
connection.setConnectMessage(msg);
}
MqttConnectMessage connectMessage = createMqttConnectMessage(msg, userRole);
msg = connectMessage;
connection.setConnectMessage(msg);

ConnectEvent connectEvent = ConnectEvent.builder()
.clientId(connection.getClientId())
Expand All @@ -166,10 +163,8 @@ public void processPublish(MqttAdapterMessage adapter) {
proxyConfig.getDefaultTenant(), proxyConfig.getDefaultNamespace(),
TopicDomain.getEnum(proxyConfig.getDefaultTopicDomain()));
adapter.setClientId(connection.getClientId());
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttPublishMessage mqttMessage = createMqttPublishMessage(msg, AUTH_MTLS, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
}
MqttPublishMessage mqttMessage = createMqttPublishMessage(msg, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
startPublish()
.thenCompose(__ -> writeToBroker(pulsarTopicName, adapter))
.whenComplete((unused, ex) -> {
Expand Down Expand Up @@ -300,10 +295,8 @@ public void processSubscribe(final MqttAdapterMessage adapter) {
log.debug("[Proxy Subscribe] [{}] msg: {}", clientId, msg);
}
registerTopicListener(adapter);
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttSubscribeMessage mqttMessage = createMqttSubscribeMessage(msg, AUTH_MTLS, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
}
MqttSubscribeMessage mqttMessage = createMqttSubscribeMessage(msg, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
doSubscribe(adapter, false)
.exceptionally(ex -> {
Throwable realCause = FutureUtil.unwrapCompletionException(ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
package io.streamnative.pulsar.handlers.mqtt.support;

import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createWillMessage;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getMtlsAuthMethodAndData;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getAuthenticationRole;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.pingResp;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.topicSubscriptions;
import io.netty.channel.ChannelHandlerContext;
Expand Down Expand Up @@ -75,9 +75,7 @@
import org.apache.bookkeeper.mledger.Position;
import org.apache.bookkeeper.mledger.PositionFactory;
import org.apache.bookkeeper.mledger.impl.AckSetStateUtil;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pulsar.broker.PulsarService;
import org.apache.pulsar.broker.authentication.AuthenticationDataCommand;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.broker.authorization.AuthorizationService;
import org.apache.pulsar.broker.service.BrokerServiceException;
Expand Down Expand Up @@ -199,10 +197,9 @@ public void processPublish(MqttAdapterMessage adapter) {
String userRole = connection.getUserRole();
AuthenticationDataSource authData = connection.getAuthData();
if (adapter.fromProxy()) {
final Optional<Pair<String, byte[]>> mtlsAuthMethodAndData = getMtlsAuthMethodAndData(msg);
if (mtlsAuthMethodAndData.isPresent()) {
userRole = mtlsAuthMethodAndData.get().getKey();
authData = new AuthenticationDataCommand(new String(mtlsAuthMethodAndData.get().getValue()));
final Optional<String> authenticationRole = getAuthenticationRole(msg);
if (authenticationRole.isPresent()) {
userRole = authenticationRole.get();
}
}
result = this.authorizationService.canProduceAsync(TopicName.get(msg.variableHeader().topicName()),
Expand All @@ -224,9 +221,11 @@ private CompletableFuture<Void> doUnauthorized(MqttAdapterMessage adapter) {
log.error("[Publish] not authorized to topic={}, userRole={}, CId= {}",
msg.variableHeader().topicName(), connection.getUserRole(),
connection.getClientId());
int packetId = msg.variableHeader().packetId();
packetId = packetId == -1 ? 1 : packetId;
MqttPubAck.MqttPubErrorAckBuilder pubAckBuilder = MqttPubAck
.errorBuilder(connection.getProtocolVersion())
.packetId(msg.variableHeader().packetId())
.packetId(packetId)
.reasonCode(Mqtt5PubReasonCode.NOT_AUTHORIZED);
if (connection.getClientRestrictions().isAllowReasonStrOrUserProperty()) {
pubAckBuilder.reasonString("Not Authorized!");
Expand Down Expand Up @@ -367,10 +366,9 @@ public void processSubscribe(MqttAdapterMessage adapter) {
} else {
AuthenticationDataSource authData = connection.getAuthData();
if (adapter.fromProxy()) {
final Optional<Pair<String, byte[]>> mtlsAuthMethodAndData = getMtlsAuthMethodAndData(msg);
if (mtlsAuthMethodAndData.isPresent()) {
userRole = mtlsAuthMethodAndData.get().getKey();
authData = new AuthenticationDataCommand(new String(mtlsAuthMethodAndData.get().getValue()));
final Optional<String> authenticationRole = getAuthenticationRole(msg);
if (authenticationRole.isPresent()) {
userRole = authenticationRole.get();
}
}
List<CompletableFuture<Void>> authorizationFutures = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import static com.google.common.base.Preconditions.checkArgument;
import static io.netty.handler.codec.mqtt.MqttQoS.AT_MOST_ONCE;
import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_MTLS;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
Expand All @@ -41,7 +40,6 @@
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.tuple.Pair;

/**
* Mqtt message utils.
Expand All @@ -50,6 +48,8 @@ public class MqttMessageUtils {

public static final int CLIENT_IDENTIFIER_MAX_LENGTH = 23;

public static final String AUTHENTICATE_ROLE_KEY = "__mop_auth_role";

public static void checkState(MqttMessage msg) {
if (!msg.decoderResult().isSuccess()) {
throw new IllegalStateException(msg.decoderResult().cause().getMessage());
Expand Down Expand Up @@ -190,14 +190,10 @@ public static MqttPublishMessage createMqttWillMessage(WillMessage willMessage)
}

public static MqttConnectMessage createMqttConnectMessage(MqttConnectMessage connectMessage,
String authMethod,
String authData) {
final MqttConnectVariableHeader header = connectMessage.variableHeader();
MqttProperties properties = new MqttProperties();
properties.add(new MqttProperties.StringProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value()
, authMethod));
properties.add(new MqttProperties.BinaryProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value()
, authData.getBytes()));
properties.add(new MqttProperties.UserProperty(AUTHENTICATE_ROLE_KEY, authData));
MqttConnectVariableHeader variableHeader = new MqttConnectVariableHeader(
MqttVersion.MQTT_5.protocolName(), MqttVersion.MQTT_5.protocolLevel(), header.hasUserName(),
header.hasPassword(), header.isWillRetain(), header.willQos(), header.isWillFlag(),
Expand All @@ -209,72 +205,58 @@ public static MqttConnectMessage createMqttConnectMessage(MqttConnectMessage con
}

public static MqttPublishMessage createMqttPublishMessage(MqttPublishMessage publishMessage,
String authMethod,
String authData) {
final MqttPublishVariableHeader header = publishMessage.variableHeader();
MqttProperties properties = new MqttProperties();
properties.add(new MqttProperties.StringProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value()
, authMethod));
properties.add(new MqttProperties.BinaryProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value()
, authData.getBytes()));
properties.add(new MqttProperties.UserProperty(AUTHENTICATE_ROLE_KEY, authData));
MqttPublishVariableHeader variableHeader = new MqttPublishVariableHeader(
header.topicName(), header.packetId(), properties);
MqttPublishMessage newPublishMessage = new MqttPublishMessage(publishMessage.fixedHeader(), variableHeader,
publishMessage.payload());
return newPublishMessage;
}

public static Optional<Pair<String, byte[]>> getMtlsAuthMethodAndData(MqttConnectMessage connectMessage) {
public static Optional<String> getAuthenticationRole(MqttConnectMessage connectMessage) {
final MqttConnectVariableHeader header = connectMessage.variableHeader();
MqttProperties properties = header.properties();
final MqttProperties.MqttProperty property = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value());
if (property != null && property.value() instanceof String
&& ((String) property.value()).equalsIgnoreCase(AUTH_MTLS)) {
final MqttProperties.MqttProperty data = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value());
return Optional.of(Pair.of((String) property.value(), (byte[]) data.value()));
final MqttProperties.UserProperties data = (MqttProperties.UserProperties) properties.getProperty(
MqttProperties.MqttPropertyType.USER_PROPERTY.value());
if (data != null && data.value() instanceof List<MqttProperties.StringPair>) {
return data.value().stream().filter(d -> d.key.equalsIgnoreCase(AUTHENTICATE_ROLE_KEY))
.map(e -> e.value).findFirst();
}
return Optional.empty();
}

public static Optional<Pair<String, byte[]>> getMtlsAuthMethodAndData(MqttPublishMessage publishMessage) {
public static Optional<String> getAuthenticationRole(MqttPublishMessage publishMessage) {
final MqttPublishVariableHeader header = publishMessage.variableHeader();
MqttProperties properties = header.properties();
final MqttProperties.MqttProperty property = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value());
if (property != null && property.value() instanceof String
&& ((String) property.value()).equalsIgnoreCase(AUTH_MTLS)) {
final MqttProperties.MqttProperty data = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value());
return Optional.of(Pair.of((String) property.value(), (byte[]) data.value()));
}
final MqttProperties.UserProperties data = (MqttProperties.UserProperties) properties.getProperty(
MqttProperties.MqttPropertyType.USER_PROPERTY.value());
if (data != null && data.value() instanceof List<MqttProperties.StringPair>) {
return data.value().stream().filter(d -> d.key.equalsIgnoreCase(AUTHENTICATE_ROLE_KEY))
.map(e -> e.value).findFirst();
}
return Optional.empty();
}

public static Optional<Pair<String, byte[]>> getMtlsAuthMethodAndData(MqttSubscribeMessage subscribeMessage) {
public static Optional<String> getAuthenticationRole(MqttSubscribeMessage subscribeMessage) {
final MqttMessageIdAndPropertiesVariableHeader header = subscribeMessage.idAndPropertiesVariableHeader();
MqttProperties properties = header.properties();
final MqttProperties.MqttProperty property = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value());
if (property != null && property.value() instanceof String
&& ((String) property.value()).equalsIgnoreCase(AUTH_MTLS)) {
final MqttProperties.MqttProperty data = properties.getProperty(
MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value());
return Optional.of(Pair.of((String) property.value(), (byte[]) data.value()));
final MqttProperties.UserProperties data = (MqttProperties.UserProperties) properties.getProperty(
MqttProperties.MqttPropertyType.USER_PROPERTY.value());
if (data != null && data.value() instanceof List<MqttProperties.StringPair>) {
return data.value().stream().filter(d -> d.key.equalsIgnoreCase(AUTHENTICATE_ROLE_KEY))
.map(e -> e.value).findFirst();
}
return Optional.empty();
}

public static MqttSubscribeMessage createMqttSubscribeMessage(MqttSubscribeMessage subscribeMessage,
String authMethod,
String authData) {
final MqttMessageIdAndPropertiesVariableHeader header = subscribeMessage.idAndPropertiesVariableHeader();
MqttProperties properties = new MqttProperties();
properties.add(new MqttProperties.StringProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value()
, authMethod));
properties.add(new MqttProperties.BinaryProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value()
, authData.getBytes()));
properties.add(new MqttProperties.UserProperty(AUTHENTICATE_ROLE_KEY, authData));
MqttMessageIdAndPropertiesVariableHeader variableHeader = new MqttMessageIdAndPropertiesVariableHeader(
header.messageId(), properties);
MqttSubscribeMessage newSubscribeMessage = new MqttSubscribeMessage(subscribeMessage.fixedHeader(),
Expand Down

0 comments on commit 13db85b

Please sign in to comment.