Spring AI 结构化输出

1、简介

本文将带你了解如何格式化 Spring AI 的输出结构,使其更易于使用且更加直观。

2、聊天模型简介

ChatModel 接口是向 AI 模型发出提示的基本结构:

public interface ChatModel extends Model<Prompt, ChatResponse> {
    default String call(String message) {
        // 忽略实现。。。
    }

    @Override
    ChatResponse call(Prompt prompt);
}

call() 方法的作用是向模型发送消息并接收响应,仅此而已。

自然而然地,我们期望提示和响应是 String 类型。然而,现代模型的实现通常具有更复杂的结构,可以进行更精细的调整,提高模型的可预测性。例如,虽然可用的默认 call() 方法接受 String 参数,但更实用的做法是使用 PromptPrompt 可以包含多个消息或包括诸如 “温度” 之类的选项,以调节模型的表现力。

我们可以自动装配 ChatModel 并直接调用它。例如,如果我们的依赖中有用于 OpenAI API 的 spring-ai-openai-spring-boot-starter,那么就会自动注入 OpenAI 的实现 OpenAiChatModel

3、结构化输出 API

要获得数据结构化的输出,Spring AI 提供了使用结构化输出 API 封装 ChatModel 调用的工具。此 API 的核心接口是 StructuredOutputConverter(结构化输出转换器):

public interface StructuredOutputConverter<T> extends Converter<String, T>, FormatProvider {}

它结合了另外两个接口,第一个是 FormatProvider

public interface FormatProvider {
    String getFormat();
}

ChatModelcall() 调用之前,getFormat() 会准备好 Prompt,用所需的数据模式填充,并具体描述数据应如何格式化,以避免响应中的不一致。

例如,要获取 JSON 格式的响应,就会使用此 Prompt:

public String getFormat() {
    String template = "Your response should be in JSON format.\n"
      + "Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.\n"
      + "Do not include markdown code blocks in your response.\n
      + "Remove the ```json markdown from the output.\nHere is the JSON Schema instance your output must adhere to:\n```%s```\n";
    return String.format(template, this.jsonSchema);
}

这些指令通常附加在用户输入之后。

第二个接口是 Converter(转换器):

@FunctionalInterface
public interface Converter<S, T> {
    @Nullable
    T convert(S source);
 
    // 默认方法
}

call() 返回响应后,Converter 会将其解析为所需的 T(泛型)类型数据结构。以下是 StructuredOutputConverter 工作原理的简单示意图:

StructuredOutputConverter 工作原理

4、可用的 Converter

在本节中,我们将通过示例来介绍 StructuredOutputConverter 的可用实现。

我们将通过为 “龙与地下城游戏”(Dungeons & Dragons)生成角色来进行演示:

public class Character {
    private String name;
    private int age;
    private String race;
    private String characterClass;
    private String cityOfOrigin;
    private String favoriteWeapon;
    private String bio;
    
    // 构造函数、Getter、Setter 方法省略
}

注意,由于底层使用了 Jackson 的 ObjectMapper,需要为 Bean 类提供空的构造函数

5、BeanOutputConverter

BeanOutputConverter 会从模型的响应中生成指定类的实例。它构建了一个提示(Prompt),指示模型生成符合 RFC8259 标准的 JSON。

来看看如何通过 ChatClient API 来使用它:

@Override
public Character generateCharacterChatClient(String race) {
    return ChatClient.create(chatModel).prompt()
      .user(spec -> spec.text("Generate a D&D character with race {race}")
        .param("race", race))
        .call()
        .entity(Character.class); // <-------- 实际上,是在这里调用 ChatModel.call(),而不是在前面一行。
}

在此方法中,ChatClient.create(chatModel) 会实例化一个 ChatClientprompt() 方法通过 Request(ChatClientRequest)启动 Builder Chain。在本例中,我们只添加了用户文本。创建请求后,调用 call() 方法,返回一个新的 CallResponseSpec,其中包含 ChatModelChatClientRequest。然后,entity() 方法会根据提供的类型创建一个 Converter,完成提示(Prompt)并调用 AI 模型。

你可能会注意到,我们没有直接使用 BeanOutputConverter。这是因为我们使用了一个类作为 entity() 方法的参数,这意味着 BeanOutputConverter 将处理 Prompt 和转换。

