Skip to content

Commit

Permalink
Add support for Read with Meatadata in MqttIO (#32668)
Browse files Browse the repository at this point in the history
* add support for read with metadata in MqttIO

* Update CHANGES.md

* update javadoc

* update javadoc

* refactor : change to use SchemaCoder in MqttIO
- remove MqttRecordCoder
- refactor MqttRecord to use AutoValueSchema
- change related test
  • Loading branch information
twosom authored Oct 10, 2024
1 parent 9ceb14e commit 95bf983
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

## New Features / Improvements

* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195))
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for processing events which use a global sequence to "ordered" extension (Java) [#32540](https://github.com/apache/beam/pull/32540)

Expand Down
163 changes: 130 additions & 33 deletions sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.mqtt;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;

import com.google.auto.value.AutoValue;
import java.io.IOException;
Expand All @@ -36,6 +37,7 @@
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
Expand Down Expand Up @@ -80,6 +82,48 @@
*
* }</pre>
*
* <h3>Reading with Metadata from a MQTT broker</h3>
*
* <p>The {@code readWithMetadata} method extends the functionality of the basic {@code read} method
* by returning a {@link PCollection} of metadata that includes both the topic name and the payload.
* The metadata is encapsulated in a container class {@link MqttRecord} that includes the topic name
* and payload. This allows you to implement business logic that can differ depending on the topic
* from which the message was received.
*
* <pre>{@code
* PCollection<MqttRecord> records = pipeline.apply(
* MqttIO.readWithMetadata()
* .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
* "tcp://host:11883",
* "my_topic_pattern"))
*
* }</pre>
*
* <p>By using the topic information, you can apply different processing logic depending on the
* source topic, enhancing the flexibility of message processing.
*
* <h4>Example</h4>
*
* <pre>{@code
* pipeline
* .apply(MqttIO.readWithMetadata()
* .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
* "tcp://host:1883", "my_topic_pattern")))
* .apply(ParDo.of(new DoFn<MqttRecord, Void>() {
* @ProcessElement
* public void processElement(ProcessContext c) {
* MqttRecord record = c.element();
* String topic = record.getTopic();
* byte[] payload = record.getPayload();
* // Apply business logic based on the topic
* if (topic.equals("important_topic")) {
* // Special processing for important_topic
* }
* }
* }));
*
* }</pre>
*
* <h3>Writing to a MQTT broker</h3>
*
* <p>MqttIO sink supports writing {@code byte[]} to a topic on a MQTT broker.
Expand Down Expand Up @@ -130,9 +174,18 @@ public class MqttIO {
private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class);
private static final int MQTT_3_1_MAX_CLIENT_ID_LENGTH = 23;

