Skip to content

Commit

Permalink
feat: SSE 拦截器
Browse files Browse the repository at this point in the history
  • Loading branch information
mySingleLive committed Dec 20, 2024
1 parent 487b8f8 commit e923e16
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static com.dtflys.forest.mapping.MappingParameter.TARGET_BODY;
Expand Down Expand Up @@ -4678,6 +4679,16 @@ public <R> R getAttachment(String name, Class<R> clazz) {
}
return clazz.cast(result);
}


public <R> R getOrAddAttachment(String name, Supplier<R> supplier) {
Object obj = getAttachment(name);
if (obj == null) {
obj = supplier.get();
addAttachment(name, obj);
}
return (R) obj;
}

/**
* 获取序列化器
Expand Down
91 changes: 65 additions & 26 deletions forest-core/src/main/java/com/dtflys/forest/http/ForestSSE.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.dtflys.forest.annotation.SSEMessage;
import com.dtflys.forest.annotation.SSERetryMessage;
import com.dtflys.forest.exceptions.ForestRuntimeException;
import com.dtflys.forest.interceptor.Interceptor;
import com.dtflys.forest.interceptor.SSEInterceptor;
import com.dtflys.forest.reflection.MethodLifeCycleHandler;
import com.dtflys.forest.sse.EventSource;
import com.dtflys.forest.sse.ForestSSEListener;
Expand Down Expand Up @@ -81,22 +83,38 @@ void init(final ForestRequest request) {
this.request.setLifeCycleHandler(new MethodLifeCycleHandler<InputStream>(InputStream.class, InputStream.class) {});
final Class<?> clazz = this.getClass();
final Method[] methods = ReflectUtils.getMethods(clazz);

// 批量注册 SSE 控制器类中的消息处理方法
for (final Method method : methods) {
final Annotation[] annArray = method.getAnnotations();
for (final Annotation ann : annArray) {
if (ann instanceof SSEMessage) {
registerMessageMethod(method, ann, null);
} else if (ann instanceof SSEDataMessage) {
registerMessageMethod(method, ann, "data");
} else if (ann instanceof SSEEventMessage) {
registerMessageMethod(method, ann, "event");
} else if (ann instanceof SSEIdMessage) {
registerMessageMethod(method, ann, "id");
} else if (ann instanceof SSERetryMessage) {
registerMessageMethod(method, ann, "retry");
}
final List<Interceptor> interceptors = request.getInterceptorChain().getInterceptors();
for (final Interceptor interceptor : interceptors) {
if (interceptor instanceof SSEInterceptor) {
final Class<?> interceptorClass = interceptor.getClass();
final Method[] interceptorMethods = ReflectUtils.getMethods(interceptorClass);
registerMethodArray(interceptor, interceptorMethods);
}
}
registerMethodArray(this, methods);
}
}

/**
* 批量注册 SSE 控制器类中的消息处理方法
*
* @param instance 方法所属实例
* @param methods Java 方法数组
*/
private void registerMethodArray(Object instance, final Method[] methods) {
for (final Method method : methods) {
final Annotation[] annArray = method.getAnnotations();
for (final Annotation ann : annArray) {
if (ann instanceof SSEMessage) {
registerMessageMethod(instance, method, ann, null);
} else if (ann instanceof SSEDataMessage) {
registerMessageMethod(instance, method, ann, "data");
} else if (ann instanceof SSEEventMessage) {
registerMessageMethod(instance, method, ann, "event");
} else if (ann instanceof SSEIdMessage) {
registerMessageMethod(instance, method, ann, "id");
} else if (ann instanceof SSERetryMessage) {
registerMessageMethod(instance, method, ann, "retry");
}
}
}
Expand All @@ -105,18 +123,19 @@ void init(final ForestRequest request) {
/**
* 注册 SSE 消息处理方法
*
* @param instance 方法所属实例
* @param method Java 方法对象
* @param ann 注解对象
* @param defaultName SSE 消息默认名称
* @since 1.6.0
*/
private void registerMessageMethod(Method method, Annotation ann, String defaultName) {
private void registerMessageMethod(Object instance, Method method, Annotation ann, String defaultName) {
final Map<String, Object> attrs = ReflectUtils.getAttributesFromAnnotation(ann);
final String valueRegex = String.valueOf(attrs.getOrDefault("valueRegex", ""));
final String valuePrefix = String.valueOf(attrs.getOrDefault("valuePrefix", ""));
final String valuePostfix = String.valueOf(attrs.getOrDefault("valuePostfix", ""));
final String annName = defaultName != null ? defaultName : String.valueOf(attrs.getOrDefault("name", ""));
final SSEMessageMethod sseMessageMethod = new SSEMessageMethod(this, method);
final SSEMessageMethod sseMessageMethod = new SSEMessageMethod(instance, method);
if (StringUtils.isEmpty(valueRegex) && StringUtils.isEmpty(valuePrefix) && StringUtils.isEmpty(valuePostfix)) {
addConsumer(annName, (eventSource, name, value) -> sseMessageMethod.invoke(eventSource));
} else {
Expand Down Expand Up @@ -501,6 +520,19 @@ public ForestSSE addOnRetryMatchesPrefix(String valuePrefix, SSEStringMessageCon
public ForestSSE addOnRetryMatchesPostfix(String valuePostfix, SSEStringMessageConsumer consumer) {
return addConsumerMatchesPostfix("retry", valuePostfix, consumer);
}

private void doOnOpen(final EventSource eventSource) {
final List<Interceptor> interceptors = eventSource.getRequest().getInterceptorChain().getInterceptors();
for (Interceptor interceptor : interceptors) {
if (interceptor instanceof SSEInterceptor) {
((SSEInterceptor) interceptor).onSSEOpen(eventSource);
}
}
onOpen(eventSource);
if (onOpenConsumer != null) {
onOpenConsumer.accept(eventSource);
}
}

/**
* 监听打开回调函数:在开始 SSE 数据流监听的时候调用
Expand All @@ -509,9 +541,19 @@ public ForestSSE addOnRetryMatchesPostfix(String valuePostfix, SSEStringMessageC
* @since 1.6.0
*/
protected void onOpen(EventSource eventSource) {
if (onOpenConsumer != null) {
onOpenConsumer.accept(eventSource);
}

private void doOnClose(final ForestRequest request, final ForestResponse response) {
final List<Interceptor> interceptors = request.getInterceptorChain().getInterceptors();
for (Interceptor interceptor : interceptors) {
if (interceptor instanceof SSEInterceptor) {
((SSEInterceptor) interceptor).onSSEClose(request, response);
}
}
if (onCloseConsumer != null) {
onCloseConsumer.accept(request, response);
}
onClose(request, response);
}

/**
Expand All @@ -522,9 +564,6 @@ protected void onOpen(EventSource eventSource) {
* @since 1.6.0
*/
protected void onClose(ForestRequest request, ForestResponse response) {
if (onCloseConsumer != null) {
onCloseConsumer.accept(request, response);
}
}

/**
Expand Down Expand Up @@ -600,7 +639,7 @@ public <R extends ForestSSE> R listen() {
throw new ForestRuntimeException(e);
}
} else {
response = this.request.execute(new TypeReference<ForestResponse<InputStream>>() {});
response = this.request.execute(new TypeReference<ForestResponse<InputStream>>() {});
}
if (response == null) {
return (R) this;
Expand All @@ -609,7 +648,7 @@ public <R extends ForestSSE> R listen() {
return (R) this;
}
final EventSource openEventSource = new EventSource("open", request, response);
this.onOpen(openEventSource);
this.doOnOpen(openEventSource);
if (SSEMessageResult.CLOSE.equals(openEventSource.getMessageResult())) {
onClose(request, response);
return (R) this;
Expand All @@ -634,7 +673,7 @@ public <R extends ForestSSE> R listen() {
} catch (IOException e) {
throw new ForestRuntimeException(e);
} finally {
onClose(request, response);
doOnClose(request, response);
}
});
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.dtflys.forest.reflection.ForestMethod;
import com.dtflys.forest.utils.ForestProgress;

import java.util.function.Supplier;

/**
* Forest拦截器接口
* <p>拦截器在请求的初始化、发送请求前、发送成功、发送失败等生命周期中都会被调用
Expand Down Expand Up @@ -238,15 +240,15 @@ default Object getAttribute(ForestRequest request, String name) {
* @param request Forest请求对象
* @param name 属性名称
* @param clazz 属性值的类型对象
* @param <T> 属性值类型的泛型
* @param <R> 属性值类型的泛型
* @return Attribute 属性值
*/
default <T> T getAttribute(ForestRequest request, String name, Class<T> clazz) {
default <R> R getAttribute(ForestRequest request, String name, Class<R> clazz) {
Object obj = request.getInterceptorAttribute(this.getClass(), name);
if (obj == null) {
return null;
}
return (T) obj;
return clazz.cast(obj);
}

/**
Expand Down Expand Up @@ -309,4 +311,24 @@ default Double getAttributeAsDouble(ForestRequest request, String name) {
return (Double) attr;
}

/**
* 获取或添加请求在本拦截器中的 Attribute 属性
* <p>当 Attribute 属性中不存在属性名称所对应的值,则添加属性值</p>
*
* @param request Forest请求对象
* @param name 属性名称
* @param supplier 属性值回调函数
* @return 属性值
* @param <R> 属性值类型
* @since 1.6.1
*/
default <R> R getOrAddAttribute(ForestRequest request, String name, Supplier<R> supplier) {
Object obj = getAttribute(request, name);
if (obj == null && supplier != null) {
obj = supplier.get();
addAttribute(request, name, obj);
}
return (R) obj;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,8 @@ public void afterExecute(ForestRequest request, ForestResponse response) {
item.afterExecute(request, response);
}
}

public LinkedList<Interceptor> getInterceptors() {
return interceptors;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.dtflys.forest.interceptor;

import com.dtflys.forest.http.ForestRequest;
import com.dtflys.forest.http.ForestResponse;
import com.dtflys.forest.sse.EventSource;

/**
* Forest SSE 拦截器
*
* @since 1.6.1
*/
public interface SSEInterceptor extends Interceptor {

/**
* 监听打开回调函数:在开始 SSE 数据流监听的时候调用
*
* @param eventSource SSE 事件来源
* @since 1.6.1
*/
default void onSSEOpen(EventSource eventSource) {
}

/**
* 监听关闭回调函数:在结束 SSE 数据流监听的时候调用
*
* @param request Forest 请求对象
* @param response Forest 响应对象
* @since 1.6.1
*/
default void onSSEClose(ForestRequest request, ForestResponse response) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import com.dtflys.forest.http.ForestResponse;

/**
* Forest SSE Event Source
* Forest SSE 事件来源
*
* @since 1.6.0
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,32 @@
import java.lang.reflect.Type;
import java.util.function.Function;

/**
* SSE 消息方法
* <p>用于包装注册好的 SSE 消息处理方法</p>
*
* @since 1.6.1
*/
public class SSEMessageMethod {

private final ForestSSE sse;
/**
* 方法所属实例
*/
private final Object instance;

/**
* Java 方法
*/
private final Method method;


/**
* 方法参数值获取函数表
*/
private Function<EventSource, ?>[] argumentFunctions;


public SSEMessageMethod(ForestSSE sse, Method method) {
this.sse = sse;
public SSEMessageMethod(Object instance, Method method) {
this.instance = instance;
this.method = method;
init();
}
Expand Down Expand Up @@ -77,7 +92,7 @@ public void invoke(final EventSource eventSource) {
final boolean accessible = method.isAccessible();
method.setAccessible(true);
try {
method.invoke(sse, args);
method.invoke(instance, args);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new ForestRuntimeException(e);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.dtflys.forest.test.sse.MySSEInterceptor;

@Address(host = "localhost", port = "{port}")
@BaseRequest(interceptor = MySSEInterceptor.class)
public interface SSEClient {

@Get("/sse")
Expand All @@ -17,4 +16,7 @@ public interface SSEClient {
@Get("/sse")
MySSEHandler testSSE_withCustomClass();

@Get(url = "/sse", interceptor = MySSEInterceptor.class)
ForestSSE testSSE_withInterceptor();

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,30 @@ public void testSSE_withCustomClass() {
);
}

@Test
public void testSSE_withInterceptor() {
server.enqueue(new MockResponse().setResponseCode(200).setBody(
"data:start\n" +
"data:{\"event\": \"message\", \"conversation_id\": \"aee49897-5214308b6b2d\", \"message_id\": \"9e292a7d\", \"created_at\": 1734689225 \"answer\": \"I\", \"from_variable_selector\": null}\n" +
"event:{\"name\":\"Peter\",\"age\": \"18\",\"phone\":\"12345678\"}\n" +
"event:close\n" +
"data:dont show"
));

ForestSSE sse = sseClient.testSSE_withInterceptor().listen();

System.out.println(sse.getRequest().getAttachment("text"));
assertThat(sse.getRequest().getAttachment("text").toString()).isEqualTo(
"MySSEInterceptor onSuccess\n" +
"MySSEInterceptor afterExecute\n" +
"MySSEInterceptor onSSEOpen\n" +
"Receive data: start\n" +
"Receive data: {\"event\": \"message\", \"conversation_id\": \"aee49897-5214308b6b2d\", \"message_id\": \"9e292a7d\", \"created_at\": 1734689225 \"answer\": \"I\", \"from_variable_selector\": null}\n" +
"name: Peter; age: 18; phone: 12345678\n" +
"receive close --- close\n" +
"MySSEInterceptor onSSEClose"
);
}


}
Loading

0 comments on commit e923e16

Please sign in to comment.