diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index a7930c860..063e4cdec 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -396,7 +396,7 @@ public void operationComplete(ChannelFuture future) throws Exception { setupInflightResender(channel); } - postOffice.dispatchConnection(msg); + postOffice.dispatchConnection(msg, channel); LOG.trace("dispatch connection: {}", msg); } } else { diff --git a/broker/src/main/java/io/moquette/broker/PostOffice.java b/broker/src/main/java/io/moquette/broker/PostOffice.java index 3c2771036..e3e684656 100644 --- a/broker/src/main/java/io/moquette/broker/PostOffice.java +++ b/broker/src/main/java/io/moquette/broker/PostOffice.java @@ -25,6 +25,7 @@ import io.moquette.interception.BrokerInterceptor; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttFixedHeader; import io.netty.handler.codec.mqtt.MqttMessageBuilders; @@ -1124,8 +1125,8 @@ private static boolean validateContentTypeAsUTF8(MqttPublishMessage msg) { * notify MqttConnectMessage after connection established (already pass login). * @param msg */ - void dispatchConnection(MqttConnectMessage msg) { - interceptor.notifyClientConnected(msg); + void dispatchConnection(MqttConnectMessage msg, Channel channel) { + interceptor.notifyClientConnected(msg, channel); } void dispatchDisconnection(String clientId,String userName) { diff --git a/broker/src/main/java/io/moquette/interception/BrokerInterceptor.java b/broker/src/main/java/io/moquette/interception/BrokerInterceptor.java index 89c0cefea..1aefff6d3 100644 --- a/broker/src/main/java/io/moquette/interception/BrokerInterceptor.java +++ b/broker/src/main/java/io/moquette/interception/BrokerInterceptor.java @@ -23,6 +23,7 @@ import io.moquette.broker.subscriptions.Subscription; import io.moquette.metrics.MetricsProvider; import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttPublishMessage; import io.netty.util.ReferenceCountUtil; @@ -111,10 +112,15 @@ public void stop() { @Override public void notifyClientConnected(final MqttConnectMessage msg) { + notifyClientConnected(msg, (Channel) null); + } + + @Override + public void notifyClientConnected(final MqttConnectMessage msg, final Channel channel) { for (final InterceptHandler handler : this.handlers.get(InterceptConnectMessage.class)) { LOG.debug("Sending MQTT CONNECT message to interceptor. CId={}, interceptorId={}", msg.payload().clientIdentifier(), handler.getID()); - executor.execute(() -> handler.onConnect(new InterceptConnectMessage(msg))); + executor.execute(() -> handler.onConnect(new InterceptConnectMessage(msg, channel))); } } diff --git a/broker/src/main/java/io/moquette/interception/Interceptor.java b/broker/src/main/java/io/moquette/interception/Interceptor.java index 7055bceae..12f07e412 100644 --- a/broker/src/main/java/io/moquette/interception/Interceptor.java +++ b/broker/src/main/java/io/moquette/interception/Interceptor.java @@ -19,6 +19,7 @@ import io.moquette.interception.messages.InterceptAcknowledgedMessage; import io.moquette.broker.subscriptions.Subscription; import io.moquette.interception.messages.InterceptExceptionMessage; +import io.netty.channel.Channel; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttPublishMessage; @@ -36,6 +37,10 @@ public interface Interceptor { void notifyClientConnected(MqttConnectMessage msg); + default void notifyClientConnected(MqttConnectMessage msg, Channel channel) { + notifyClientConnected(msg); + } + void notifyClientDisconnected(String clientID, String username); void notifyClientConnectionLost(String clientID, String username); diff --git a/broker/src/main/java/io/moquette/interception/messages/InterceptConnectMessage.java b/broker/src/main/java/io/moquette/interception/messages/InterceptConnectMessage.java index e1cde5267..12752c120 100644 --- a/broker/src/main/java/io/moquette/interception/messages/InterceptConnectMessage.java +++ b/broker/src/main/java/io/moquette/interception/messages/InterceptConnectMessage.java @@ -16,21 +16,54 @@ package io.moquette.interception.messages; +import io.netty.channel.Channel; import io.netty.handler.codec.mqtt.MqttConnectMessage; +import java.net.InetSocketAddress; +import java.util.Optional; + public class InterceptConnectMessage extends InterceptAbstractMessage { private final MqttConnectMessage msg; + private final Channel channel; public InterceptConnectMessage(MqttConnectMessage msg) { + this(msg, null); + } + + public InterceptConnectMessage(MqttConnectMessage msg, Channel channel) { super(msg); this.msg = msg; + this.channel = channel; } public String getClientID() { return msg.payload().clientIdentifier(); } + public Optional getChannel() { + return Optional.ofNullable(channel); + } + + public Optional getRemoteAddress() { + if (channel != null && channel.remoteAddress() instanceof InetSocketAddress) { + return Optional.of((InetSocketAddress) channel.remoteAddress()); + } + return Optional.empty(); + } + + public String getClientAddress() { + return getRemoteAddress() + .map(InetSocketAddress::getHostString) + .orElse(null); + } + + public int getClientPort() { + return getRemoteAddress() + .map(InetSocketAddress::getPort) + .orElse(-1); + } + public boolean isCleanSession() { return msg.variableHeader().isCleanSession(); } diff --git a/broker/src/test/java/io/moquette/interception/BrokerInterceptorTest.java b/broker/src/test/java/io/moquette/interception/BrokerInterceptorTest.java index 916f9eefb..9b391e5c5 100644 --- a/broker/src/test/java/io/moquette/interception/BrokerInterceptorTest.java +++ b/broker/src/test/java/io/moquette/interception/BrokerInterceptorTest.java @@ -21,6 +21,7 @@ import io.moquette.broker.subscriptions.Topic; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import io.netty.handler.codec.mqtt.MqttMessageBuilders; import io.netty.handler.codec.mqtt.MqttQoS; import io.netty.handler.codec.mqtt.MqttSubscriptionOption; @@ -28,14 +29,21 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import java.net.InetSocketAddress; import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.refEq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class BrokerInterceptorTest { @@ -123,6 +131,45 @@ public void testNotifyClientConnected() throws Exception { assertEquals(40, n.get()); } + @Test + public void testNotifyClientConnectedIncludesRemoteAddress() throws Exception { + final CountDownLatch notified = new CountDownLatch(1); + final AtomicReference intercepted = new AtomicReference<>(); + final BrokerInterceptor localInterceptor = new BrokerInterceptor( + Collections.singletonList(new AbstractInterceptHandler() { + @Override + public String getID() { + return "RemoteAddressObserver"; + } + + @Override + public void onConnect(InterceptConnectMessage msg) { + intercepted.set(msg); + notified.countDown(); + } + + @Override + public void onSessionLoopError(Throwable error) { + throw new RuntimeException(error); + } + })); + + try { + final InetSocketAddress remoteAddress = new InetSocketAddress("127.0.0.1", 12345); + final Channel channel = mock(Channel.class); + when(channel.remoteAddress()).thenReturn(remoteAddress); + localInterceptor.notifyClientConnected(MqttMessageBuilders.connect().build(), channel); + + assertTrue(notified.await(1, TimeUnit.SECONDS)); + assertSame(channel, intercepted.get().getChannel().get()); + assertEquals(remoteAddress, intercepted.get().getRemoteAddress().get()); + assertEquals("127.0.0.1", intercepted.get().getClientAddress()); + assertEquals(12345, intercepted.get().getClientPort()); + } finally { + localInterceptor.stop(); + } + } + @Test public void testNotifyClientDisconnected() throws Exception { interceptor.notifyClientDisconnected("cli1234", "cli1234");