Skip to content
This repository has been archived by the owner on Dec 14, 2022. It is now read-only.

Commit

Permalink
feat #275: add ThreadSafeDeserializationSchema solve thread safe (#365)
Browse files Browse the repository at this point in the history
(cherry picked from commit 68b5260)
  • Loading branch information
shibd authored and jianyun8023 committed Jun 30, 2021
1 parent d6e2737 commit a8f9665
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.util.serialization.FlinkSchema;
import org.apache.flink.streaming.util.serialization.PulsarDeserializationSchema;
import org.apache.flink.streaming.util.serialization.ThreadSafeDeserializationSchema;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.types.DeserializationException;
Expand Down Expand Up @@ -78,8 +79,8 @@ public SimpleCollector initialValue() {
keyDeserialization != null && keyProjection.length > 0,
"Key must be set in upsert mode for deserialization schema.");
}
this.keyDeserialization = keyDeserialization;
this.valueDeserialization = valueDeserialization;
this.keyDeserialization = ThreadSafeDeserializationSchema.of(keyDeserialization);
this.valueDeserialization = ThreadSafeDeserializationSchema.of(valueDeserialization);
this.hasMetadata = hasMetadata;
this.outputCollector = new OutputProjectionCollector(
physicalArity,
Expand Down Expand Up @@ -116,22 +117,14 @@ public void deserialize(Message<RowData> message, Collector<RowData> collector)
// shortcut in case no output projection is required,
// also not for a cartesian product with the keys
if (keyDeserialization == null && !hasMetadata) {
// Because the Pulsar Source is designed to be multi-threaded,
// Flink's internal design of the Source is single-threaded,
// so, DeserializationSchema instances are oriented to single-threaded,
// and thread safety issues exist when they are accessed by multiple threads at the same time. Cause the message deserialization to fail.
synchronized (valueDeserialization) {
valueDeserialization.deserialize(message.getData(), collector);
}
valueDeserialization.deserialize(message.getData(), collector);
return;
}
BufferingCollector keyCollector = new BufferingCollector();

// buffer key(s)
if (keyDeserialization != null) {
synchronized (keyDeserialization) {
keyDeserialization.deserialize(message.getKeyBytes(), keyCollector);
}
keyDeserialization.deserialize(message.getKeyBytes(), keyCollector);
}

// project output while emitting values
Expand All @@ -142,9 +135,7 @@ public void deserialize(Message<RowData> message, Collector<RowData> collector)
// collect tombstone messages in upsert mode by hand
outputCollector.collect(null);
} else {
synchronized (valueDeserialization) {
valueDeserialization.deserialize(message.getData(), outputCollector);
}
valueDeserialization.deserialize(message.getData(), outputCollector);
}
keyCollector.buffer.clear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public FlinkSchema(SchemaInfo schemaInfo, SerializationSchema<T> serializer,
DeserializationSchema<T> deserializer) {
this.schemaInfo = schemaInfo;
this.serializer = serializer;
this.deserializer = deserializer;
this.deserializer = ThreadSafeDeserializationSchema.of(deserializer);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ public class PulsarDeserializationSchemaWrapper<T> implements PulsarDeserializat

@Deprecated
public PulsarDeserializationSchemaWrapper(DeserializationSchema<T> deSerializationSchema, DataType dataType) {
this.deSerializationSchema = checkNotNull(deSerializationSchema);
this.deSerializationSchema = ThreadSafeDeserializationSchema.of(checkNotNull(deSerializationSchema));
}

@Deprecated
public PulsarDeserializationSchemaWrapper(DeserializationSchema<T> deSerializationSchema) {
this.deSerializationSchema = checkNotNull(deSerializationSchema);
this.deSerializationSchema = ThreadSafeDeserializationSchema.of(checkNotNull(deSerializationSchema));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.util.serialization;

import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.api.common.typeinfo.TypeInformation;

import java.io.IOException;

/**
* Because the Pulsar Source is designed to be multi-threaded,
* Flink's internal design of the Source is single-threaded,
* so, DeserializationSchema instances are oriented to single-threaded,
* and thread safety issues exist when they are accessed by multiple threads at the same time. Cause the message deserialization to fail.
*/
public class ThreadSafeDeserializationSchema<T> implements DeserializationSchema<T> {

private DeserializationSchema<T> deserializationSchema;

private ThreadSafeDeserializationSchema(DeserializationSchema<T> deserializationSchema) {
this.deserializationSchema = deserializationSchema;
}

public static ThreadSafeDeserializationSchema of(DeserializationSchema deserializationSchema) {
return deserializationSchema != null ? new ThreadSafeDeserializationSchema(deserializationSchema) : null;
}

@Override
public synchronized void open(InitializationContext context) throws Exception {
deserializationSchema.open(context);
}

@Override
public synchronized T deserialize(byte[] bytes) throws IOException {
return deserializationSchema.deserialize(bytes);
}

@Override
public synchronized boolean isEndOfStream(T object) {
return deserializationSchema.isEndOfStream(object);
}

@Override
public synchronized TypeInformation getProducedType() {
return deserializationSchema.getProducedType();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.util.serialization;

import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.api.common.typeinfo.TypeInformation;

import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;

/**
* thread safe {@link ThreadSafeDeserializationSchema} test.
*/
public class ThreadSafeDeserializationSchemaTest {

@Test
public void deserialize() throws InterruptedException {
NoThreadSafeDeserializationSchema noThreadSafeDeserializationSchema = new NoThreadSafeDeserializationSchema();
DeserializationSchema deserializationSchema = ThreadSafeDeserializationSchema.of(noThreadSafeDeserializationSchema);
Thread[] threads = new Thread[10];
for (int i = 0; i < 10; i++) {
threads[i] = new Thread(() -> {
try {
for (int j = 0; j < 100; j++) {
deserializationSchema.deserialize(null);
}
} catch (IOException e) {
}
});
threads[i].start();
}

for (int i = 0; i < 10; i++) {
threads[i].join();
}
Assert.assertEquals(noThreadSafeDeserializationSchema.getCount(), 1000);
}

class NoThreadSafeDeserializationSchema implements DeserializationSchema {

private int count = 0;
private int tmpCount = 0;

public int getCount() {
return count;
}

@Override
public Object deserialize(byte[] bytes) throws IOException {
tmpCount = count;
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
tmpCount++;
count = tmpCount;
return null;
}

@Override
public boolean isEndOfStream(Object o) {
return false;
}

@Override
public TypeInformation getProducedType() {
return null;
}
}
}

0 comments on commit a8f9665

Please sign in to comment.