Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ public void operationComplete(ChannelFuture future) throws Exception {
setupInflightResender(channel);
}

postOffice.dispatchConnection(msg);
postOffice.dispatchConnection(msg, channel);
LOG.trace("dispatch connection: {}", msg);
}
} else {
Expand Down
5 changes: 3 additions & 2 deletions broker/src/main/java/io/moquette/broker/PostOffice.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.moquette.interception.BrokerInterceptor;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttFixedHeader;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
Expand Down Expand Up @@ -1124,8 +1125,8 @@ private static boolean validateContentTypeAsUTF8(MqttPublishMessage msg) {
* notify MqttConnectMessage after connection established (already pass login).
* @param msg
*/
void dispatchConnection(MqttConnectMessage msg) {
interceptor.notifyClientConnected(msg);
void dispatchConnection(MqttConnectMessage msg, Channel channel) {
interceptor.notifyClientConnected(msg, channel);
}

void dispatchDisconnection(String clientId,String userName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.moquette.broker.subscriptions.Subscription;
import io.moquette.metrics.MetricsProvider;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;
import io.netty.util.ReferenceCountUtil;
Expand Down Expand Up @@ -111,10 +112,15 @@ public void stop() {

@Override
public void notifyClientConnected(final MqttConnectMessage msg) {
notifyClientConnected(msg, (Channel) null);
}

@Override
public void notifyClientConnected(final MqttConnectMessage msg, final Channel channel) {
for (final InterceptHandler handler : this.handlers.get(InterceptConnectMessage.class)) {
LOG.debug("Sending MQTT CONNECT message to interceptor. CId={}, interceptorId={}",
msg.payload().clientIdentifier(), handler.getID());
executor.execute(() -> handler.onConnect(new InterceptConnectMessage(msg)));
executor.execute(() -> handler.onConnect(new InterceptConnectMessage(msg, channel)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.moquette.interception.messages.InterceptAcknowledgedMessage;
import io.moquette.broker.subscriptions.Subscription;
import io.moquette.interception.messages.InterceptExceptionMessage;
import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttPublishMessage;

Expand All @@ -36,6 +37,10 @@ public interface Interceptor {

void notifyClientConnected(MqttConnectMessage msg);

default void notifyClientConnected(MqttConnectMessage msg, Channel channel) {
notifyClientConnected(msg);
}

void notifyClientDisconnected(String clientID, String username);

void notifyClientConnectionLost(String clientID, String username);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,54 @@

package io.moquette.interception.messages;

import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttConnectMessage;

import java.net.InetSocketAddress;
import java.util.Optional;

public class InterceptConnectMessage extends InterceptAbstractMessage {

private final MqttConnectMessage msg;
private final Channel channel;

public InterceptConnectMessage(MqttConnectMessage msg) {
this(msg, null);
}

public InterceptConnectMessage(MqttConnectMessage msg, Channel channel) {
super(msg);
this.msg = msg;
this.channel = channel;
}

public String getClientID() {
return msg.payload().clientIdentifier();
}

public Optional<Channel> getChannel() {
return Optional.ofNullable(channel);
}

public Optional<InetSocketAddress> getRemoteAddress() {
if (channel != null && channel.remoteAddress() instanceof InetSocketAddress) {
return Optional.of((InetSocketAddress) channel.remoteAddress());
}
return Optional.empty();
}

public String getClientAddress() {
return getRemoteAddress()
.map(InetSocketAddress::getHostString)
.orElse(null);
}

public int getClientPort() {
return getRemoteAddress()
.map(InetSocketAddress::getPort)
.orElse(-1);
}

public boolean isCleanSession() {
return msg.variableHeader().isCleanSession();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@
import io.moquette.broker.subscriptions.Topic;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.codec.mqtt.MqttQoS;
import io.netty.handler.codec.mqtt.MqttSubscriptionOption;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class BrokerInterceptorTest {

Expand Down Expand Up @@ -123,6 +131,45 @@ public void testNotifyClientConnected() throws Exception {
assertEquals(40, n.get());
}

@Test
public void testNotifyClientConnectedIncludesRemoteAddress() throws Exception {
final CountDownLatch notified = new CountDownLatch(1);
final AtomicReference<InterceptConnectMessage> intercepted = new AtomicReference<>();
final BrokerInterceptor localInterceptor = new BrokerInterceptor(
Collections.<InterceptHandler>singletonList(new AbstractInterceptHandler() {
@Override
public String getID() {
return "RemoteAddressObserver";
}

@Override
public void onConnect(InterceptConnectMessage msg) {
intercepted.set(msg);
notified.countDown();
}

@Override
public void onSessionLoopError(Throwable error) {
throw new RuntimeException(error);
}
}));

try {
final InetSocketAddress remoteAddress = new InetSocketAddress("127.0.0.1", 12345);
final Channel channel = mock(Channel.class);
when(channel.remoteAddress()).thenReturn(remoteAddress);
localInterceptor.notifyClientConnected(MqttMessageBuilders.connect().build(), channel);

assertTrue(notified.await(1, TimeUnit.SECONDS));
assertSame(channel, intercepted.get().getChannel().get());
assertEquals(remoteAddress, intercepted.get().getRemoteAddress().get());
assertEquals("127.0.0.1", intercepted.get().getClientAddress());
assertEquals(12345, intercepted.get().getClientPort());
} finally {
localInterceptor.stop();
}
}

@Test
public void testNotifyClientDisconnected() throws Exception {
interceptor.notifyClientDisconnected("cli1234", "cli1234");
Expand Down