Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection filter interface for IP Bans #747

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions broker/src/main/java/io/moquette/BrokerConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public final class BrokerConstants {
public static final String REAUTHORIZE_SUBSCRIPTIONS_ON_CONNECT = "reauthorize_subscriptions_on_connect";
public static final String ALLOW_ZERO_BYTE_CLIENT_ID_PROPERTY_NAME = "allow_zero_byte_client_id";
public static final String ACL_FILE_PROPERTY_NAME = "acl_file";

public static final String CONNECTION_FILTER_CLASS_NAME = "connection_filter_class";
public static final String AUTHORIZATOR_CLASS_NAME = "authorizator_class";
public static final String AUTHENTICATOR_CLASS_NAME = "authenticator_class";
public static final String DB_AUTHENTICATOR_DRIVER = "authenticator.db.driver";
Expand Down
26 changes: 25 additions & 1 deletion broker/src/main/java/io/moquette/broker/MQTTConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.moquette.broker;

import io.moquette.BrokerConstants;
import io.moquette.broker.security.IConnectionFilter;
import io.moquette.broker.subscriptions.Topic;
import io.moquette.broker.security.IAuthenticator;
import io.netty.buffer.ByteBuf;
Expand Down Expand Up @@ -50,17 +51,19 @@ final class MQTTConnection {
final Channel channel;
private final BrokerConfiguration brokerConfig;
private final IAuthenticator authenticator;
private final IConnectionFilter connectionFilter;
private final SessionRegistry sessionRegistry;
private final PostOffice postOffice;
private volatile boolean connected;
private final AtomicInteger lastPacketId = new AtomicInteger(0);
private Session bindedSession;

MQTTConnection(Channel channel, BrokerConfiguration brokerConfig, IAuthenticator authenticator,
SessionRegistry sessionRegistry, PostOffice postOffice) {
IConnectionFilter connectionFilter, SessionRegistry sessionRegistry, PostOffice postOffice) {
this.channel = channel;
this.brokerConfig = brokerConfig;
this.authenticator = authenticator;
this.connectionFilter = connectionFilter;
this.sessionRegistry = sessionRegistry;
this.postOffice = postOffice;
this.connected = false;
Expand Down Expand Up @@ -180,6 +183,12 @@ PostOffice.RouteResult processConnect(MqttConnectMessage msg) {
return PostOffice.RouteResult.failed(clientId);
}

if (!connectionAllowed(clientId)) {
abortConnection(CONNECTION_REFUSED_BANNED);
channel.close().addListener(CLOSE_ON_FAILURE);
return PostOffice.RouteResult.failed(clientId);
}

final String sessionId = clientId;
return postOffice.routeCommand(clientId, "CONN", () -> {
checkMatchSessionLoop(sessionId);
Expand All @@ -188,6 +197,10 @@ PostOffice.RouteResult processConnect(MqttConnectMessage msg) {
});
}

private boolean connectionAllowed(String clientId) {
return connectionFilter == null || connectionFilter.allowConnection(clientDescriptor(clientId));
}

private void checkMatchSessionLoop(String clientId) {
if (!sessionLoopDebug) {
return;
Expand Down Expand Up @@ -632,10 +645,21 @@ public String toString() {
return "MQTTConnection{channel=" + channel + ", connected=" + connected + '}';
}

// TODO : Unsafe cast, this is something else during testing (EmbeddedSocketAddress)
InetSocketAddress remoteAddress() {
return (InetSocketAddress) channel.remoteAddress();
}

ClientDescriptor clientDescriptor(String clientId) {
if (channel.remoteAddress() instanceof InetSocketAddress) {
return new ClientDescriptor(
clientId,
((InetSocketAddress) channel.remoteAddress()).getHostString(),
((InetSocketAddress) channel.remoteAddress()).getPort());
}
return new ClientDescriptor(clientId, "unknown", -1);
}

public void readCompleted() {
LOG.debug("readCompleted client CId: {}", getClientId());
if (getClientId() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@
package io.moquette.broker;

import io.moquette.broker.security.IAuthenticator;
import io.moquette.broker.security.IConnectionFilter;
import io.netty.channel.Channel;

class MQTTConnectionFactory {

private final BrokerConfiguration brokerConfig;
private final IAuthenticator authenticator;
private final IConnectionFilter connectionFilter;
private final SessionRegistry sessionRegistry;
private final PostOffice postOffice;

MQTTConnectionFactory(BrokerConfiguration brokerConfig, IAuthenticator authenticator,
SessionRegistry sessionRegistry, PostOffice postOffice) {
IConnectionFilter connectionFilter, SessionRegistry sessionRegistry, PostOffice postOffice) {
this.brokerConfig = brokerConfig;
this.authenticator = authenticator;
this.connectionFilter = connectionFilter;
this.sessionRegistry = sessionRegistry;
this.postOffice = postOffice;
}

MQTTConnection create(Channel channel) {
return new MQTTConnection(channel, brokerConfig, authenticator, sessionRegistry, postOffice);
return new MQTTConnection(channel, brokerConfig, authenticator, connectionFilter, sessionRegistry, postOffice);
}
}
19 changes: 16 additions & 3 deletions broker/src/main/java/io/moquette/broker/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.moquette.broker.security.DenyAllAuthorizatorPolicy;
import io.moquette.broker.security.IAuthenticator;
import io.moquette.broker.security.IAuthorizatorPolicy;
import io.moquette.broker.security.IConnectionFilter;
import io.moquette.broker.security.PermitAllAuthorizatorPolicy;
import io.moquette.broker.security.ResourceAuthenticator;
import io.moquette.broker.unsafequeues.QueueException;
Expand Down Expand Up @@ -169,11 +170,11 @@ public void startServer(IConfig config) throws IOException {
*/
public void startServer(IConfig config, List<? extends InterceptHandler> handlers) throws IOException {
LOG.debug("Starting moquette integration using IConfig instance and intercept handlers");
startServer(config, handlers, null, null, null);
startServer(config, handlers, null, null, null, null);
}

public void startServer(IConfig config, List<? extends InterceptHandler> handlers, ISslContextCreator sslCtxCreator,
IAuthenticator authenticator, IAuthorizatorPolicy authorizatorPolicy) throws IOException {
IConnectionFilter connectionFilter, IAuthenticator authenticator, IAuthorizatorPolicy authorizatorPolicy) throws IOException {
final long start = System.currentTimeMillis();
if (handlers == null) {
handlers = Collections.emptyList();
Expand All @@ -192,6 +193,7 @@ public void startServer(IConfig config, List<? extends InterceptHandler> handler
LOG.info("Using default SSL context creator");
sslCtxCreator = new DefaultMoquetteSslContextCreator(config);
}
connectionFilter = initializeConnectionFilter(connectionFilter, config);
authenticator = initializeAuthenticator(authenticator, config);
authorizatorPolicy = initializeAuthorizatorPolicy(authorizatorPolicy, config);

Expand Down Expand Up @@ -258,7 +260,7 @@ public void startServer(IConfig config, List<? extends InterceptHandler> handler
dispatcher = new PostOffice(subscriptions, retainedRepository, sessions, interceptor, authorizator,
loopsGroup);
final BrokerConfiguration brokerConfig = new BrokerConfiguration(config);
MQTTConnectionFactory connectionFactory = new MQTTConnectionFactory(brokerConfig, authenticator, sessions,
MQTTConnectionFactory connectionFactory = new MQTTConnectionFactory(brokerConfig, authenticator, connectionFilter, sessions,
dispatcher);

final NewNettyMQTTHandler mqttHandler = new NewNettyMQTTHandler(connectionFactory);
Expand All @@ -275,6 +277,7 @@ public void startServer(IConfig config, List<? extends InterceptHandler> handler
initialized = true;
}


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

private static IQueueRepository initQueuesRepository(IConfig config, Path dataPath, H2Builder h2Builder) throws IOException {
final IQueueRepository queueRepository;
final String queueType = config.getProperty(BrokerConstants.PERSISTENT_QUEUE_TYPE_PROPERTY_NAME);
Expand Down Expand Up @@ -472,6 +475,16 @@ private IAuthorizatorPolicy initializeAuthorizatorPolicy(IAuthorizatorPolicy aut
return authorizatorPolicy;
}

private IConnectionFilter initializeConnectionFilter(IConnectionFilter connectionFilter, IConfig props) {
LOG.debug("Configuring MQTT Connection Filter");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Debug log should carry some useful information like, the connectionFilterClassName

String connectionFilterClassName = props.getProperty(BrokerConstants.CONNECTION_FILTER_CLASS_NAME, "");

if (connectionFilter == null && !connectionFilterClassName.isEmpty()) {
connectionFilter = loadClass(connectionFilterClassName, IConnectionFilter.class, IConfig.class, props);
}
return connectionFilter;
}

Comment on lines +480 to +487
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is a copy and paste of the initial part of initializeAuthenticator and initializeAuthorizatorPolicy, so maybe could be extracted in a separate method and reused.

private IAuthenticator initializeAuthenticator(IAuthenticator authenticator, IConfig props) {
LOG.debug("Configuring MQTT authenticator");
String authenticatorClassName = props.getProperty(BrokerConstants.AUTHENTICATOR_CLASS_NAME, "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.moquette.broker.security;

import io.moquette.broker.ClientDescriptor;

public interface IConnectionFilter {
boolean allowConnection(ClientDescriptor clientDescriptor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class MQTTConnectionConnectTest {
private SessionRegistry sessionRegistry;
private MqttMessageBuilders.ConnectBuilder connMsg;
private static final BrokerConfiguration CONFIG = new BrokerConfiguration(true, true, false, NO_BUFFER_FLUSH);
private MockConnectionFilter connectionFilter = new MockConnectionFilter();
private IAuthenticator mockAuthenticator;
private PostOffice postOffice;
private MemoryQueueRepository queueRepository;
Expand Down Expand Up @@ -102,7 +103,7 @@ private MQTTConnection createMQTTConnectionWithPostOffice(BrokerConfiguration co
}

private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel, PostOffice postOffice) {
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, postOffice);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, postOffice);
}

@Test
Expand Down Expand Up @@ -204,6 +205,23 @@ public void validAuthentication() throws ExecutionException, InterruptedExceptio
assertTrue(channel.isOpen(), "Connection is accepted and therefore must remain open");
}


@Test
public void validAuthenticationBannedClient() throws ExecutionException, InterruptedException {
MqttConnectMessage msg = connMsg.clientId(FAKE_CLIENT_ID)
.username(TEST_USER).password(TEST_PWD).build();

connectionFilter.banClientId(FAKE_CLIENT_ID);

// Exercise
PostOffice.RouteResult result = sut.processConnect(msg);
assertFalse(result.isSuccess());

// Verify
assertEqualsConnAck(CONNECTION_REFUSED_BANNED, channel.readOutbound());
assertFalse(channel.isOpen(), "Connection is refused/baned and therefore must not remain open");
}

@Test
public void noPasswdAuthentication() {
MqttConnectMessage msg = connMsg.clientId(FAKE_CLIENT_ID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.moquette.broker;

import io.moquette.broker.security.IConnectionFilter;
import io.moquette.broker.security.PermitAllAuthorizatorPolicy;
import io.moquette.broker.subscriptions.CTrieSubscriptionDirectory;
import io.moquette.broker.subscriptions.ISubscriptionsDirectory;
Expand Down Expand Up @@ -82,6 +83,7 @@ private void createMQTTConnection(BrokerConfiguration config) {
private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
IAuthenticator mockAuthenticator = new MockAuthenticator(singleton(FAKE_CLIENT_ID),
singletonMap(TEST_USER, TEST_PWD));
IConnectionFilter connectionFilter = new MockConnectionFilter();

ISubscriptionsDirectory subscriptions = new CTrieSubscriptionDirectory();
ISubscriptionsRepository subscriptionsRepository = new MemorySubscriptionsRepository();
Expand All @@ -94,7 +96,7 @@ private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel
sessionRegistry = new SessionRegistry(subscriptions, memorySessionsRepository(), queueRepository, permitAll, scheduler, loopsGroup);
final PostOffice postOffice = new PostOffice(subscriptions,
new MemoryRetainedRepository(), sessionRegistry, ConnectionTestUtils.NO_OBSERVERS_INTERCEPTOR, permitAll, loopsGroup);
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, postOffice);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, postOffice);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that in this test we don't use the filter we could inline it, without creating an used variable:

return new MQTTConnection(channel, config, mockAuthenticator, new MockConnectionFilter(), sessionRegistry, postOffice);

Or better, given that it's a pass-all filter, a more explicit implementation could be used for such cases:

public class AcceptAllFilter implements IConnectionFilter {
    @Override
    public boolean allowConnection(ClientDescriptor clientDescriptor) {
        return true;
    }
}

In all test places where special filtering logic is not needed we can use:

return new MQTTConnection(channel, config, mockAuthenticator, new AcceptAllFilter(), sessionRegistry, postOffice);

This comment is valid also for other places where the MockConnectionFilter is used with the same intention.

}

// @NotNull
Expand Down
33 changes: 33 additions & 0 deletions broker/src/test/java/io/moquette/broker/MockConnectionFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.moquette.broker;

import io.moquette.broker.security.IConnectionFilter;

import java.util.Set;
import java.util.HashSet;
import java.util.stream.Stream;

public class MockConnectionFilter implements IConnectionFilter {
private Set<String> bannedClientIds = new HashSet<>();
private Set<String> bannedAddresses = new HashSet<>();
@Override
public boolean allowConnection(ClientDescriptor clientDescriptor) {
return !bannedClientIds.contains(clientDescriptor.getClientID())
&& !bannedAddresses.contains(clientDescriptor.getAddress());
}

public MockConnectionFilter banClientId(String clientId) {
bannedClientIds.add(clientId);
return this;
}

public MockConnectionFilter banAddress(String address) {
bannedAddresses.add(address);
return this;
}

public MockConnectionFilter reset() {
bannedClientIds = new HashSet<>();
bannedAddresses = new HashSet<>();
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class PostOfficeInternalPublishTest {
private ISubscriptionsDirectory subscriptions;
private MqttConnectMessage connectMessage;
private SessionRegistry sessionRegistry;
private MockConnectionFilter connectionFilter = new MockConnectionFilter();
private MockAuthenticator mockAuthenticator;
private static final BrokerConfiguration ALLOW_ANONYMOUS_AND_ZERO_BYTES_CLID =
new BrokerConfiguration(true, true, false, NO_BUFFER_FLUSH);
Expand Down Expand Up @@ -88,7 +89,7 @@ private MQTTConnection createMQTTConnection(BrokerConfiguration config) {
}

private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, sut);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, sut);
}

private void initPostOfficeAndSubsystems() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class PostOfficePublishTest {
public static final String FAKE_USER_NAME = "UnAuthUser";
private MqttConnectMessage connectMessage;
private SessionRegistry sessionRegistry;
private MockConnectionFilter connectionFilter = new MockConnectionFilter();
private MockAuthenticator mockAuthenticator;
static final BrokerConfiguration ALLOW_ANONYMOUS_AND_ZERO_BYTES_CLID =
new BrokerConfiguration(true, true, false, NO_BUFFER_FLUSH);
Expand Down Expand Up @@ -93,7 +94,7 @@ private MQTTConnection createMQTTConnection(BrokerConfiguration config) {
}

private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, sut);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, sut);
}

private void initPostOfficeAndSubsystems() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public class PostOfficeSubscribeTest {
private ISubscriptionsDirectory subscriptions;
public static final String FAKE_USER_NAME = "UnAuthUser";
private MqttConnectMessage connectMessage;
private MockConnectionFilter connectionFilter = new MockConnectionFilter();
private IAuthenticator mockAuthenticator;
private SessionRegistry sessionRegistry;
public static final BrokerConfiguration CONFIG = new BrokerConfiguration(true, true, false, NO_BUFFER_FLUSH);
Expand Down Expand Up @@ -110,7 +111,7 @@ private void prepareSUT() {
}

private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, sut);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, sut);
}

protected void connect() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class PostOfficeUnsubscribeTest {
private PostOffice sut;
private ISubscriptionsDirectory subscriptions;
private MqttConnectMessage connectMessage;
private MockConnectionFilter connectionFilter = new MockConnectionFilter();
private IAuthenticator mockAuthenticator;
private SessionRegistry sessionRegistry;
public static final BrokerConfiguration CONFIG = new BrokerConfiguration(true, true, false, NO_BUFFER_FLUSH);
Expand Down Expand Up @@ -101,7 +102,7 @@ private void prepareSUT() {
}

private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
return new MQTTConnection(channel, config, mockAuthenticator, sessionRegistry, sut);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sessionRegistry, sut);
}

protected static void connect(MQTTConnection connection, String clientId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ private void createMQTTConnection(BrokerConfiguration config) {
private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel channel) {
IAuthenticator mockAuthenticator = new MockAuthenticator(singleton(FAKE_CLIENT_ID),
singletonMap(TEST_USER, TEST_PWD));
MockConnectionFilter connectionFilter = new MockConnectionFilter();

ISubscriptionsDirectory subscriptions = new CTrieSubscriptionDirectory();
ISubscriptionsRepository subscriptionsRepository = new MemorySubscriptionsRepository();
Expand All @@ -121,7 +122,7 @@ private MQTTConnection createMQTTConnection(BrokerConfiguration config, Channel
sut = new SessionRegistry(subscriptions, sessionRepository, queueRepository, permitAll, scheduler, slidingClock, GLOBAL_SESSION_EXPIRY_SECONDS, loopsGroup);
final PostOffice postOffice = new PostOffice(subscriptions,
new MemoryRetainedRepository(), sut, ConnectionTestUtils.NO_OBSERVERS_INTERCEPTOR, permitAll, loopsGroup);
return new MQTTConnection(channel, config, mockAuthenticator, sut, postOffice);
return new MQTTConnection(channel, config, mockAuthenticator, connectionFilter, sut, postOffice);
}

@Test
Expand Down
3 changes: 1 addition & 2 deletions broker/src/test/java/io/moquette/broker/SessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static io.moquette.BrokerConstants.NO_BUFFER_FLUSH;
import static org.junit.jupiter.api.Assertions.*;

public class SessionTest {

Expand Down Expand Up @@ -125,7 +124,7 @@ public void testRemoveSubscription() {

private void createConnection(Session client) {
BrokerConfiguration brokerConfiguration = new BrokerConfiguration(true, false, false, NO_BUFFER_FLUSH);
MQTTConnection mqttConnection = new MQTTConnection(testChannel, brokerConfiguration, null, null, null);
MQTTConnection mqttConnection = new MQTTConnection(testChannel, brokerConfiguration, null, null, null, null);
client.markConnecting();
client.bind(mqttConnection);
client.completeConnection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public boolean canRead(Topic topic, String user, String client) {
}
};

m_server.startServer(m_config, EMPTY_OBSERVERS, null, new AcceptAllAuthenticator(), switchingAuthorizator);
m_server.startServer(m_config, EMPTY_OBSERVERS, null, null, new AcceptAllAuthenticator(), switchingAuthorizator);
}

@BeforeEach
Expand Down