public static Read read() {
return new AutoValue_MqttIO_Read.Builder()
public static Read<byte[]> read() {
return new AutoValue_MqttIO_Read.Builder<byte[]>()
.setMaxReadTime(null)
.setWithMetadata(false)
.setMaxNumRecords(Long.MAX_VALUE)
.build();
}

public static Read<MqttRecord> readWithMetadata() {
return new AutoValue_MqttIO_Read.Builder<MqttRecord>()
.setMaxReadTime(null)
.setWithMetadata(true)
.setMaxNumRecords(Long.MAX_VALUE)
.build();
}
Expand Down Expand Up @@ -267,29 +320,37 @@ private MQTT createClient() throws Exception {

/** A {@link PTransform} to read from a MQTT broker. */
@AutoValue
public abstract static class Read extends PTransform<PBegin, PCollection<byte[]>> {
public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {

abstract @Nullable ConnectionConfiguration connectionConfiguration();

abstract long maxNumRecords();

abstract @Nullable Duration maxReadTime();

abstract Builder builder();
abstract Builder<T> builder();

abstract boolean withMetadata();

abstract @Nullable Coder<T> coder();

@AutoValue.Builder
abstract static class Builder {
abstract Builder setConnectionConfiguration(ConnectionConfiguration config);
abstract static class Builder<T> {
abstract Builder<T> setConnectionConfiguration(ConnectionConfiguration config);

abstract Builder<T> setMaxNumRecords(long maxNumRecords);

abstract Builder setMaxNumRecords(long maxNumRecords);
abstract Builder<T> setMaxReadTime(Duration maxReadTime);

abstract Builder setMaxReadTime(Duration maxReadTime);
abstract Builder<T> setWithMetadata(boolean withMetadata);

abstract Read build();
abstract Builder<T> setCoder(Coder<T> coder);

abstract Read<T> build();
}

/** Define the MQTT connection configuration used to connect to the MQTT broker. */
public Read withConnectionConfiguration(ConnectionConfiguration configuration) {
public Read<T> withConnectionConfiguration(ConnectionConfiguration configuration) {
checkArgument(configuration != null, "configuration can not be null");
return builder().setConnectionConfiguration(configuration).build();
}
Expand All @@ -299,27 +360,41 @@ public Read withConnectionConfiguration(ConnectionConfiguration configuration) {
* records is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link
* PCollection}.
*/
public Read withMaxNumRecords(long maxNumRecords) {
public Read<T> withMaxNumRecords(long maxNumRecords) {
return builder().setMaxNumRecords(maxNumRecords).build();
}

/**
* Define the max read time (duration) while the {@link Read} will receive messages. When this
* max read time is not null, the {@link Read} will provide a bounded {@link PCollection}.
*/
public Read withMaxReadTime(Duration maxReadTime) {
public Read<T> withMaxReadTime(Duration maxReadTime) {
return builder().setMaxReadTime(maxReadTime).build();
}

@Override
public PCollection<byte[]> expand(PBegin input) {
@SuppressWarnings("unchecked")
public PCollection<T> expand(PBegin input) {
checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null");
checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null");

org.apache.beam.sdk.io.Read.Unbounded<byte[]> unbounded =
org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this));
Coder<T> coder;
if (withMetadata()) {
try {
coder =
(Coder<T>) input.getPipeline().getSchemaRegistry().getSchemaCoder(MqttRecord.class);
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e.getMessage());
}
} else {
coder = (Coder<T>) ByteArrayCoder.of();
}

org.apache.beam.sdk.io.Read.Unbounded<T> unbounded =
org.apache.beam.sdk.io.Read.from(
new UnboundedMqttSource<>(this.builder().setCoder(coder).build()));

PTransform<PBegin, PCollection<byte[]>> transform = unbounded;
PTransform<PBegin, PCollection<T>> transform = unbounded;

if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) {
transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords());
Expand Down Expand Up @@ -403,27 +478,39 @@ public int hashCode() {
}

@VisibleForTesting
static class UnboundedMqttSource extends UnboundedSource<byte[], MqttCheckpointMark> {
static class UnboundedMqttSource<T> extends UnboundedSource<T, MqttCheckpointMark> {

private final Read spec;
private final Read<T> spec;

public UnboundedMqttSource(Read spec) {
public UnboundedMqttSource(Read<T> spec) {
this.spec = spec;
}

@Override
public UnboundedReader<byte[]> createReader(
@SuppressWarnings("unchecked")
public UnboundedReader<T> createReader(
PipelineOptions options, MqttCheckpointMark checkpointMark) {
return new UnboundedMqttReader(this, checkpointMark);
final UnboundedMqttReader<T> unboundedMqttReader;
if (spec.withMetadata()) {
unboundedMqttReader =
new UnboundedMqttReader<>(
this,
checkpointMark,
message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
} else {
unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
}

return unboundedMqttReader;
}

@Override
public List<UnboundedMqttSource> split(int desiredNumSplits, PipelineOptions options) {
public List<UnboundedMqttSource<T>> split(int desiredNumSplits, PipelineOptions options) {
// MQTT is based on a pub/sub pattern
// so, if we create several subscribers on the same topic, they all will receive the same
// message, resulting to duplicate messages in the PCollection.
// So, for MQTT, we limit to number of split ot 1 (unique source).
return Collections.singletonList(new UnboundedMqttSource(spec));
return Collections.singletonList(new UnboundedMqttSource<>(spec));
}

@Override
Expand All @@ -437,36 +524,46 @@ public Coder<MqttCheckpointMark> getCheckpointMarkCoder() {
}

@Override
public Coder<byte[]> getOutputCoder() {
return ByteArrayCoder.of();
public Coder<T> getOutputCoder() {
return checkNotNull(this.spec.coder(), "coder can not be null");
}
}

@VisibleForTesting
static class UnboundedMqttReader extends UnboundedSource.UnboundedReader<byte[]> {
static class UnboundedMqttReader<T> extends UnboundedSource.UnboundedReader<T> {

private final UnboundedMqttSource source;
private final UnboundedMqttSource<T> source;

private MQTT client;
private BlockingConnection connection;
private byte[] current;
private T current;
private Instant currentTimestamp;
private MqttCheckpointMark checkpointMark;
private SerializableFunction<Message, T> extractFn;

public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) {
public UnboundedMqttReader(UnboundedMqttSource<T> source, MqttCheckpointMark checkpointMark) {
this.source = source;
this.current = null;
if (checkpointMark != null) {
this.checkpointMark = checkpointMark;
} else {
this.checkpointMark = new MqttCheckpointMark();
}
this.extractFn = message -> (T) message.getPayload();
}

public UnboundedMqttReader(
UnboundedMqttSource<T> source,
MqttCheckpointMark checkpointMark,
SerializableFunction<Message, T> extractFn) {
this(source, checkpointMark);
this.extractFn = extractFn;
}

@Override
public boolean start() throws IOException {
LOG.debug("Starting MQTT reader ...");
Read spec = source.spec;
Read<T> spec = source.spec;
try {
client = spec.connectionConfiguration().createClient();
LOG.debug("Reader client ID is {}", client.getClientId());
Expand All @@ -488,7 +585,7 @@ public boolean advance() throws IOException {
if (message == null) {
return false;
}
current = message.getPayload();
current = this.extractFn.apply(message);
currentTimestamp = Instant.now();
checkpointMark.add(message, currentTimestamp);
} catch (Exception e) {
Expand Down Expand Up @@ -520,7 +617,7 @@ public UnboundedSource.CheckpointMark getCheckpointMark() {
}

@Override
public byte[] getCurrent() {
public T getCurrent() {
if (current == null) {
throw new NoSuchElementException();
}
Expand All @@ -536,7 +633,7 @@ public Instant getCurrentTimestamp() {
}

@Override
public UnboundedMqttSource getCurrentSource() {
public UnboundedMqttSource<T> getCurrentSource() {
return source;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.mqtt;

import com.google.auto.value.AutoValue;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;

/** A container class for MQTT message metadata, including the topic name and payload. */
@DefaultSchema(AutoValueSchema.class)
@AutoValue
public abstract class MqttRecord {
public abstract String getTopic();

@SuppressWarnings("mutable")
public abstract byte[] getPayload();

static Builder builder() {
return new AutoValue_MqttRecord.Builder();
}

static MqttRecord of(String topic, byte[] payload) {
return builder().setTopic(topic).setPayload(payload).build();
}

@AutoValue.Builder
abstract static class Builder {
abstract Builder setTopic(String topic);

abstract Builder setPayload(byte[] payload);

abstract MqttRecord build();
}
}
Loading

0 comments on commit 95bf983

Please sign in to comment.