为了更加灵活的控制,我们可以编写一个低级别的方法。如下,直接使用之前自动装配的 ChatModel.call()方法:

@Override
public Character generateCharacterChatModel(String race) {
    BeanOutputConverter<Character> beanOutputConverter = new BeanOutputConverter<>(Character.class);

    String format = beanOutputConverter.getFormat();

    String template = """
                Generate a D&D character with race {race}
                {format}
                """;

    PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("race", race, "format", format));
    Prompt prompt = new Prompt(promptTemplate.createMessage());
    Generation generation = chatModel.call(prompt).getResult();

    return beanOutputConverter.convert(generation.getOutput().getContent());
}

如上,创建了 BeanOutputConverter,提取了模型的格式化指南,然后将这些指南添加到自定义提示(Prompt)中。使用 PromptTemplate 制作了最终的提示。PromptTemplate 是 Spring AI 的核心提示(Prompt)模板组件,它底层使用 StringTemplate 引擎。然后,调用模型,得到生成结果。Generation 表示模型的响应:提取其内容,然后使用 Converter 将其转换为 Java 对象。

下面是使用 Converter 从 OpenAI 得到的真实响应示例:

{
    name: "Thoren Ironbeard",
    age: 150,
    race: "Dwarf",
    characterClass: "Wizard",
    cityOfOrigin: "Sundabar",
    favoriteWeapon: "Magic Staff",
    bio: "Born and raised in the city of Sundabar, he is known for his skills in crafting and magic."
}

矮人巫师,真是难得一见!

6、MapOutputConverter 和 ListOutputConverter

通过 MapOutputConverterListOutputConverter,可以分别创建结构为 MapList 的响应。以下是使用 MapOutputConverter 的高级和低级代码示例:

@Override
public Map<String, Object> generateMapOfCharactersChatClient(int amount) {
    return ChatClient.create(chatModel).prompt()
      .user(u -> u.text("Generate {amount} D&D characters, where key is a character's name")
        .param("amount", String.valueOf(amount)))
        .call()
        .entity(new ParameterizedTypeReference<Map<String, Object>>() {});
}
    
@Override
public Map<String, Object> generateMapOfCharactersChatModel(int amount) {
    MapOutputConverter outputConverter = new MapOutputConverter();
    String format = outputConverter.getFormat();
    String template = """
            "Generate {amount} of key-value pairs, where key is a "Dungeons and Dragons" character name and value (String) is his bio.
            {format}
            """;
    Prompt prompt = new Prompt(new PromptTemplate(template, Map.of("amount", String.valueOf(amount), "format", format)).createMessage());
    Generation generation = chatModel.call(prompt).getResult();

    return outputConverter.convert(generation.getOutput().getContent());
}

之所以在 Map<String, Object> 中使用 Object,是因为 MapOutputConverter 目前还不支持泛型值。不过不用担心,稍后我们将创建自定义 Converter 来支持泛型值。

现在,来看看 ListOutputConverter 的示例,可以使用泛型:

@Override
public List<String> generateListOfCharacterNamesChatClient(int amount) {
    return ChatClient.create(chatModel).prompt()
      .user(u -> u.text("List {amount} D&D character names")
        .param("amount", String.valueOf(amount)))
        .call()
        .entity(new ListOutputConverter(new DefaultConversionService()));
}

@Override
public List<String> generateListOfCharacterNamesChatModel(int amount) {
    ListOutputConverter listOutputConverter = new ListOutputConverter(new DefaultConversionService());
    String format = listOutputConverter.getFormat();
    String userInputTemplate = """
            List {amount} D&D character names
            {format}
            """;
    PromptTemplate promptTemplate = new PromptTemplate(userInputTemplate,
      Map.of("amount", amount, "format", format));
    Prompt prompt = new Prompt(promptTemplate.createMessage());
    Generation generation = chatModel.call(prompt).getResult();
    return listOutputConverter.convert(generation.getOutput().getContent());
}

7、自定义 Converter

创建一个 Converter,将 AI 模型中的数据转换为 Map<String, V> 格式,其中 V 是泛型类型。与 Spring 提供的 Converter 一样,我们的容器将实现 StructuredOutputConverter<T>,这要求我们添加 convert()getFormat() 方法:

