diff --git a/river/src/main/java/org/jboss/marshalling/river/RiverMarshaller.java b/river/src/main/java/org/jboss/marshalling/river/RiverMarshaller.java index 5e75d015..ee55b3ad 100644 --- a/river/src/main/java/org/jboss/marshalling/river/RiverMarshaller.java +++ b/river/src/main/java/org/jboss/marshalling/river/RiverMarshaller.java @@ -74,6 +74,7 @@ public class RiverMarshaller extends AbstractMarshaller { private final IdentityIntMap instanceCache; private final IdentityIntMap> classCache; + private final IdentityIntMap> serialClassCache; private final IdentityHashMap, Externalizer> externalizers; private int instanceSeq; private int classSeq; @@ -92,6 +93,7 @@ protected RiverMarshaller(final RiverMarshallerFactory marshallerFactory, final final float loadFactor = 0x0.5p0f; instanceCache = new IdentityIntMap((int) ((double)configuration.getInstanceCount() / (double)loadFactor), loadFactor); classCache = new IdentityIntMap>((int) ((double)configuration.getClassCount() / (double)loadFactor), loadFactor); + serialClassCache = new IdentityIntMap>((int) ((double)configuration.getClassCount() / (double)loadFactor), loadFactor); externalizers = new IdentityHashMap, Externalizer>(configuration.getClassCount()); } @@ -261,7 +263,7 @@ protected void doWriteObject(final Object original, final boolean unshared) thro // user type #3: serializable if (serializabilityChecker.isSerializable(objClass)) { write(unshared ? ID_NEW_OBJECT_UNSHARED : ID_NEW_OBJECT); - writeSerializableClass(objClass); + writeSerializableClass(objClass, false); instanceCache.put(obj, instanceSeq++); doWriteSerializableObject(info, obj, objClass); if (unshared) { @@ -1200,7 +1202,7 @@ protected void doWriteEmptyFields(final SerializableClass info) throws IOExcepti } protected void writeProxyClass(final Class objClass) throws IOException { - if (! writeKnownClass(objClass)) { + if (! writeKnownClass(objClass, false)) { writeNewProxyClass(objClass); } } @@ -1230,7 +1232,7 @@ protected void writeNewProxyClass(final Class objClass) throws IOException { } protected void writeEnumClass(final Class objClass) throws IOException { - if (! writeKnownClass(objClass)) { + if (! writeKnownClass(objClass, false)) { writeNewEnumClass(objClass); } } @@ -1262,11 +1264,17 @@ protected void writeObjectArrayClass(final Class objClass) throws IOException } protected void writeClass(final Class objClass) throws IOException { - if (! writeKnownClass(objClass)) { + if (! writeKnownClass(objClass, false)) { writeNewClass(objClass); } } + protected void writeSerialSuperClass(final Class objClass) throws IOException { + if (! writeKnownClass(objClass, true)) { + writeNewSerialSuperClass(objClass); + } + } + private static final IdentityIntMap> BASIC_CLASSES_V2; private static final IdentityIntMap> BASIC_CLASSES_V3; private static final IdentityIntMap> BASIC_CLASSES_V4; @@ -1423,6 +1431,24 @@ protected void writeNewClass(final Class objClass) throws IOException { } } + protected void writeNewSerialSuperClass(final Class objClass) throws IOException { + if (! objClass.isInterface() && serializabilityChecker.isSerializable(objClass)) { + writeNewSerializableClass(objClass); + } else { + ClassTable.Writer classTableWriter = classTable.getClassWriter(objClass); + if (classTableWriter != null) { + write(ID_PREDEFINED_PLAIN_CLASS); + classCache.put(objClass, classSeq++); + writeClassTableData(objClass, classTableWriter); + } else { + write(ID_PLAIN_CLASS); + writeString(classResolver.getClassName(objClass)); + classResolver.annotateClass(this, objClass); + classCache.put(objClass, classSeq++); + } + } + } + private void writeClassTableData(final Class objClass, final ClassTable.Writer classTableWriter) throws IOException { if (configuredVersion == 1) { classTableWriter.writeClass(getBlockMarshaller(), objClass); @@ -1432,14 +1458,29 @@ private void writeClassTableData(final Class objClass, final ClassTable.Write } } - protected boolean writeKnownClass(final Class objClass) throws IOException { + protected boolean writeKnownClass(final Class objClass, final boolean isSuper) throws IOException { final int configuredVersion = this.configuredVersion; - int i = getBasicClasses(configuredVersion).get(objClass, -1); - if (i != -1) { - write(i); - return true; + int i; + if (isSuper) { + // serialized superclasses may only be of certain types + i = getBasicClasses(configuredVersion).get(objClass, -1); + if (i == ID_OBJECT_CLASS) { + write(i); + return true; + } + // otherwise, we see if it's a known serialized class, ignoring other classes + i = serialClassCache.get(objClass, -1); + } else { + i = getBasicClasses(configuredVersion).get(objClass, -1); + if (i != -1) { + write(i); + return true; + } + i = classCache.get(objClass, -1); + if (i == -1) { + i = serialClassCache.get(objClass, -1); + } } - i = classCache.get(objClass, -1); if (i != -1) { final int diff = i - classSeq; if (diff >= -256) { @@ -1457,8 +1498,8 @@ protected boolean writeKnownClass(final Class objClass) throws IOException { return false; } - protected void writeSerializableClass(final Class objClass) throws IOException { - if (! writeKnownClass(objClass)) { + protected void writeSerializableClass(final Class objClass, final boolean isSuper) throws IOException { + if (! writeKnownClass(objClass, isSuper)) { writeNewSerializableClass(objClass); } } @@ -1467,7 +1508,7 @@ protected void writeNewSerializableClass(final Class objClass) throws IOExcep ClassTable.Writer classTableWriter = classTable.getClassWriter(objClass); if (classTableWriter != null) { write(ID_PREDEFINED_SERIALIZABLE_CLASS); - classCache.put(objClass, classSeq++); + serialClassCache.put(objClass, classSeq++); writeClassTableData(objClass, classTableWriter); } else { final SerializableClass info = registry.lookup(objClass); @@ -1483,7 +1524,7 @@ protected void writeNewSerializableClass(final Class objClass) throws IOExcep writeString(className); } writeLong(info.getEffectiveSerialVersionUID()); - classCache.put(objClass, classSeq++); + serialClassCache.put(objClass, classSeq++); classResolver.annotateClass(this, objClass); final SerializableField[] fields = info.getFields(); final int cnt = fields.length; @@ -1508,11 +1549,11 @@ protected void writeNewSerializableClass(final Class objClass) throws IOExcep write(ID_OBJECT_CLASS); return; } - writeClass(sc); + writeSerialSuperClass(sc); } protected void writeExternalizableClass(final Class objClass) throws IOException { - if (! writeKnownClass(objClass)) { + if (! writeKnownClass(objClass, false)) { writeNewExternalizableClass(objClass); } } @@ -1533,7 +1574,7 @@ protected void writeNewExternalizableClass(final Class objClass) throws IOExc } protected void writeExternalizerClass(final Class objClass, final Externalizer externalizer) throws IOException { - if (! writeKnownClass(objClass)) { + if (! writeKnownClass(objClass, false)) { writeNewExternalizerClass(objClass, externalizer); } } @@ -1563,6 +1604,7 @@ public void clearInstanceCache() throws IOException { public void clearClassCache() throws IOException { classCache.clear(); + serialClassCache.clear(); externalizers.clear(); classSeq = 0; instanceCache.clear(); diff --git a/tests/src/test/java/org/jboss/test/marshalling/SimpleMarshallerTests.java b/tests/src/test/java/org/jboss/test/marshalling/SimpleMarshallerTests.java index 03ce576d..d71880f6 100644 --- a/tests/src/test/java/org/jboss/test/marshalling/SimpleMarshallerTests.java +++ b/tests/src/test/java/org/jboss/test/marshalling/SimpleMarshallerTests.java @@ -39,19 +39,20 @@ import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.HashMap; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.SortedMap; import java.util.TreeMap; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import org.jboss.marshalling.AbstractClassResolver; import org.jboss.marshalling.AnnotationClassExternalizerFactory; import org.jboss.marshalling.ByteInput; import org.jboss.marshalling.ByteOutput; @@ -92,8 +93,8 @@ public SimpleMarshallerTests(TestMarshallerProvider testMarshallerProvider, Test * Simple constructor for running one test at a time from an IDE. */ public SimpleMarshallerTests() { - super(new MarshallerFactoryTestMarshallerProvider(new RiverMarshallerFactory(), 3), - new MarshallerFactoryTestUnmarshallerProvider(new RiverMarshallerFactory(), 3), + super(new MarshallerFactoryTestMarshallerProvider(new RiverMarshallerFactory(), 4), + new MarshallerFactoryTestUnmarshallerProvider(new RiverMarshallerFactory(), 4), getOneTestMarshallingConfiguration()); } @@ -3660,4 +3661,35 @@ public void runRead(Unmarshaller unmarshaller) throws Throwable { } }); } + + static class Wrapper implements Serializable { + private static final long serialVersionUID = 1L; + Map map = Collections.unmodifiableMap(new TreeMap<>()); + Wrapped to; + } + + static class Wrapped implements Serializable { + private static final long serialVersionUID = 3L; + Date mop; + SortedMap map = Collections.emptySortedMap(); + int mip; + } + + @Test(description = "JBMAR-233") + public void testWeirdSortingRelatedIssue() throws Throwable { + runReadWriteTest(new ReadWriteTest() { + public void runWrite(final Marshaller marshaller) throws Throwable { + Wrapper smi = new Wrapper(); + smi.to = new Wrapped(); + marshaller.writeObject(smi); + } + + public void runRead(final Unmarshaller unmarshaller) throws Throwable { + Wrapper smi = unmarshaller.readObject(Wrapper.class); + assertNotNull(smi); + assertNotNull(smi.to); + assertEquals(Collections.emptySortedMap(), smi.to.map); + } + }); + } }