Skip to content

Commit

Permalink
refactor(server): enable use of StateGraph
Browse files Browse the repository at this point in the history
work on #24
  • Loading branch information
bsorrentino committed Sep 11, 2024
1 parent 39da1f4 commit 16aefea
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.last;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

@Slf4j( topic="DiagramCorrectionProcess" )
Expand Down Expand Up @@ -44,7 +45,7 @@ CompletableFuture<Map<String,Object>> reviewResult(State state) {
CompletableFuture<Map<String,Object>> future = new CompletableFuture<>();
try {

var diagramCode = state.diagramCode().last()
var diagramCode = last( state.diagramCode() )
.orElseThrow(() -> new IllegalArgumentException("no diagram code provided!"));

var error = state.evaluationError()
Expand All @@ -71,7 +72,7 @@ CompletableFuture<Map<String,Object>> reviewResult(State state) {

private CompletableFuture<Map<String,Object>> evaluateResult(State state) {

var diagramCode = state.diagramCode().last()
var diagramCode = last( state.diagramCode() )
.orElseThrow(() -> new IllegalArgumentException("no diagram code provided!"));

return PlantUMLAction.validate( diagramCode )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,31 @@
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Map;
import java.util.Optional;
import java.util.*;

import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.utils.CollectionsUtils.*;

public interface ImageToDiagram {

class State extends AgentState {

static Map<String, Channel<?>> SCHEMA = mapOf(
"messages", AppenderChannel.<String>of(ArrayList::new)
);
public State(Map<String, Object> initData) {
super(initData);
}

public Optional<Diagram.Element> diagram() {
return value("diagram");
}
public AppendableValue<String> diagramCode() {
return appendableValue("diagramCode");
public List<String> diagramCode() {
return this.<List<String>>value("diagramCode").orElseGet(Collections::emptyList);
}
public Optional<ImageToDiagramProcess.EvaluationResult> evaluationResult() {
return value("evaluationResult" );
Expand All @@ -49,10 +51,10 @@ public boolean isExecutionError() {
public boolean lastTwoDiagramsAreEqual() {
if( diagramCode().size() < 2 ) return false;

String last = diagramCode().last()
String last = last( diagramCode() )
.map(String::trim)
.orElseThrow( () -> new IllegalStateException( "last() is null!" ) );
String prev = diagramCode().lastMinus(1)
String prev = lastMinus( diagramCode(), 1)
.map(String::trim)
.orElseThrow( () -> new IllegalStateException( "last(-1) is null!" ) );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static java.lang.String.format;
import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.utils.CollectionsUtils.last;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -136,7 +137,7 @@ public void imageToDiagram() throws Exception {
}

System.out.println( ofNullable(state)
.flatMap( s -> s.diagramCode().last() ).orElse("NO DIAGRAM CODE") );
.flatMap( s -> last( s.diagramCode() ) ).orElse("NO DIAGRAM CODE") );

}

Expand All @@ -160,7 +161,7 @@ public String reviewDiagram( String diagramId ) throws Exception {
})
.join();

var code = result.diagramCode().last();
var code = last( result.diagramCode() );
assertTrue( code.isPresent() );
assertEquals( expectedCode, code.get().trim() );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* of LangGraph.
* Implementations of this interface can be used to create a web server
* that exposes an API for interacting with compiled language graphs.
*/
*/
public interface LangGraphStreamingServer {

Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class);
Expand All @@ -45,7 +45,7 @@ static Builder builder() {

class Builder {
private int port = 8080;
private Map<String, ArgumentMetadata> inputArgs = new HashMap<>();
private final Map<String, ArgumentMetadata> inputArgs = new HashMap<>();
private String title = null;
private ObjectMapper objectMapper;

Expand Down Expand Up @@ -74,36 +74,39 @@ public Builder addInputStringArg(String name) {
return this;
}

public <State extends AgentState> LangGraphStreamingServer build(CompiledGraph<State> compiledGraph) throws Exception {
public <State extends AgentState> LangGraphStreamingServer build(StateGraph<State> stateGraph) throws Exception {

Server server = new Server();

ServerConnector connector = new ServerConnector(server);
connector.setPort(port);
server.addConnector(connector);

ResourceHandler resourceHandler = new ResourceHandler();
var resourceHandler = new ResourceHandler();

// Path publicResourcesPath = Paths.get("jetty", "src", "main", "webapp");
// Resource baseResource = ResourceFactory.of(resourceHandler).newResource(publicResourcesPath));
Resource baseResource = ResourceFactory.of(resourceHandler).newClassLoaderResource("webapp");
var baseResource = ResourceFactory.of(resourceHandler).newClassLoaderResource("webapp");
resourceHandler.setBaseResource(baseResource);

resourceHandler.setDirAllowed(true);

ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
var context = new ServletContextHandler(ServletContextHandler.SESSIONS);

if (objectMapper == null) {
objectMapper = new ObjectMapper();
}

context.setSessionHandler(new org.eclipse.jetty.ee10.servlet.SessionHandler());

var initData = new InitData(title, inputArgs);
context.addServlet(new ServletHolder(new GraphInitServlet<>(stateGraph, initData)), "/init");

// context.setContextPath("/");
// Add the streaming servlet
context.addServlet(new ServletHolder(new GraphExecutionServlet<State>(compiledGraph, objectMapper)), "/stream");
context.addServlet(new ServletHolder(new GraphExecutionServlet<State>(stateGraph, objectMapper)), "/stream");

InitData initData = new InitData(title, inputArgs);
context.addServlet(new ServletHolder(new GraphInitServlet<State>(compiledGraph, initData)), "/init");

Handler.Sequence handlerList = new Handler.Sequence(resourceHandler, context);
var handlerList = new Handler.Sequence( resourceHandler, context);

server.setHandler(handlerList);

Expand All @@ -127,12 +130,12 @@ public CompletableFuture<Void> start() throws Exception {


class GraphExecutionServlet<State extends AgentState> extends HttpServlet {
final CompiledGraph<State> compiledGraph;
final StateGraph<State> stateGraph;
final ObjectMapper objectMapper;

public GraphExecutionServlet(CompiledGraph<State> compiledGraph, ObjectMapper objectMapper) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
public GraphExecutionServlet(StateGraph<State> stateGraph, ObjectMapper objectMapper) {
Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
this.stateGraph = stateGraph;
this.objectMapper = objectMapper;
}

Expand All @@ -151,6 +154,10 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
var asyncContext = request.startAsync();

try {
var config = CompileConfig.builder().build();

var compiledGraph = stateGraph.compile(config);

compiledGraph.stream(dataMap)
.forEachAsync(s -> {
try {
Expand Down Expand Up @@ -197,7 +204,7 @@ record InitData(
*/
class GraphInitServlet<State extends AgentState> extends HttpServlet {

final CompiledGraph<State> compiledGraph;
final StateGraph<State> stateGraph;
final ObjectMapper objectMapper = new ObjectMapper();
final InitData initData;

Expand All @@ -212,9 +219,9 @@ public Result(GraphRepresentation graph, InitData initData) {
}
}

public GraphInitServlet(CompiledGraph<State> compiledGraph, InitData initData) {
Objects.requireNonNull(compiledGraph, "compiledGraph cannot be null");
this.compiledGraph = compiledGraph;
public GraphInitServlet(StateGraph<State> stateGraph, InitData initData) {
Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
this.stateGraph = stateGraph;
this.initData = initData;
}

Expand All @@ -223,7 +230,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t
response.setContentType("application/json");
response.setCharacterEncoding("UTF-8");

GraphRepresentation graph = compiledGraph.getGraph(GraphRepresentation.Type.MERMAID, initData.title(), false);
GraphRepresentation graph = stateGraph.getGraph(GraphRepresentation.Type.MERMAID, initData.title(), false);

final Result result = new Result(graph, initData);
String resultJson = objectMapper.writeValueAsString(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.bsc.langgraph4j.state.AgentState;

import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
Expand All @@ -12,26 +13,6 @@ public class LangGraphStreamingServerTest {


public static void main(String[] args) throws Exception {
StateGraph<AgentState> workflow = new StateGraph<>(AgentState::new);

workflow.setEntryPoint("agent_1");

workflow.addNode("agent_1", node_async((state ) -> {
System.out.println("agent_1 ");
System.out.println(state);
return mapOf("prop1", "value1");
}) ) ;

workflow.addNode("agent_2", node_async( state -> {

System.out.print( "agent_2: ");
System.out.println( state );
return mapOf("prop2", "value2");
}));

workflow.addEdge("agent_2", "agent_1" );


EdgeAction<AgentState> conditionalAge = new EdgeAction<>() {
int steps= 0;
@Override
Expand All @@ -44,16 +25,29 @@ public String apply(AgentState state) {
}
};

workflow.addConditionalEdges("agent_1",
edge_async(conditionalAge), mapOf( "a2", "agent_2", "end", END ) );

CompiledGraph<AgentState> app = workflow.compile();
StateGraph<AgentState> workflow = new StateGraph<>(AgentState::new)
.addNode("agent_1", node_async((state ) -> {
System.out.println("agent_1 ");
System.out.println(state);
return mapOf("prop1", "value1");
}) )
.addNode("agent_2", node_async( state -> {
System.out.print( "agent_2: ");
System.out.println( state );
return mapOf("prop2", "value2");
}))
.addEdge(START, "agent_1")
.addEdge("agent_2", "agent_1" )
.addConditionalEdges("agent_1",
edge_async(conditionalAge), mapOf( "a2", "agent_2", "end", END ) )
;

LangGraphStreamingServer server = LangGraphStreamingServer.builder()
.port(8080)
.title("LANGGRAPH4j - TEST")
.addInputStringArg("input")
.build(app);
.build(workflow);

server.start().join();

Expand Down

0 comments on commit 16aefea

Please sign in to comment.