OpenAI SDK开发(1)
本次完成的是基本框架的搭建,项目结构如下图所示:

common
Constants
common包下定义了Constants类,里面暂时写了一个枚举对象Role,是要用在Message中的一个参数,而Message在Request和Response中都有,所以放在common包下,后面会用到.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
   | public class Constants {
      
 
 
      public enum Role {
          SYSTEM("system"),         USER("user"),         ASSISTANT("assistant"),         ;
          private String code;
          Role(String code) {             this.code = code;         }
          public String getCode() {             return code;         }
      }
  }
 
  | 
 
domain
chat(聊天模型)
ChatChoice
1 2 3 4 5 6 7 8 9 10
   | @Data public class ChatChoice implements Serializable {
      private long index;     @JsonProperty("message")     private Message message;     @JsonProperty("finish_reason")     private String finishReason;
  }
   | 
 
这里面定义的是choices中的几个参数,choices参数是在Response中的

ChatCompletionRequest(聊天完成请求)
把model单独写了一个枚举类,定义类所需参数
这些参数里只有model和message是必须的,其他的都是可选的,所以用Builder模式来构建对象

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
   | @Data @Builder @Slf4j @JsonInclude(JsonInclude.Include.NON_NULL) @NoArgsConstructor @AllArgsConstructor public class ChatCompletionRequest implements Serializable {
           private String model = Model.GPT_3_5_TURBO.getCode();          private List<Message> messages;          private double temperature = 0.2;          @JsonProperty("top_p")     private Double topP = 1d;          private Integer n = 1;          private boolean stream = false;          private List<String> stop;          @JsonProperty("max_tokens")     private Integer maxTokens = 2048;          @JsonProperty("frequency_penalty")     private double frequencyPenalty = 0;          @JsonProperty("presence_penalty")     private double presencePenalty = 0;          @JsonProperty("logit_bias")     private Map logitBias;          private String user;
      @Getter     @AllArgsConstructor     public enum Model {                  GPT_3_5_TURBO("gpt-3.5-turbo"),                  GPT_4("gpt-4"),                  GPT_4_32K("gpt-4-32k"),         ;         private String code;     }
  }
   | 
 
ChatCompletionResponse(聊天完成响应)
定义了Response中的参数

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
   | @Data public class ChatCompletionResponse implements Serializable {
           private String id;          private String object;          private String model;          private List<ChatChoice> choices;          private long created;          private Usage usage;
  }
   | 
 
Message
定义的聊天消息对象,包含消息角色、消息内容、消息名称
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
   | @Data @JsonInclude(JsonInclude.Include.NON_NULL) public class Message implements Serializable {
      private String role;     private String content;     private String name;
      public Message() {     }
      private Message(Builder builder) {         this.role = builder.role;         this.content = builder.content;         this.name = builder.name;     }
      public static Builder builder() {         return new Builder();     }
      
 
      public static final class Builder {
          private String role;         private String content;         private String name;
          public Builder() {         }
          public Builder role(Constants.Role role) {             this.role = role.getCode();             return this;         }
          public Builder content(String content) {             this.content = content;             return this;         }
          public Builder name(String name) {             this.name = name;             return this;         }
          public Message build() {             return new Message(this);         }     }
  }
   | 
 
other
Usage(使用量)
是Response中的一个参数,记录了token的使用量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
   | public class Usage implements Serializable {
           @JsonProperty("prompt_tokens")     private long promptTokens;          @JsonProperty("completion_tokens")     private long completionTokens;          @JsonProperty("total_tokens")     private long totalTokens;
      public long getPromptTokens() {         return promptTokens;     }
      public void setPromptTokens(long promptTokens) {         this.promptTokens = promptTokens;     }
      public long getCompletionTokens() {         return completionTokens;     }
      public void setCompletionTokens(long completionTokens) {         this.completionTokens = completionTokens;     }
      public long getTotalTokens() {         return totalTokens;     }
      public void setTotalTokens(long totalTokens) {         this.totalTokens = totalTokens;     }
  }
 
  | 
 
OpenAiResponse
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
   | @Data public class OpenAiResponse<T> implements Serializable {
      private String object;     private List<T> data;     private Error error;
 
      @Data     public class Error {         private String message;         private String type;         private String param;         private String code;     }
  }
   | 
 
qa(问答模型)

很快就不能用了,而且跟聊天模型差不多,就不贴代码了
QAChoice

QACompletionRequest
model和prompt参数必要,其他参数可选

QACompletionResponse

