Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Azure OpenAI #42

Merged
merged 5 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
<maven.compiler.release>11</maven.compiler.release>
<!-- Dependencies Versions -->
<slf4j.version>[2.0.9,3.0.0)</slf4j.version>
<cleverclient.version>0.13.0</cleverclient.version>
<cleverclient.version>1.1.0</cleverclient.version>
<lombok.version>[1.18.30,2.0.0)</lombok.version>
<jackson.version>[2.15.2,3.0.0)</jackson.version>
<json.schema.version>[4.31.1,5.0.0)</json.schema.version>
Expand Down
16 changes: 14 additions & 2 deletions src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.github.sashirestela.openai.demo;

import io.github.sashirestela.cleverclient.http.HttpRequestData;
import io.github.sashirestela.openai.SimpleOpenAI;
import java.util.ArrayList;
import java.util.List;

import io.github.sashirestela.openai.SimpleOpenAI;
import java.util.function.UnaryOperator;
import lombok.NonNull;

public abstract class AbstractDemo {

Expand All @@ -23,6 +25,16 @@ protected AbstractDemo() {
.build();
}

protected AbstractDemo(@NonNull String baseUrl,
@NonNull String apiKey,
@NonNull UnaryOperator<HttpRequestData> requestInterceptor) {
openAI = SimpleOpenAI.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.requestInterceptor(requestInterceptor)
.build();
}

public void addTitleAction(String title, Action action) {
titleActions.add(new TitleAction(title, action));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.github.sashirestela.openai.demo;


import io.github.sashirestela.cleverclient.support.ContentType;
import io.github.sashirestela.openai.domain.chat.ChatRequest;
import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem;
import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser;
import java.util.Map;
import java.util.Optional;

public class AzureOpenAIChatServiceDemo extends AbstractDemo {
private static final String AZURE_OPENAI_API_KEY_HEADER = "api-key";
private final ChatRequest chatRequest;

@SuppressWarnings("unchecked")
public AzureOpenAIChatServiceDemo(String baseUrl, String apiKey, String model) {
super(baseUrl, apiKey, request -> {
var url = request.getUrl();
var contentType = request.getContentType();
var body = request.getBody();

// add a header to the request
var headers = request.getHeaders();
headers.put(AZURE_OPENAI_API_KEY_HEADER, apiKey);
request.setHeaders(headers);

// add a query parameter to url
url += (url.contains("?") ? "&" : "?") + "api-version=2023-05-15";
// remove '/vN' or '/vN.M' from url
url = url.replaceFirst("(\\/v\\d+\\.*\\d*)", "");
request.setUrl(url);

if (contentType != null) {
if (contentType.equals(ContentType.APPLICATION_JSON)) {
var bodyJson = (String) request.getBody();
// remove a field from body (as Json)
bodyJson = bodyJson.replaceFirst(",?\"model\":\"[^\"]*\",?", "");
bodyJson = bodyJson.replaceFirst("\"\"", "\",\"");
body = bodyJson;
}
if (contentType.equals(ContentType.MULTIPART_FORMDATA)) {
Map<String, Object> bodyMap = (Map<String, Object>) request.getBody();
the-gigi marked this conversation as resolved.
Show resolved Hide resolved
// remove a field from body (as Map)
bodyMap.remove("model");
body = bodyMap;
}
request.setBody(body);
}

return request;
});

chatRequest = ChatRequest.builder()
.model(model)
.message(new ChatMsgSystem("You are an expert in AI."))
.message(
new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words."))
.temperature(0.0)
.maxTokens(300)
.build();
}

public void demoCallChatBlocking() {
var futureChat = openAI.chatCompletions().create(chatRequest);
var chatResponse = futureChat.join();
System.out.println(chatResponse.firstContent());
}

public static void main(String[] args) {
var baseUrl = System.getenv("CUSTOM_OPENAI_BASE_URL");
var apiKey = System.getenv("CUSTOM_OPENAI_API_KEY");
// Services like Azure OpenAI don't require a model (endpoints have built-in model)
var model = Optional.ofNullable(System.getenv("CUSTOM_OPENAI_MODEL"))
.orElse("N/A");
var demo = new AzureOpenAIChatServiceDemo(baseUrl, apiKey, model);

demo.addTitleAction("Call Completion (Blocking Approach)", demo::demoCallChatBlocking);

demo.run();
}
}
2 changes: 1 addition & 1 deletion src/main/java/io/github/sashirestela/openai/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ default CompletableFuture<Stream<ChatResponse>> createStream(@Body ChatRequest c

/**
* Given a prompt, the model will return one or more predicted completions. It
* is recommend most users to use the Chat Completion.
* is recommended for most users to use the Chat Completion.
*
* @see <a href=
* "https://platform.openai.com/docs/api-reference/completions">OpenAI
Expand Down
20 changes: 11 additions & 9 deletions src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package io.github.sashirestela.openai;

import io.github.sashirestela.cleverclient.CleverClient;
import io.github.sashirestela.cleverclient.http.HttpRequestData;
import java.net.http.HttpClient;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Optional;

import io.github.sashirestela.cleverclient.CleverClient;
import java.util.function.UnaryOperator;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
Expand All @@ -30,6 +31,7 @@ public class SimpleOpenAI {
private final String baseUrl;
@Deprecated
private final String urlBase = null;

private HttpClient httpClient;

private CleverClient cleverClient;
Expand Down Expand Up @@ -86,26 +88,26 @@ public SimpleOpenAI(
String organizationId,
String baseUrl,
String urlBase,
HttpClient httpClient) {
HttpClient httpClient,
UnaryOperator<HttpRequestData> requestInterceptor) {
this.apiKey = apiKey;
this.organizationId = organizationId;
this.baseUrl = Optional.ofNullable(baseUrl)
.orElse(Optional.ofNullable(urlBase).orElse(OPENAI_BASE_URL));

this.httpClient = Optional.ofNullable(httpClient).orElse(HttpClient.newHttpClient());

var headers = new ArrayList<String>();
headers.add(AUTHORIZATION_HEADER);
headers.add(BEARER_AUTHORIZATION + apiKey);
var headers = new HashMap<String, String>();
headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey);
if (organizationId != null) {
headers.add(ORGANIZATION_HEADER);
headers.add(organizationId);
headers.put(ORGANIZATION_HEADER, organizationId);
}
this.cleverClient = CleverClient.builder()
.httpClient(this.httpClient)
.baseUrl(this.baseUrl)
.headers(headers)
.endOfStream(END_OF_STREAM)
.requestInterceptor(requestInterceptor)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;

import io.github.sashirestela.cleverclient.util.Constant;
import io.github.sashirestela.openai.SimpleUncheckedException;

public class JsonSchemaUtil {

public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}";
private static ObjectMapper objectMapper = new ObjectMapper();

private JsonSchemaUtil() {
Expand All @@ -39,7 +39,7 @@ public static JsonNode classToJsonSchema(Class<?> clazz) {
}
} else {
try {
jsonSchema = objectMapper.readTree(Constant.JSON_EMPTY_CLASS);
jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
} catch (JsonProcessingException e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.",
clazz.getName(), e);
Expand Down
22 changes: 11 additions & 11 deletions src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void shouldNotAddOrganizationToHeadersWhenBuilderIsCalledWithoutOrganizationId()
var openAI = SimpleOpenAI.builder()
.apiKey("apiKey")
.build();
assertFalse(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
assertFalse(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}

@Test
Expand All @@ -108,7 +108,7 @@ void shouldAddOrganizationToHeadersWhenBuilderIsCalledWithOrganizationId() {
.apiKey("apiKey")
.organizationId("orgId")
.build();
assertTrue(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
assertTrue(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}

@Test
Expand Down Expand Up @@ -165,7 +165,7 @@ void shouldInstanceAudioServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Audios.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.audios());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -175,7 +175,7 @@ void shouldInstanceChatCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.ChatCompletions.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.chatCompletions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -185,7 +185,7 @@ void shouldInstanceCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Completions.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.completions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -195,7 +195,7 @@ void shouldInstanceEmbeddingServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Embeddings.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.embeddings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -205,7 +205,7 @@ void shouldInstanceFilesServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Files.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.files());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -215,7 +215,7 @@ void shouldInstanceFineTunningServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.FineTunings.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.fineTunings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -225,7 +225,7 @@ void shouldInstanceImageServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Images.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.images());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -235,7 +235,7 @@ void shouldInstanceModelsServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Models.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.models());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand All @@ -245,7 +245,7 @@ void shouldInstanceModerationServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Moderations.class,
new HttpProcessor(null, null, null)));
HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.moderations());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package io.github.sashirestela.openai.support;

import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS;
import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.jupiter.api.Test;

import com.fasterxml.jackson.annotation.JsonProperty;

import io.github.sashirestela.cleverclient.util.Constant;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.Test;

class JsonSchemaUtilTest {

Expand All @@ -24,7 +22,7 @@ void shouldGenerateFullJsonSchemaWhenClassHasSomeFields() {
@Test
void shouldGenerateEmptyJsonSchemaWhenClassHasNoFields() {
var actualJsonSchema = JsonSchemaUtil.classToJsonSchema(EmptyClass.class).toString();
var expectedJsonSchema = Constant.JSON_EMPTY_CLASS;
var expectedJsonSchema = JSON_EMPTY_CLASS;
assertEquals(expectedJsonSchema, actualJsonSchema);
}

Expand Down
Loading