Skip to content

Commit

Permalink
refactor(core-jdk8): Agent State Management
Browse files Browse the repository at this point in the history
- AgentState from interface to concrete class
- AppendableValue a readonly interface
- Create internal AppendableValueRW to update state
  • Loading branch information
bsorrentino committed May 13, 2024
1 parent 0d7d09f commit 7e19f1e
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 76 deletions.
26 changes: 1 addition & 25 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/GraphState.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
import org.bsc.langgraph4j.state.AppendableValue;

import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static java.lang.String.format;
Expand Down Expand Up @@ -80,27 +77,6 @@ public class Runnable {
);
}

private Object mergeFunction(Object currentValue, Object newValue) {
if (currentValue instanceof AppendableValue<?> ) {
((AppendableValue<?>) currentValue).append( newValue );
return currentValue;
}
return newValue;
}
private State mergeState( State currentState, Map<String,Object> partialState) {
Objects.requireNonNull(currentState, "currentState");

if( partialState == null || partialState.isEmpty() ) {
return currentState;
}
var mergedMap = Stream.concat(currentState.data().entrySet().stream(), partialState.entrySet().stream())
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
this::mergeFunction));

return stateFactory.apply(mergedMap);
}

private String nextNodeId( String nodeId , State state ) throws Exception {

Expand Down Expand Up @@ -145,7 +121,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro

partialState = action.apply(currentState).get();

currentState = mergeState(currentState, partialState);
currentState = currentState.mergeWith(partialState, stateFactory);

var data = new NodeOutput<>(currentNodeId, currentState);

Expand Down
65 changes: 57 additions & 8 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java
Original file line number Diff line number Diff line change
@@ -1,21 +1,70 @@
package org.bsc.langgraph4j.state;

import java.util.List;
import java.util.Optional;
import lombok.var;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.unmodifiableMap;
import static java.util.Optional.ofNullable;

public interface AgentState {
public class AgentState {

private final java.util.Map<String,Object> data;

java.util.Map<String,Object> data();
public AgentState( Map<String,Object> initData ) {
this.data = new HashMap<>(initData);
}
public final java.util.Map<String,Object> data() {
return unmodifiableMap(data);
}

default <T> Optional<T> value(String key) {
public final <T> Optional<T> value(String key) {
return ofNullable((T) data().get(key));
};

default <T> Optional<List<T>> appendableValue(String key ) {
return ofNullable( ((AppendableValue<T>)data().get(key)))
.map(AppendableValue::values);
public final <T> AppendableValue<T> appendableValue(String key ) {
Object value = this.data.get(key);

if( value instanceof AppendableValue ) {
return (AppendableValue<T>) value;
}
if( value instanceof Collection) {
return new AppendableValueRW<>((Collection<T>)value);
}
AppendableValueRW<T> rw = new AppendableValueRW<>();
if ( value != null ) {
rw.append(value);
}
this.data.put(key, rw);
return rw;

}

private Object mergeFunction(Object currentValue, Object newValue) {
if (currentValue instanceof AppendableValueRW<?>) {
((AppendableValueRW<?>) currentValue).append( newValue );
return currentValue;
}
return newValue;
}
public <State extends AgentState> State mergeWith(Map<String,Object> partialState, AgentStateFactory<State> factory) {

if( partialState == null || partialState.isEmpty() ) {
return factory.apply(data());
}
var mergedMap = Stream.concat(data().entrySet().stream(), partialState.entrySet().stream())
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
this::mergeFunction));

return factory.apply(mergedMap);
}

@Override
public String toString() {
return data.toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
package org.bsc.langgraph4j.state;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.*;

import static java.util.Collections.unmodifiableList;

public class AppendableValue<T> {
private final List<T> values;
public AppendableValue( List<T> values) {
this.values = new ArrayList<>(values);
}
public AppendableValue() {
this(Collections.emptyList());
}
public interface AppendableValue<T> {

public List<T> values() {
return unmodifiableList(values);
}
public void append(Object value) {
if (value instanceof Collection ) {
this.values.addAll((Collection<? extends T>) value);
}
else {
this.values.add((T)value);
}
}
List<T> values();

public String toString() {
return String.valueOf(values);
}
boolean isEmpty() ;
int size() ;

Optional<T> last() ;
Optional<T> lastMinus( int n ) ;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.bsc.langgraph4j.state;

import java.util.*;

import static java.util.Collections.unmodifiableList;

public class AppendableValueRW<T> implements AppendableValue<T> {
private final List<T> values;

public AppendableValueRW( Collection<T> values) {
this.values = new ArrayList<>(values);
}
public AppendableValueRW() {
this(Collections.emptyList());
}
public void append(Object value) {
if (value instanceof Collection) {
this.values.addAll((Collection<? extends T>) value);
}
else {
this.values.add((T)value);
}
}

public List<T> values() {
return unmodifiableList(values);
}

public boolean isEmpty() {
return values().isEmpty();
}
public int size() {
return values().size();
}
public Optional<T> last() {
List<T> values = values();
return ( values == null || values.isEmpty() ) ? Optional.empty() : Optional.of(values.get(values.size()-1));
}
public Optional<T> lastMinus( int n ) {
if( values == null || values.isEmpty() ) return Optional.empty();
if( n < 0 ) return Optional.empty();
if( values.size() - n - 1 < 0 ) return Optional.empty();
return Optional.of(values.get(values.size()-n-1));
}

public String toString() {
return String.valueOf(values);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,11 @@ public static <K,V> Map<K,V> mapOf( K k1, V v1, K k2, V v2 ) {
result.put(k2,v2);
return unmodifiableMap(result);
}
public static <K,V> Map<K,V> mapOf( K k1, V v1, K k2, V v2, K k3, V v3 ) {
Map<K,V> result = new HashMap<K,V>();
result.put(k1,v1);
result.put(k2,v2);
result.put(k3,v3);
return unmodifiableMap(result);
}
}
15 changes: 0 additions & 15 deletions core-jdk8/src/test/java/org/bsc/langgraph4j/BaseAgentState.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.bsc.langgraph4j;

import lombok.var;
import org.bsc.langgraph4j.state.AgentState;
import org.junit.jupiter.api.Test;

import java.util.List;
Expand All @@ -26,7 +27,7 @@ public static <T> List<Map.Entry<String,T>> sortMap(Map<String,T> map ) {
@Test
void testValidation() throws Exception {

var workflow = new GraphState<>(BaseAgentState::new);
var workflow = new GraphState<>(AgentState::new);
var exception = assertThrows(GraphStateException.class, workflow::compile);
System.out.println(exception.getMessage());
assertEquals( "missing Entry Point", exception.getMessage());
Expand Down Expand Up @@ -80,7 +81,7 @@ void testValidation() throws Exception {
@Test
public void testRunningOneNode() throws Exception {

var workflow = new GraphState<>(BaseAgentState::new);
var workflow = new GraphState<>(AgentState::new);
workflow.setEntryPoint("agent_1");

workflow.addNode("agent_1", node_async( state -> {
Expand Down

0 comments on commit 7e19f1e

Please sign in to comment.