在 WebFlux 中拦截请求并添加自定义请求头

1、概览

Filter(拦截器/过滤器)是 Spring 提供的一个机制,可以在 Controller 处理请求或向客户端返回响应之前拦截并处理请求。

本文将带你了解如何使用 WebFlux 拦截客户端请求以及如何添加自定义 Header。

2、Maven 依赖

添加 spring-boot-starter-webflux 响应式 Web 依赖:

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-webflux</artifactId>
    <version>3.1.5</version>
</dependency>

3、拦截请求

Spring WebFlux Filter 可分为 WebFilterHandlerFilterFunction 两类。可使用这些过滤器拦截 Web 请求,并添加新的自定义 Header 或修改现有 Header。

3.1、使用 WebFilter

WebFilter 以链式拦截方式处理 Web 请求。WebFilter 在全局范围内生效,一旦启用,将拦截所有的请求和响应。

首先,定义基于注解的 Controller:

@GetMapping(value= "/trace-annotated")
public Mono<String> trace(@RequestHeader(name = "traceId") final String traceId) {
    return Mono.just("TraceId: ".concat(traceId));
}

然后,拦截 Web 请求,使用 TraceWebFilter 实现添加一个新的 Header traceId

@Component
public class TraceWebFilter implements WebFilter {
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        exchange.getRequest().mutate()
          .header("traceId", "ANNOTATED-TRACE-ID");
        return chain.filter(exchange);
    }
}

现在,可以使用 WebTestClient 向跟端点发起 GET 请求,并验证响应中是否包含我们添加的 traceId Header 值,即 TraceId: ANNOTATED-TRACE-ID

@Test
void whenCallAnnotatedTraceEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-annotated")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: ANNOTATED-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

这里需要注意的是,我们不能像修改响应 Header 那样直接修改请求 Header,因为请求头 Map 是只读的:

@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
    if (exchange.getRequest().getPath().toString().equals("/trace-exceptional")) {
        exchange.getRequest().getHeaders().add("traceId", "TRACE-ID");
    }
    return chain.filter(exchange);
 }

该实现会抛出 UnsupportedOperationException 异常

使用 WebTestClient 来验证 Filter 是否会抛出异常:

@GetMapping(value = "/trace-exceptional")
public Mono<String> traceExceptional() {
    return Mono.just("Traced");
}
@Test
void whenCallTraceExceptionalEndpoint_thenThrowsException() {
    EntityExchangeResult<Map> result = webTestClient.get()
      .uri("/trace-exceptional")
      .exchange()
      .expectStatus()
      .is5xxServerError()
      .expectBody(Map.class)
      .returnResult();

    assertNotNull(result.getResponseBody());
}

3.2、使用 HandlerFilterFunction

在函数式风格中,Router 函数拦截请求并调用相应的处理函数。

我们可以启用零个或多个 HandlerFilterFunction,作为过滤 HandlerFunction 的函数。HandlerFilterFunction 仅适用于基于路由器的实现。

对于函数式端点,必须先创建一个 Handler:

@Component
public class TraceRouterHandler {
    public Mono<ServerResponse> handle(final ServerRequest serverRequest) {
        String traceId = serverRequest.headers().firstHeader("traceId");
      
        assert traceId != null;
        Mono<String> body = Mono.just("TraceId: ".concat(traceId));
        return ok().body(body, String.class);
    }
}

使用 Router Configuration 配置 Handler 后,我们拦截 Web 请求,并使用 TraceHandlerFilterFunction 实现添加新的 Header traceId

public RouterFunction<ServerResponse> routes(TraceRouterHandler routerHandler) {
    return RouterFunctions
      .route(GET("/trace-functional-filter"), routerHandler::handle)
      .filter(new TraceHandlerFilterFunction());
}
public class TraceHandlerFilterFunction implements HandlerFilterFunction<ServerResponse, ServerResponse> {
    @Override
    public Mono<ServerResponse> filter(ServerRequest request, HandlerFunction<ServerResponse> handlerFunction) {
        ServerRequest serverRequest = ServerRequest.from(request)
          .header("traceId", "FUNCTIONAL-TRACE-ID")
          .build();
        return handlerFunction.handle(serverRequest);
    }
}

