Skip to content

Commit

Permalink
Implement TCP server mode.
Browse files Browse the repository at this point in the history
This new mode works by first loading the model then listening for TCP
connections on a port. When a connection is received, arguments will be
parsed using a simple protocol:

- First the number of arguments will be read followed by a newline
  character.
- Then each argument will be read, separated by the 0 byte.
- With this we build an argument vector, similar to what is passed to
  the program entry point. We pass this to gpt_params_parse.

Finally `llama_main` will be executed with the input/output streams
connected to the socket.

Signed-off-by: Thiago Padilha <[email protected]>
  • Loading branch information
tarruda committed Mar 19, 2023
1 parent 9712739 commit 5e65c52
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ add_executable(llama
llama.cpp
utils.h)

if(NOT WIN32)
target_sources(llama PRIVATE tcp_server.cpp)
endif()

add_executable(quantize
quantize.cpp
utils.cpp
Expand Down
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,14 @@ utils.o: utils.cpp utils.h
llama.o: llama.cpp llama.h
$(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o

tcp_server.o: tcp_server.cpp tcp_server.h
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o

clean:
rm -f *.o main quantize

main: main.cpp ggml.o utils.o llama.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o -o main $(LDFLAGS)
main: main.cpp ggml.o utils.o llama.o tcp_server.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o tcp_server.o -o main $(LDFLAGS)
./main -h

quantize: quantize.cpp ggml.o utils.o
Expand Down
42 changes: 42 additions & 0 deletions chat_tcp_client.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env bash

PORT=${PORT:-8080}
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
User:"}"
RPROMPT="${RPROMPT:-"User:"}"
N_PREDICT="${N_PREDICT:-"4096"}"
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"

# Open connection to the chat server
exec 3<>/dev/tcp/127.0.0.1/${PORT}

# Pass the arguments. The protocol is really simple:
# 1. Pass the number of arguments followed by a linefeed
# 2. Pass the arguments, with each being followed by "0"
(
echo -en "10\n"
echo -en "-n\x00"
echo -en "$N_PREDICT\x00"
echo -en "--repeat_penalty\x00"
echo -en "$REPEAT_PENALTY\x00"
echo -en "--color\x00"
echo -en "-i\x00"
echo -en "-r\x00"
echo -en "$RPROMPT\x00"
echo -en "-p\x00"
echo -en "$PROMPT\x00"
) >&3

trap exit TERM

# When we have passed the arguments, start printing socket data to the screen.
# This is done in a background job because we also want to send data when
# running in interactive mode.
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
cat >&3
wait
6 changes: 6 additions & 0 deletions chat_tcp_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

PORT=${PORT:-8080}
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}

./main -l ${PORT} -m $MODEL
7 changes: 7 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ggml.h"
#include "utils.h"
#include "llama.h"
#include "tcp_server.h"

#include <iostream>

Expand Down Expand Up @@ -65,5 +66,11 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}

#ifndef _WIN32
if (params.listen_port != "") {
return listen_tcp(params, vocab, model, t_main_start_us, t_load_us);
}
#endif

return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
}
245 changes: 245 additions & 0 deletions tcp_server.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#include "tcp_server.h"

#include <iostream>

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>

#include <signal.h>
#include <unistd.h>
#include <sys/wait.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>

class PosixStream : public std::istream {
public:
PosixStream(int fd) : std::istream(&buf), buf(fd) {}
~PosixStream() { close(buf.get_fd()); }

private:
class PosixStreamBuf : public std::streambuf {
public:
PosixStreamBuf(int fd) : fd(fd) {}
int get_fd() const { return fd; }

protected:
virtual int_type underflow() {
if (gptr() < egptr()) {
return traits_type::to_int_type(*gptr());
}

ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE);
if (num_read <= 0) {
return traits_type::eof();
}

setg(buffer, buffer, buffer + num_read);
return traits_type::to_int_type(*gptr());
}

private:
static const int BUFFER_SIZE = 1024;
int fd;
char buffer[BUFFER_SIZE];
};

PosixStreamBuf buf;
};

void die(const char *msg, ...)
{
va_list ap;

va_start(ap, msg);
vfprintf(stderr, msg, ap);
va_end(ap);
fputc('\n', stderr);
exit(1);
}

static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
bool done = false;
uint8_t *buf = *param_buf;
size_t bufsize = *param_buf_size;
size_t bufpos = 0;
while (!done) {
if (bufpos == bufsize) {
bufsize += 1024;
buf = (uint8_t *)realloc(buf, bufsize);
if (!buf) {
die("failed to allocate memory");
}
}

int c = fgetc(instream);
if (c == EOF) {
die("unexpected EOF client socket");
}
buf[bufpos++] = (uint8_t)c;
if (c == 0) {
// done reading argument
break;
}
}
*param_buf = buf;
*param_buf_size = bufsize;
return strdup((char *)buf);
}

