Skip to content

Commit

Permalink
Polish HissObjectEncryptor
Browse files Browse the repository at this point in the history
  • Loading branch information
mkay1375 committed Aug 15, 2024
1 parent 31443ef commit fb64b14
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 103 deletions.
161 changes: 68 additions & 93 deletions src/main/java/io/github/tap30/hiss/HissObjectEncryptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.intellij.lang.annotations.Language;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.function.BiConsumer;
Expand All @@ -15,12 +16,11 @@
class HissObjectEncryptor {

private static final Logger logger = Logger.getLogger(HissObjectEncryptor.class.getName());
private final static Map<Class<?>, ClassDescription> CLASSES_DESCRIPTION_CACHE = new HashMap<>();
private static final Map<Class<?>, ClassDescription> CLASSES_DESCRIPTION_CACHE = new HashMap<>();

private final HissEncryptor hissEncryptor;
private final HissHasher hissHasher;


public HissObjectEncryptor(HissEncryptor hissEncryptor,
HissHasher hissHasher) {
this.hissEncryptor = Objects.requireNonNull(hissEncryptor);
Expand All @@ -35,31 +35,31 @@ public void decryptObject(Object domainObject) {
this.decryptFields(domainObject);
}

void encryptFields(Object object) {
private void encryptFields(Object object) {
this.processFields(object, this::encryptField);
}

void decryptFields(Object object) {
private void decryptFields(Object object) {
this.processFields(object, this::decryptField);
}

private void processFields(Object object,
BiConsumer<FieldAnnotatedWithEncrypted, Object> processor) {
BiConsumer<Object, FieldAnnotatedWithEncrypted> processor) {
if (object == null) return;

var classDescription = getClassDescription(object.getClass());
for (var field : classDescription.getFieldsAnnotatedWithEncrypted()) {
processor.accept(field, object);
processor.accept(object, field);
}
for (var field : classDescription.getFieldsAnnotatedWithEncryptedInside()) {
this.processFieldsAnnotatedWithEncryptedInside(field, object, processor);
this.processFieldsAnnotatedWithEncryptedInside(object, field, processor);
}
}

private void processFieldsAnnotatedWithEncryptedInside(FieldAnnotatedWithEncryptedInside field,
Object object,
BiConsumer<FieldAnnotatedWithEncrypted, Object> processor) {
var fieldContent = ReflectionUtils.invokeGetter(field.getGetter(), object, Object.class);
private void processFieldsAnnotatedWithEncryptedInside(Object object,
FieldAnnotatedWithEncryptedInside fieldAnnotatedWithEncryptedInside,
BiConsumer<Object, FieldAnnotatedWithEncrypted> processor) {
var fieldContent = fieldAnnotatedWithEncryptedInside.getField().getContent(object);
if (fieldContent instanceof Iterable<?>) {
((Iterable<?>) fieldContent).forEach(item -> this.processFields(item, processor));
} else if (fieldContent instanceof Map<?, ?>) {
Expand All @@ -69,30 +69,30 @@ private void processFieldsAnnotatedWithEncryptedInside(FieldAnnotatedWithEncrypt
}
}

private void encryptField(FieldAnnotatedWithEncrypted fieldAnnotatedWithEncrypted, Object object) {
private void encryptField(Object object, FieldAnnotatedWithEncrypted fieldAnnotatedWithEncrypted) {
try {
var content = getContent(fieldAnnotatedWithEncrypted, object);
var content = fieldAnnotatedWithEncrypted.getContentField().getContent(object);
if (content == null) {
return;
}
@Language("regexp")
var pattern = fieldAnnotatedWithEncrypted.getEncryptedAnnotation().pattern();
var encryptedContent = this.hissEncryptor.encrypt(content, pattern);
fieldAnnotatedWithEncrypted.getContentField().getSetter().invoke(object, encryptedContent);
fieldAnnotatedWithEncrypted.getContentField().setContent(object, encryptedContent);
if (fieldAnnotatedWithEncrypted.getEncryptedAnnotation().hashingEnabled()) {
var hashedContent = this.hissHasher.hash(content, pattern);
fieldAnnotatedWithEncrypted.getHashField().getSetter().invoke(object, hashedContent);
fieldAnnotatedWithEncrypted.getHashField().setContent(object, hashedContent);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private void decryptField(FieldAnnotatedWithEncrypted fieldAnnotatedWithEncrypted, Object object) {
private void decryptField(Object object, FieldAnnotatedWithEncrypted fieldAnnotatedWithEncrypted) {
try {
var content = getContent(fieldAnnotatedWithEncrypted, object);
var content = fieldAnnotatedWithEncrypted.getContentField().getContent(object);
var decryptedContent = this.hissEncryptor.decrypt(content);
fieldAnnotatedWithEncrypted.getContentField().getSetter().invoke(object, decryptedContent);
fieldAnnotatedWithEncrypted.getContentField().setContent(object, decryptedContent);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -106,23 +106,23 @@ private static ClassDescription getClassDescription(Class<?> clazz) {
var fieldsAnnotatedWithEncrypted = new ArrayList<FieldAnnotatedWithEncrypted>();
var fieldsAnnotatedWithEncryptedInside = new ArrayList<FieldAnnotatedWithEncryptedInside>();
for (var field : ReflectionUtils.getAllFields(clazz)) {
getFieldAnnotatedWithEncrypted(field, clazz).ifPresent(fieldsAnnotatedWithEncrypted::add);
getFieldAnnotatedWithEncryptedInside(field, clazz).ifPresent(fieldsAnnotatedWithEncryptedInside::add);
getFieldAnnotatedWithEncrypted(clazz, field).ifPresent(fieldsAnnotatedWithEncrypted::add);
getFieldAnnotatedWithEncryptedInside(clazz, field).ifPresent(fieldsAnnotatedWithEncryptedInside::add);
}

var classDescription = new ClassDescription(fieldsAnnotatedWithEncrypted, fieldsAnnotatedWithEncryptedInside);
CLASSES_DESCRIPTION_CACHE.put(clazz, classDescription);
if (CLASSES_DESCRIPTION_CACHE.size() > 1000) {
if (CLASSES_DESCRIPTION_CACHE.size() > 10000) {
logger.log(Level.WARNING, "{0} classes are cached", CLASSES_DESCRIPTION_CACHE.size());
}
return classDescription;
}

private static Optional<FieldAnnotatedWithEncrypted>
getFieldAnnotatedWithEncrypted(Field field, Class<?> clazz) {
getFieldAnnotatedWithEncrypted(Class<?> clazz, Field field) {
var encryptedAnnotation = field.getDeclaredAnnotation(Encrypted.class);
if (encryptedAnnotation != null) {
var contentField = getDataField(field, clazz);
var contentField = new StringField(clazz, field.getName());
var hashField = getHashField(clazz, field, encryptedAnnotation);
return Optional.of(new FieldAnnotatedWithEncrypted(encryptedAnnotation, contentField, hashField));
} else {
Expand All @@ -131,79 +131,24 @@ private static ClassDescription getClassDescription(Class<?> clazz) {
}

private static Optional<FieldAnnotatedWithEncryptedInside>
getFieldAnnotatedWithEncryptedInside(Field field, Class<?> clazz) {
getFieldAnnotatedWithEncryptedInside(Class<?> clazz, Field field) {
var encryptedInsideAnnotation = field.getDeclaredAnnotation(EncryptedInside.class);
if (encryptedInsideAnnotation != null) {
return Optional.of(new FieldAnnotatedWithEncryptedInside(field, getGetter(field, clazz)));
return Optional.of(new FieldAnnotatedWithEncryptedInside(new ReadOnlyObjectField(clazz, field.getName())));
} else {
return Optional.empty();
}
}

private static DataField getHashField(Class<?> clazz, Field field, Encrypted encryptedAnnotation) {
DataField hashField = null;
private static StringField getHashField(Class<?> clazz, Field field, Encrypted encryptedAnnotation) {
if (encryptedAnnotation.hashingEnabled()) {
if (StringUtils.hasText(encryptedAnnotation.hashFieldName())) {
hashField = getDataField(encryptedAnnotation.hashFieldName(), clazz);
return new StringField(clazz, encryptedAnnotation.hashFieldName());
} else {
var fieldName = "hashed" + capitalizeFirstLetter(field.getName());
hashField = getDataField(fieldName, clazz);
}
}
return hashField;
}

private static DataField getDataField(String fieldName, Class<?> clazz) {
return getDataField(getField(fieldName, clazz), clazz);
}

private static Field getField(String fieldName, Class<?> clazz) {
try {
return clazz.getDeclaredField(fieldName);
} catch (NoSuchFieldException e) {
if (clazz.getSuperclass() != null) {
return getField(fieldName, clazz.getSuperclass());
}
throw new RuntimeException(e);
}
}

private static String getContent(FieldAnnotatedWithEncrypted fieldAnnotatedWithEncrypted, Object object) {
return ReflectionUtils.invokeGetter(fieldAnnotatedWithEncrypted.getContentField().getGetter(), object, String.class);
}

private static DataField getDataField(Field field, Class<?> clazz) {
try {
var firstLetterCapitalizedFieldName = capitalizeFirstLetter(field.getName());
var getter = clazz.getDeclaredMethod("get" + firstLetterCapitalizedFieldName);
var setter = clazz.getDeclaredMethod("set" + firstLetterCapitalizedFieldName, String.class);
return new DataField(field, getter, setter);
} catch (NoSuchMethodException e) {
if (clazz.getSuperclass() != null) {
return getDataField(field, clazz.getSuperclass());
}
throw new RuntimeException(e);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static Method getGetter(Field field, Class<?> clazz) {
try {
var firstLetterCapitalizedFieldName = capitalizeFirstLetter(field.getName());
return clazz.getDeclaredMethod("get" + firstLetterCapitalizedFieldName);
} catch (NoSuchMethodException e) {
if (clazz.getSuperclass() != null) {
return getGetter(field, clazz.getSuperclass());
return new StringField(clazz, "hashed" + StringUtils.capitalizeFirstLetter(field.getName()));
}
throw new RuntimeException(e);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static String capitalizeFirstLetter(String text) {
return text.substring(0, 1).toUpperCase() + text.substring(1);
return null;
}

@Value
Expand All @@ -214,22 +159,52 @@ private static class ClassDescription {

@Value
private static class FieldAnnotatedWithEncryptedInside {
Field field;
Method getter;
ReadOnlyObjectField field;
}

@Value
private static class FieldAnnotatedWithEncrypted {
Encrypted encryptedAnnotation;
DataField contentField;
DataField hashField;
StringField contentField;
StringField hashField;
}

@Value
private static class DataField {
Field field;
Method getter;
Method setter;
private static class StringField {
private final Method getter;
private final Method setter;

public StringField(Class<?> clazz, String fieldName) {
this.getter = ReflectionUtils.getMethod(clazz,
"get" + StringUtils.capitalizeFirstLetter(fieldName));
this.setter = ReflectionUtils.getMethod(clazz,
"set" + StringUtils.capitalizeFirstLetter(fieldName), String.class);
}

public String getContent(Object object) {
return ReflectionUtils.invokeSupplier(object, getter, String.class);
}

public void setContent(Object object, String content) {
try {
setter.invoke(object, content);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
}

private static class ReadOnlyObjectField {
private final Method getter;

public ReadOnlyObjectField(Class<?> clazz, String fieldName) {
this.getter = ReflectionUtils.getMethod(clazz,
"get" + StringUtils.capitalizeFirstLetter(fieldName));
}

public Object getContent(Object object) {
return ReflectionUtils.invokeSupplier(object, getter, Object.class);
}

}

}
23 changes: 16 additions & 7 deletions src/main/java/io/github/tap30/hiss/utils/ReflectionUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

public class ReflectionUtils {


public static <T> T invokeGetter(Method getter, Object object, Class<T> targetType) {
public static <T> T invokeSupplier(Object object, Method supplier, Class<T> targetType) {
Object content;
try {
content = getter.invoke(object);
content = supplier.invoke(object);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
Expand All @@ -27,19 +26,29 @@ public static <T> T invokeGetter(Method getter, Object object, Class<T> targetTy
return castedContent;
} else {
throw new ClassCastException(String.format(
"Cast error for content of method %s: wanted %s but got %s", getter.getName(), targetType.getName(), content.getClass().getName()
"Cast error for content of method %s: wanted %s but got %s",
supplier.getName(), targetType.getName(), content.getClass().getName()
));
}
}

public static List<Field> getAllFields(Class<?> clazz) {
var objectFields = new ArrayList<Field>();
Class<?> objectClass = clazz;
while (objectClass != null) {
for (var objectClass = clazz; objectClass != null; objectClass = objectClass.getSuperclass()) {
objectFields.addAll(Arrays.asList(objectClass.getDeclaredFields()));
objectClass = objectClass.getSuperclass();
}
return Collections.unmodifiableList(objectFields);
}

public static Method getMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) {
try {
return clazz.getDeclaredMethod(methodName, parameterTypes);
} catch (NoSuchMethodException e) {
if (clazz.getSuperclass() != null) {
return getMethod(clazz.getSuperclass(), methodName, parameterTypes);
}
throw new RuntimeException(e);
}
}

}
4 changes: 4 additions & 0 deletions src/main/java/io/github/tap30/hiss/utils/StringUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ public static String toLowerCase(String text) {
return text.toLowerCase();
}

public static String capitalizeFirstLetter(String text) {
return text.substring(0, 1).toUpperCase() + text.substring(1);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import java.lang.reflect.InvocationTargetException;

import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class ReflectionUtilsTest {

Expand All @@ -15,7 +16,7 @@ void testInvokeGetter() throws NoSuchMethodException, InvocationTargetException,
var nameGetterMethod = aClassInstance.getClass().getDeclaredMethod("getName");

// When
var content = ReflectionUtils.invokeGetter(nameGetterMethod, aClassInstance, String.class);
var content = ReflectionUtils.invokeSupplier(aClassInstance, nameGetterMethod, String.class);

// Then
assertEquals("Mamad", content);
Expand All @@ -28,7 +29,7 @@ void testInvokeGetter_whenTypeNotMatches() throws NoSuchMethodException, Invocat
var nameGetterMethod = aClassInstance.getClass().getDeclaredMethod("getName");

// When & Then
assertThrows(ClassCastException.class, () -> ReflectionUtils.invokeGetter(nameGetterMethod, aClassInstance, Integer.class));
assertThrows(ClassCastException.class, () -> ReflectionUtils.invokeSupplier(aClassInstance, nameGetterMethod, Integer.class));
}

public static class AClass {
Expand Down

0 comments on commit fb64b14

Please sign in to comment.