interceptor
OpenAiInterceptor(自定义拦截器)
auth方法将token参数加入url对象,返回一个新的请求,intercept对该请求进行预处理,然后将处理后的请求传递给下一个拦截器(或目标方法)继续处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
   | public class OpenAiInterceptor implements Interceptor {
           private String apiKey;          private String authToken;
      public OpenAiInterceptor(String apiKey, String authToken) {         this.apiKey = apiKey;         this.authToken = authToken;     }
      @NotNull     @Override     public Response intercept(Chain chain) throws IOException {         return chain.proceed(this.auth(apiKey, chain.request()));     }
      private Request auth(String apiKey, Request original) {                  HttpUrl url = original.url().newBuilder()                 .addQueryParameter("token", authToken)                 .build();
                   return original.newBuilder()                 .url(url)                 .header(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey)                 .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())                 .method(original.method(), original.body())                 .build();     }
  }
  | 
 
session
IOpenAiApi
定义访问接口,传入请求
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
   | public interface IOpenAiApi {
 
      
 
 
 
      
 
 
 
      @POST("v1/completions")     Single<QACompletionResponse> completions(@Body QACompletionRequest qaCompletionRequest);
      
 
 
 
      @POST("v1/chat/completions")     Single<ChatCompletionResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);
  }
  | 
 
OpenAiSession
会话接口
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
   | public interface OpenAiSession {
      
 
 
 
      QACompletionResponse completions(QACompletionRequest qaCompletionRequest);
      
 
 
 
      QACompletionResponse completions(String question);
      
 
 
 
      ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest);
  }
  | 
 
OpenAiSessionFactory
会话工厂接口
1 2 3 4 5
   | public interface OpenAiSessionFactory {
      OpenAiSession openSession();
  }
  | 
 
Configuration
配置类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
   | @Slf4j @Data @NoArgsConstructor @AllArgsConstructor public class Configuration {
      @Getter     @NotNull     private String apiKey;
      @Getter     private String apiHost;
      @Getter
      private String authToken;
  }
   | 
 
DefaultOpenAiSession
实现OpenAiSession接口
blockingGet()是RxJava中Single中的方法,用于将当前线程阻塞,这里的作用是将异步计算的结果转换为同步结果,使得调用这个方法的线程会等待异步计算完成后才继续执行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
   | public class DefaultOpenAiSession implements OpenAiSession {
      private IOpenAiApi openAiApi;
      public DefaultOpenAiSession(IOpenAiApi openAiApi) {         this.openAiApi = openAiApi;     }
      @Override     public QACompletionResponse completions(QACompletionRequest qaCompletionRequest) {         return this.openAiApi.completions(qaCompletionRequest).blockingGet();     }
      @Override     public QACompletionResponse completions(String question) {         QACompletionRequest request = QACompletionRequest                 .builder()                 .prompt(question)                 .build();         Single<QACompletionResponse> completions = this.openAiApi.completions(request);         return completions.blockingGet();     }
      @Override     public ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest) {         return this.openAiApi.completions(chatCompletionRequest).blockingGet();     }
  }
  | 
 
DefaultOpenAiSessionFactory
实现OpenAiSessionFactory接口,其实实现的是IOpenAiApi接口,返回一个DefaultOpenAiSession(openAiApi)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
   | public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
      private final Configuration configuration;
      public DefaultOpenAiSessionFactory(Configuration configuration) {         this.configuration = configuration;     }
      @Override     public OpenAiSession openSession() {                  HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();         httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
                   OkHttpClient okHttpClient = new OkHttpClient                 .Builder()                 .addInterceptor(httpLoggingInterceptor)                 .addInterceptor(new OpenAiInterceptor(configuration.getApiKey(), configuration.getAuthToken()))                 .connectTimeout(450, TimeUnit.SECONDS)                 .writeTimeout(450, TimeUnit.SECONDS)                 .readTimeout(450, TimeUnit.SECONDS)                 .build();
                   IOpenAiApi openAiApi = new Retrofit.Builder()                 .baseUrl(configuration.getApiHost())                 .client(okHttpClient)
 
 
 
                  .addCallAdapterFactory(RxJava2CallAdapterFactory.create())                 .addConverterFactory(JacksonConverterFactory.create())                 .build().create(IOpenAiApi.class);
          return new DefaultOpenAiSession(openAiApi);     }
  }
  | 
 
单元测试
传url,key,token和request就行

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
   | @Slf4j public class ApiTest {
      private OpenAiSession openAiSession;
      @Before     public void test_OpenAiSessionFactory() {                  Configuration configuration = new Configuration();         configuration.setApiHost("https://api.openai-proxy.com/");         configuration.setApiKey("xxx");         configuration.setAuthToken("xxx");                  OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);                  this.openAiSession = factory.openSession();     }
      
 
      @Test     public void test_chat_completions() {                  ChatCompletionRequest chatCompletion = ChatCompletionRequest                 .builder()                 .messages(Collections.singletonList(Message.builder().role(Constants.Role.USER).content("写一个java冒泡排序").build()))                 .model(ChatCompletionRequest.Model.GPT_3_5_TURBO.getCode())                 .build();                  ChatCompletionResponse chatCompletionResponse = openAiSession.completions(chatCompletion);                  chatCompletionResponse.getChoices().forEach(e -> {             log.info("测试结果:{}", e.getMessage());         });     }
  }
   |