static int read_arguments(int argc, char **argv, FILE *instream) {
int i = 1;
size_t param_buf_size = 0;
uint8_t *param_buf = nullptr;

for (i = 1; i < argc; i++) {
argv[i] = read_argument(&param_buf, &param_buf_size, instream);
}

free(param_buf);
return i;
}

static int serve_model(
gpt_params params,
gpt_vocab vocab,
llama_model model,
int64_t t_load_us,
int64_t t_main_start_us,
int sock_fd)
{
char *response_data;
int argc;
char **argv;
FILE *instream = fdopen(sock_fd, "r");
FILE *outstream = fdopen(sock_fd, "w");
setvbuf(instream, NULL, _IONBF, 0);

// start by reading the parameter count
if (fscanf(instream, "%d\n", &argc) != 1) {
fprintf(outstream, "Error: First line must be character count\n");
fflush(outstream);
return 1;
}

argc += 1; // add one extra argument to emulate the program command line
argv = (char **)malloc(argc * sizeof *argv);
argv[0] = nullptr;
if (read_arguments(argc, argv, instream) != argc) {
fprintf(outstream, "Error: Failed to read arguments\n");
fflush(outstream);
}

if (gpt_params_parse(argc, argv, params) == false) {
fprintf(outstream, "Error: Failed to parse parameters\n");
fflush(outstream);
return 1;
}

for (int i = 1; i < argc; i++) {
free(argv[i]);
}
free(argv);

PosixStream tcp_is(sock_fd);

return llama_main(params, vocab, model, t_load_us, t_main_start_us, tcp_is, outstream, outstream);
}

int listen_tcp(
gpt_params params,
gpt_vocab vocab,
llama_model model,
int64_t t_main_start_us,
int64_t t_load_us) {
int listen_fd;
int status;
pid_t child;
struct addrinfo hints;
struct addrinfo *servinfo, *p;
int yes = 1;

memset(&hints, 0, sizeof hints);
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;

// This should only ever listen on a loopback address. Access from outside
// should be proxied via nginx or similar software
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
if (status) {
die("getaddrinfo error: %s", gai_strerror(status));
}

// bind to the first addrinfo we can from the getaddrinfo results
for (p = servinfo; p != NULL; p = p->ai_next) {
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (listen_fd == -1) {
perror("server: socket");
continue;
}

if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes)) {
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
}

if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
break;
}

close(listen_fd);
perror("server: bind");
}

freeaddrinfo(servinfo);

if (p == NULL) {
die("failed to bind: %s", strerror(errno));
}

if (listen(listen_fd, 20)) {
die("listen error: %s", strerror(errno));
}
// Don't track child processes, so ignore SIGCHLD to prevent zombies
signal(SIGCHLD, SIG_IGN);

for (;;) {
struct sockaddr_in client_addr = {0};
socklen_t client_addr_len = 0;

int sock_fd = accept(listen_fd,
(struct sockaddr *)&client_addr,
&client_addr_len);
if (sock_fd < 0) {
fprintf(stderr, "accept error: %s\n", strerror(errno));
break;
}

child = fork();
if (child == 0) {
// close the listen_fd since we won't use it in the child
close(listen_fd);
int ret = serve_model(params, vocab, model, t_main_start_us, t_load_us, sock_fd);
close(sock_fd);
return ret;
} else {
// close the client since we won't use it in the server
close(sock_fd);
sock_fd = 0;
}
}
close(listen_fd);

// ignore SIGTERM since we'll send it to the group
signal(SIGTERM, SIG_IGN);
// tell children to exit
kill(0, SIGTERM);
// wait for children to terminate
wait(&status);
return 0;
}
11 changes: 11 additions & 0 deletions tcp_server.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include "utils.h"
#include "llama.h"

int listen_tcp(
gpt_params params,
gpt_vocab vocab,
llama_model model,
int64_t t_main_start_us,
int64_t t_load_us);
8 changes: 8 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
#ifndef _WIN32
} else if (arg == "-l" || arg == "--listen") {
params.listen_port = argv[++i];
#endif
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -118,6 +122,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
#ifndef _WIN32
fprintf(stderr, " -l PORT, --listen PORT\n");
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
#endif
fprintf(stderr, "\n");
}

Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos

#ifndef _WIN32
std::string listen_port = ""; // TCP port for when running in server mode
#endif
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down

0 comments on commit 5e65c52

Please sign in to comment.