现在,我们可以通过触发对 trace-functional-filter 端点的 GET 调用来验证响应中是否包含我们添加的 traceId Header 值,即 TraceId: FUNCTIONAL-TRACE-ID

@Test
void whenCallTraceFunctionalEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-functional-filter")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: FUNCTIONAL-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

3.3、使用自定义 Processor Function

处理器函数(Processor Function)类似于路由器函数,它拦截请求并调用相应的处理函数。

通过函数式路由 API,我们可以添加零个或多个自定义 Function 实例,这些实例会在 HandlerFunction 之前应用。

该过滤器函数会拦截 Builder 创建的 Web 请求,并添加一个新的 Header traceId

public RouterFunction<ServerResponse> routes(TraceRouterHandler routerHandler) {
    return route()
      .GET("/trace-functional-before", routerHandler::handle)
      .before(request -> ServerRequest.from(request)
        .header("traceId", "FUNCTIONAL-TRACE-ID")
        .build())
      .build());
}

trace-functional-before 端点发起 GET 请,验证响应中是否包含我们添加的 traceId Header 值,即 TraceId: FUNCTIONAL-TRACE-ID

@Test
void whenCallTraceFunctionalBeforeEndpoint_thenResponseContainsTraceId() {
    EntityExchangeResult<String> result = webTestClient.get()
      .uri("/trace-functional-before")
      .exchange()
      .expectStatus()
      .isOk()
      .expectBody(String.class)
      .returnResult();

    String body = "TraceId: FUNCTIONAL-TRACE-ID";
    assertEquals(result.getResponseBody(), body);
}

4、在客户端拦截请求

ExchangeFilterFunctions 可以和 Spring WebClient 一起使用,以拦截客户端的请求。

4.1、使用 ExchangeFilterFunction

ExchangeFilterFunction 是一个与 Spring WebClient 相关的术语。我们使用它来拦截 WebFlux WebClient 的客户端请求。ExchangeFilterFunction 用于在发送或接收请求之前或之后转换请求或响应。

定义 exchange filter function,拦截 WebClient 请求并添加新的 Header traceId

跟踪所有请求头,以验证 ExchangeFilterFunction

public ExchangeFilterFunction modifyRequestHeaders(MultiValueMap<String, String> changedMap) {
    return (request, next) -> {
        ClientRequest clientRequest = ClientRequest.from(request)
          .header("traceId", "TRACE-ID")
          .build();
        changedMap.addAll(clientRequest.headers());
        return next.exchange(clientRequest);
    };
}

定义了过滤器函数后,就可以将其添加到 WebClient 实例。这只能在创建 WebClient 时添加:

public WebClient webclient() {
    return WebClient.builder()
      .filter(modifyRequestHeaders(new LinkedMultiValueMap<>()))
      .build();
}

现在,可以使用 Wiremock 测试自定义 ExchangeFilterFunction

@RegisterExtension
static WireMockExtension extension = WireMockExtension.newInstance()
  .options(wireMockConfig().dynamicPort().dynamicHttpsPort())
  .build();
@Test
void whenCallEndpoint_thenRequestHeadersModified() {
    extension.stubFor(get("/test").willReturn(aResponse().withStatus(200)
      .withBody("SUCCESS")));

    MultiValueMap<String, String> map = new LinkedMultiValueMap<>();

    WebClient webClient = WebClient.builder()
      .filter(modifyRequestHeaders(map))
      .build();
    String receivedResponse = triggerGetRequest(webClient);

    String body = "SUCCESS";
    Assertions.assertEquals(receivedResponse, body);
    Assertions.assertEquals("TRACE-ID", map.getFirst("traceId"));
}

如上,使用 Wiremock 验证了 ExchangeFilterFunction,检查了 MultivalueMap 实例中是否有新的 Header traceId

5、总结

本文介绍了为 Web 请求和 WebClient 请求添加自定义 Header 的不同方法。

首先介绍了如何使用 WebFilterHandlerFilterFunction 在服务端为 Web 请求添加自定义 Header,最后介绍了如何使用 ExchangeFilterFunctionWebClient 客户端实现相同的功能。


Ref:https://www.baeldung.com/spring-webflux-intercept-request-add-headers