public class GenericMapOutputConverter<V> implements StructuredOutputConverter<Map<String, V>> {
    private final ObjectMapper objectMapper; // 转换响应
    private final String jsonSchema; // getFormat() 中指令的 Schema
    private final TypeReference<Map<String, V>> typeRef; // 对象 mapper 的 type reference 

    public GenericMapOutputConverter(Class<V> valueType) {
        this.objectMapper = this.getObjectMapper();
        this.typeRef = new TypeReference<>() {};
        this.jsonSchema = generateJsonSchemaForValueType(valueType);
    }

    public Map<String, V> convert(@NonNull String text) {
        try {
            text = trimMarkdown(text);
            return objectMapper.readValue(text, typeRef);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to convert JSON to Map<String, V>", e);
        }
    }

    public String getFormat() {
        String raw = "Your response should be in JSON format.\nThe data structure for the JSON should match this Java class: %s\n" +
                "For the map values, here is the JSON Schema instance your output must adhere to:\n```%s```\n" +
                "Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.\n";
        return String.format(raw, HashMap.class.getName(), this.jsonSchema);
    }

    private ObjectMapper getObjectMapper() {
        return JsonMapper.builder()
          .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
          .build();
    }

    private String trimMarkdown(String text) {
        if (text.startsWith("```json") && text.endsWith("```")) {
            text = text.substring(7, text.length() - 3);
        }
        return text;
    }

    private String generateJsonSchemaForValueType(Class<V> valueType) {
        try {
            JacksonModule jacksonModule = new JacksonModule();
            SchemaGeneratorConfig config = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON)
              .with(jacksonModule)
              .build();
            SchemaGenerator generator = new SchemaGenerator(config);

            JsonNode jsonNode = generator.generateSchema(valueType);
            ObjectWriter objectWriter = new ObjectMapper().writer(new DefaultPrettyPrinter()
              .withObjectIndenter(new DefaultIndenter().withLinefeed(System.lineSeparator())));

            return objectWriter.writeValueAsString(jsonNode);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Could not generate JSON schema for value type: " + valueType.getName(), e);
        }
    }
}

getFormat() 为 AI 模型提供了一个指令,它根据用户的提示向 AI 模型发出最终请求。这条指令指定了一个映射结构,并提供了我们自定义对象的 Schema 值。我们使用 com.github.victools.jsonschema 库生成了一个 Schema。Spring AI 内部将此库用于其 Converter,因此无需显式导入。

由于请求的是 JSON 格式的响应,因此在 convert() 中,我们使用 Jackson 的 ObjectMapper 进行解析。需要像 Spring 的 BeanOutputConverter 实现那样删除 Markdown 中的特殊标记符号(trimMarkdown)。AI 模型通常使用 markdown 来封装代码片段,删除它,以避免 ObjectMapper 出现异常。

之后,就可以像下面这样使用我们的实现:

@Override
public Map<String, Character> generateMapOfCharactersCustomConverter(int amount) {
    GenericMapOutputConverter<Character> outputConverter = new GenericMapOutputConverter<>(Character.class);
    String format = outputConverter.getFormat();
    String template = """
            "Generate {amount} of key-value pairs, where key is a "Dungeons and Dragons" character name and value is character object.
            {format}
            """;
    Prompt prompt = new Prompt(new PromptTemplate(template, Map.of("amount", String.valueOf(amount), "format", format)).createMessage());
    Generation generation = chatModel.call(prompt).getResult();

    return outputConverter.convert(generation.getOutput().getContent());
}

@Override
public Map<String, Character> generateMapOfCharactersCustomConverterChatClient(int amount) {
    return ChatClient.create(chatModel).prompt()
      .user(u -> u.text("Generate {amount} D&D characters, where key is a character's name")
        .param("amount", String.valueOf(amount)))
        .call()
        .entity(new GenericMapOutputConverter<>(Character.class));
}

8、总结

本文介绍了如何使用 Spring AI 中 StructuredOutputConverter 高效地把大型语言模型(LLM)的响应格式化为指定的输出结构。分别介绍了 BeanOutputConverterMapOutputConverterListOutputConverter 的用法,还提供了实际示例。此外,还介绍了如何创建自定义 Converter 来处理更复杂的数据类型。通过这些 Converter,我们可以更容易地格式化 AI 的输出,从而提高 LLM 响应的可靠性和可预测性。


Ref:https://www.baeldung.com/spring-artificial-intelligence-structure-output