Skip to content

Commit

Permalink
Merge pull request #42 from the-gigi/main
Browse files Browse the repository at this point in the history
Add support for Azure OpenAI
  • Loading branch information
sashirestela authored Feb 13, 2024
2 parents 9ca56a2 + 7eaadac commit 443dea1
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
<maven.compiler.release>11</maven.compiler.release>
<!-- Dependencies Versions -->
<slf4j.version>[2.0.9,3.0.0)</slf4j.version>
<cleverclient.version>0.13.0</cleverclient.version>
<cleverclient.version>1.1.0</cleverclient.version>
<lombok.version>[1.18.30,2.0.0)</lombok.version>
<jackson.version>[2.15.2,3.0.0)</jackson.version>
<json.schema.version>[4.31.1,5.0.0)</json.schema.version>
Expand Down
16 changes: 14 additions & 2 deletions src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.cleverclient.http.HttpRequestData;
import io.github.sashirestela.openai.SimpleOpenAI;
import java.util.ArrayList;
import java.util.List;

import io.github.sashirestela.openai.SimpleOpenAI;
import java.util.function.UnaryOperator;
import lombok.NonNull;

public abstract class AbstractDemo {

Expand All @@ -23,6 +25,16 @@ protected AbstractDemo() {
.build();
}

protected AbstractDemo(@NonNull String baseUrl,
@NonNull String apiKey,
@NonNull UnaryOperator<HttpRequestData> requestInterceptor) {
openAI = SimpleOpenAI.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.requestInterceptor(requestInterceptor)
.build();
}

public void addTitleAction(String title, Action action) {
titleActions.add(new TitleAction(title, action));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.github.sashirestela.openai.demo;


import io.github.sashirestela.cleverclient.support.ContentType;
import io.github.sashirestela.openai.domain.chat.ChatRequest;
import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem;
import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser;
import java.util.Map;
import java.util.Optional;

public class AzureOpenAIChatServiceDemo extends AbstractDemo {
private static final String AZURE_OPENAI_API_KEY_HEADER = "api-key";
private final ChatRequest chatRequest;

@SuppressWarnings("unchecked")
public AzureOpenAIChatServiceDemo(String baseUrl, String apiKey, String model) {
super(baseUrl, apiKey, request -> {
var url = request.getUrl();
var contentType = request.getContentType();
var body = request.getBody();

// add a header to the request
var headers = request.getHeaders();
headers.put(AZURE_OPENAI_API_KEY_HEADER, apiKey);
request.setHeaders(headers);

// add a query parameter to url
url += (url.contains("?") ? "&" : "?") + "api-version=2023-05-15";
// remove '/vN' or '/vN.M' from url
url = url.replaceFirst("(\\/v\\d+\\.*\\d*)", "");
request.setUrl(url);

if (contentType != null) {
if (contentType.equals(ContentType.APPLICATION_JSON)) {
var bodyJson = (String) request.getBody();
// remove a field from body (as Json)
bodyJson = bodyJson.replaceFirst(",?\"model\":\"[^\"]*\",?", "");
bodyJson = bodyJson.replaceFirst("\"\"", "\",\"");
body = bodyJson;
}
if (contentType.equals(ContentType.MULTIPART_FORMDATA)) {
Map<String, Object> bodyMap = (Map<String, Object>) request.getBody();
// remove a field from body (as Map)
bodyMap.remove("model");
body = bodyMap;
}
request.setBody(body);
}

return request;
});

chatRequest = ChatRequest.builder()
.model(model)
.message(new ChatMsgSystem("You are an expert in AI."))
.message(
new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words."))
.temperature(0.0)
.maxTokens(300)
.build();
}

public void demoCallChatBlocking() {
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
}

public static void main(String[] args) {
var baseUrl = System.getenv("CUSTOM_OPENAI_BASE_URL");
var apiKey = System.getenv("CUSTOM_OPENAI_API_KEY");
// Services like Azure OpenAI don't require a model (endpoints have built-in model)
var model = Optional.ofNullable(System.getenv("CUSTOM_OPENAI_MODEL"))
.orElse("N/A");
var demo = new AzureOpenAIChatServiceDemo(baseUrl, apiKey, model);

demo.addTitleAction("Call Completion (Blocking Approach)", demo::demoCallChatBlocking);

demo.run();
}
}
2 changes: 1 addition & 1 deletion src/main/java/io/github/sashirestela/openai/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ default CompletableFuture<Stream<ChatResponse>> createStream(@Body ChatRequest c

/**
* Given a prompt, the model will return one or more predicted completions. It
* is recommend most users to use the Chat Completion.
* is recommended for most users to use the Chat Completion.
*
* @see <a href=
* "https://platform.openai.com/docs/api-reference/completions">OpenAI
Expand Down
20 changes: 11 additions & 9 deletions src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package io.github.sashirestela.openai;

import io.github.sashirestela.cleverclient.CleverClient;
import io.github.sashirestela.cleverclient.http.HttpRequestData;
import java.net.http.HttpClient;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Optional;

import io.github.sashirestela.cleverclient.CleverClient;
import java.util.function.UnaryOperator;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
Expand All @@ -30,6 +31,7 @@ public class SimpleOpenAI {
private final String baseUrl;
@Deprecated
private final String urlBase = null;

private HttpClient httpClient;

private CleverClient cleverClient;
Expand Down Expand Up @@ -86,26 +88,26 @@ public SimpleOpenAI(
String organizationId,
String baseUrl,
String urlBase,
HttpClient httpClient) {
HttpClient httpClient,
UnaryOperator<HttpRequestData> requestInterceptor) {
this.apiKey = apiKey;
this.organizationId = organizationId;
this.baseUrl = Optional.ofNullable(baseUrl)
.orElse(Optional.ofNullable(urlBase).orElse(OPENAI_BASE_URL));

this.httpClient = Optional.ofNullable(httpClient).orElse(HttpClient.newHttpClient());

var headers = new ArrayList<String>();
headers.add(AUTHORIZATION_HEADER);
headers.add(BEARER_AUTHORIZATION + apiKey);
var headers = new HashMap<String, String>();
headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey);
if (organizationId != null) {
headers.add(ORGANIZATION_HEADER);
headers.add(organizationId);
headers.put(ORGANIZATION_HEADER, organizationId);
}
this.cleverClient = CleverClient.builder()
.httpClient(this.httpClient)
.baseUrl(this.baseUrl)
.headers(headers)
.endOfStream(END_OF_STREAM)
.requestInterceptor(requestInterceptor)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;

import io.github.sashirestela.cleverclient.util.Constant;
import io.github.sashirestela.openai.SimpleUncheckedException;

public class JsonSchemaUtil {

public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}";
private static ObjectMapper objectMapper = new ObjectMapper();

private JsonSchemaUtil() {
Expand All @@ -39,7 +39,7 @@ public static JsonNode classToJsonSchema(Class<?> clazz) {
}
} else {
try {
jsonSchema = objectMapper.readTree(Constant.JSON_EMPTY_CLASS);
jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
} catch (JsonProcessingException e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.",
clazz.getName(), e);
Expand Down
22 changes: 11 additions & 11 deletions src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void shouldNotAddOrganizationToHeadersWhenBuilderIsCalledWithoutOrganizationId()
var openAI = SimpleOpenAI.builder()
.apiKey("apiKey")
.build();
assertFalse(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
assertFalse(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}

@Test
Expand All @@ -108,7 +108,7 @@ void shouldAddOrganizationToHeadersWhenBuilderIsCalledWithOrganizationId() {
.apiKey("apiKey")
.organizationId("orgId")
.build();
assertTrue(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
assertTrue(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}

@Test
Expand Down Expand Up @@ -165,7 +165,7 @@ void shouldInstanceAudioServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Audios.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.audios());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -175,7 +175,7 @@ void shouldInstanceChatCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.ChatCompletions.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.chatCompletions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -185,7 +185,7 @@ void shouldInstanceCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Completions.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.completions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -195,7 +195,7 @@ void shouldInstanceEmbeddingServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Embeddings.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.embeddings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -205,7 +205,7 @@ void shouldInstanceFilesServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Files.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.files());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -215,7 +215,7 @@ void shouldInstanceFineTunningServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.FineTunings.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.fineTunings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -225,7 +225,7 @@ void shouldInstanceImageServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Images.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.images());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -235,7 +235,7 @@ void shouldInstanceModelsServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Models.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.models());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -245,7 +245,7 @@ void shouldInstanceModerationServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Moderations.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.moderations());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package io.github.sashirestela.openai.support;

import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS;
import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.jupiter.api.Test;

import com.fasterxml.jackson.annotation.JsonProperty;

import io.github.sashirestela.cleverclient.util.Constant;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.Test;

class JsonSchemaUtilTest {

Expand All @@ -24,7 +22,7 @@ void shouldGenerateFullJsonSchemaWhenClassHasSomeFields() {
@Test
void shouldGenerateEmptyJsonSchemaWhenClassHasNoFields() {
var actualJsonSchema = JsonSchemaUtil.classToJsonSchema(EmptyClass.class).toString();
var expectedJsonSchema = Constant.JSON_EMPTY_CLASS;
var expectedJsonSchema = JSON_EMPTY_CLASS;
assertEquals(expectedJsonSchema, actualJsonSchema);
}

Expand Down

0 comments on commit 443dea1

Please sign in to comment.