From 7fbae4a8af00d518a72fd928159ab48bf4c64878 Mon Sep 17 00:00:00 2001 From: jesperpedersen Date: Mon, 20 Apr 2020 07:20:45 -0400 Subject: [PATCH] [#70] TLS support: pgagroal - PostgreSQL --- doc/ARCHITECTURE.md | 1 + doc/CONFIGURATION.md | 4 + doc/man/pgagroal.conf.5.rst | 13 + src/include/message.h | 8 + src/include/pgagroal.h | 22 +- src/include/pool.h | 10 +- src/include/security.h | 26 +- src/include/server.h | 4 +- src/include/worker.h | 1 + src/libpgagroal/configuration.c | 42 +++ src/libpgagroal/management.c | 3 +- src/libpgagroal/message.c | 24 ++ src/libpgagroal/pipeline_session.c | 19 +- src/libpgagroal/pool.c | 94 +++++-- src/libpgagroal/security.c | 434 ++++++++++++++++++++++++++--- src/libpgagroal/server.c | 6 +- src/libpgagroal/worker.c | 11 +- 17 files changed, 633 insertions(+), 89 deletions(-) diff --git a/doc/ARCHITECTURE.md b/doc/ARCHITECTURE.md index 52f80235..50f4b527 100644 --- a/doc/ARCHITECTURE.md +++ b/doc/ARCHITECTURE.md @@ -139,6 +139,7 @@ struct worker_io int server_fd; /* The server descriptor */ int slot; /* The slot */ SSL* client_ssl; /* The client SSL context */ + SSL* server_ssl; /* The server SSL context */ void* shmem; /* The shared memory segment */ void* pipeline_shmem; /* The shared memory segment for the pipeline */ }; diff --git a/doc/CONFIGURATION.md b/doc/CONFIGURATION.md index 51fd8547..74f383bd 100644 --- a/doc/CONFIGURATION.md +++ b/doc/CONFIGURATION.md @@ -64,6 +64,10 @@ __Danger zone__ | host | | String | Yes | The address of the PostgreSQL instance | | port | | Int | Yes | The port of the PostgreSQL instance | | primary | | Bool | No | Identify the instance as primary (hint) | +| tls | `off` | Bool | No | Enable Transport Layer Security (TLS) towards PostgreSQL | +| tls_cert_file | | String | No | Certificate file for TLS | +| tls_key_file | | String | No | Private key file for TLS | +| tls_ca_file | | String | No | Certificate Authority (CA) file for TLS | # pgagroal_hba configuration diff --git a/doc/man/pgagroal.conf.5.rst b/doc/man/pgagroal.conf.5.rst index 0ad39284..48a4cb66 100644 --- a/doc/man/pgagroal.conf.5.rst +++ b/doc/man/pgagroal.conf.5.rst @@ -127,6 +127,19 @@ port primary Identify the instance as the primary instance (hint) +tls + Enable Transport Layer Security (TLS). Default is false + +tls_cert_file + Certificate file for TLS + +tls_key_file + Private key file for TLS + +tls_ca_file + Certificate Authority (CA) file for TLS + + REPORTING BUGS ============== diff --git a/src/include/message.h b/src/include/message.h index 9856a7e4..4829bbbc 100644 --- a/src/include/message.h +++ b/src/include/message.h @@ -335,6 +335,14 @@ pgagroal_create_auth_scram256_final(char* ss, struct message** msg); int pgagroal_write_auth_success(SSL* ssl, int socket); +/** + * Create a SSL message + * @param msg The resulting message + * @return 0 upon success, otherwise 1 + */ +int +pgagroal_create_ssl_message(struct message** msg); + /** * Create a startup message * @param username The user name diff --git a/src/include/pgagroal.h b/src/include/pgagroal.h index ede82090..420583c4 100644 --- a/src/include/pgagroal.h +++ b/src/include/pgagroal.h @@ -53,9 +53,10 @@ extern "C" { #define ZF_LOG_LEVEL ZF_LOG_DEBUG #endif -#define MAX_BUFFER_SIZE 65535 -#define DEFAULT_BUFFER_SIZE 65535 -#define SECURITY_BUFFER_SIZE 512 +#define MAX_BUFFER_SIZE 65535 +#define DEFAULT_BUFFER_SIZE 65535 +#define SECURITY_BUFFER_SIZE 512 +#define SSL_SESSION_BUFFER_SIZE 1024 #define IDENTIFIER_LENGTH 64 #define MISC_LENGTH 128 @@ -120,10 +121,14 @@ extern "C" { */ struct server { - char name[MISC_LENGTH]; /**< The name of the server */ - char host[MISC_LENGTH]; /**< The host name of the server */ - int port; /**< The port of the server */ - int primary; /**< The status of the server */ + char name[MISC_LENGTH]; /**< The name of the server */ + char host[MISC_LENGTH]; /**< The host name of the server */ + int port; /**< The port of the server */ + int primary; /**< The status of the server */ + bool tls; /**< Is TLS enabled */ + char tls_cert_file[MISC_LENGTH]; /**< TLS certificate path */ + char tls_key_file[MISC_LENGTH]; /**< TLS key path */ + char tls_ca_file[MISC_LENGTH]; /**< TLS CA certificate path */ } __attribute__ ((aligned (64))); /** @struct @@ -141,6 +146,9 @@ struct connection ssize_t security_lengths[NUMBER_OF_SECURITY_MESSAGES]; /**< The lengths of the security messages */ char security_messages[NUMBER_OF_SECURITY_MESSAGES][SECURITY_BUFFER_SIZE]; /**< The security messages */ + int ssl_session_length; /**< The length of the SSL session */ + char ssl_session[SSL_SESSION_BUFFER_SIZE]; /**< The SSL session (ASN.1) */ + signed char limit_rule; /**< The limit rule used */ time_t timestamp; /**< The last used timestamp */ pid_t pid; /**< The associated process id */ diff --git a/src/include/pool.h b/src/include/pool.h index 2f448bc8..09f9cbdd 100644 --- a/src/include/pool.h +++ b/src/include/pool.h @@ -37,6 +37,7 @@ extern "C" { #include #include +#include /** * Get a connection @@ -45,28 +46,31 @@ extern "C" { * @param database The database * @param reuse Should a slot be reused * @param slot The resulting slot + * @param ssl The resulting SSL (can be NULL) * @return 0 upon success, 1 if pool is full, otherwise 2 */ int -pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, int* slot); +pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, int* slot, SSL** ssl); /** * Return a connection * @param shmem The shared memory segment * @param slot The slot + * @param ssl The SSL connection * @return 0 upon success, otherwise 1 */ int -pgagroal_return_connection(void* shmem, int slot); +pgagroal_return_connection(void* shmem, int slot, SSL* ssl); /** * Kill a connection * @param shmem The shared memory segment * @param slot The slot + * @param ssl The SSL connection * @return 0 upon success, otherwise 1 */ int -pgagroal_kill_connection(void* shmem, int slot); +pgagroal_kill_connection(void* shmem, int slot, SSL* ssl); /** * Perform idle timeout diff --git a/src/include/security.h b/src/include/security.h index 4732d21b..19e1b0ce 100644 --- a/src/include/security.h +++ b/src/include/security.h @@ -46,10 +46,11 @@ extern "C" { * @param shmem The shared memory segment * @param slot The resulting slot * @param client_ssl The client SSL context + * @param server_ssl The server SSL context * @return 0 upon success, otherwise 1 */ int -pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL** client_ssl); +pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL** client_ssl, SSL** server_ssl); /** * Authenticate a prefill connection @@ -58,10 +59,11 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* * @param database The database * @param shmem The shared memory segment * @param slot The resulting slot + * @param server_ssl The resulting SSL context * @return 0 upon success, otherwise 1 */ int -pgagroal_prefill_auth(char* username, char* password, char* database, void* shmem, int* slot); +pgagroal_prefill_auth(char* username, char* password, char* database, void* shmem, int* slot, SSL** server_ssl); /** * Get the master key @@ -119,6 +121,26 @@ pgagroal_user_known(char* user, void* shmem); int pgagroal_tls_valid(void* shmem); +/** + * Load a SSL connection from a slot + * @param slot The slot + * @param shmem The shared memory segment + * @param ssl The resulting SSL connection (can be NULL) + * @return 0 upon success, otherwise 1 + */ +int +pgagroal_load_tls_connection(int slot, void* shmem, SSL** ssl); + +/** + * Save a TLS connection to a slot + * @param ssl The SSL connection + * @param slot The slot + * @param shmem The shared memory segment + * @return 0 upon success, otherwise 1 + */ +int +pgagroal_save_tls_connection(SSL* ssl, int slot, void* shmem); + #ifdef __cplusplus } #endif diff --git a/src/include/server.h b/src/include/server.h index e788354e..09828982 100644 --- a/src/include/server.h +++ b/src/include/server.h @@ -36,6 +36,7 @@ extern "C" { #include #include +#include /** * Get the primary server @@ -51,10 +52,11 @@ pgagroal_get_primary(void* shmem, int* server); * @param shmem The shared memory segment * @param slot The slot * @param socket The descriptor + * @param ssl The SSL connection * @return 0 upon success, otherwise 1 */ int -pgagroal_update_server_state(void* shmem, int slot, int socket); +pgagroal_update_server_state(void* shmem, int slot, int socket, SSL* ssl); /** * Print the state of the servers diff --git a/src/include/worker.h b/src/include/worker.h index ee8040e7..c2bbeddb 100644 --- a/src/include/worker.h +++ b/src/include/worker.h @@ -55,6 +55,7 @@ struct worker_io int server_fd; /**< The server descriptor */ int slot; /**< The slot */ SSL* client_ssl; /**< The client SSL context */ + SSL* server_ssl; /**< The server SSL context */ void* shmem; /**< The shared memory segment */ void* pipeline_shmem; /**< The shared memory segment for the pipeline */ }; diff --git a/src/libpgagroal/configuration.c b/src/libpgagroal/configuration.c index 99897aad..ef08485a 100644 --- a/src/libpgagroal/configuration.c +++ b/src/libpgagroal/configuration.c @@ -267,6 +267,10 @@ pgagroal_read_configuration(char* filename, void* shmem) { config->tls = as_bool(value); } + else if (strlen(section) > 0) + { + srv.tls = as_bool(value); + } else { unknown = true; @@ -281,6 +285,17 @@ pgagroal_read_configuration(char* filename, void* shmem) max = MISC_LENGTH - 1; memcpy(config->tls_ca_file, value, max); } + else if (strlen(section) > 0) + { + max = strlen(section); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.name, section, max); + max = strlen(value); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.tls_ca_file, value, max); + } else { unknown = true; @@ -295,6 +310,17 @@ pgagroal_read_configuration(char* filename, void* shmem) max = MISC_LENGTH - 1; memcpy(config->tls_cert_file, value, max); } + else if (strlen(section) > 0) + { + max = strlen(section); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.name, section, max); + max = strlen(value); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.tls_cert_file, value, max); + } else { unknown = true; @@ -309,6 +335,17 @@ pgagroal_read_configuration(char* filename, void* shmem) max = MISC_LENGTH - 1; memcpy(config->tls_key_file, value, max); } + else if (strlen(section) > 0) + { + max = strlen(section); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.name, section, max); + max = strlen(value); + if (max > MISC_LENGTH - 1) + max = MISC_LENGTH - 1; + memcpy(&srv.tls_key_file, value, max); + } else { unknown = true; @@ -660,6 +697,11 @@ pgagroal_validate_configuration(void* shmem) ZF_LOGF("pgagroal: No port defined for %s", config->servers[i].name); return 1; } + + if (config->servers[i].tls && (strlen(config->servers[i].tls_cert_file) > 0 || strlen(config->servers[i].tls_key_file) > 0)) + { + tls = true; + } } if (config->pipeline == PIPELINE_AUTO) diff --git a/src/libpgagroal/management.c b/src/libpgagroal/management.c index 213256f3..68ec5ac7 100644 --- a/src/libpgagroal/management.c +++ b/src/libpgagroal/management.c @@ -248,7 +248,8 @@ pgagroal_management_transfer_connection(void* shmem, int32_t slot) error: free(cmptr); pgagroal_disconnect(fd); - pgagroal_kill_connection(shmem, slot); + /* TODO */ + pgagroal_kill_connection(shmem, slot, NULL); return 1; } diff --git a/src/libpgagroal/message.c b/src/libpgagroal/message.c index 3cab7f28..8cb6e157 100644 --- a/src/libpgagroal/message.c +++ b/src/libpgagroal/message.c @@ -850,6 +850,30 @@ pgagroal_write_auth_success(SSL* ssl, int socket) return ssl_write_message(ssl, true, &msg); } +int +pgagroal_create_ssl_message(struct message** msg) +{ + struct message* m = NULL; + size_t size; + + size = 8; + + m = (struct message*)malloc(sizeof(struct message)); + m->data = malloc(size); + + memset(m->data, 0, size); + + m->kind = 0; + m->length = size; + + pgagroal_write_int32(m->data, size); + pgagroal_write_int32(m->data + 4, 80877103); + + *msg = m; + + return MESSAGE_STATUS_OK; +} + int pgagroal_create_startup_message(char* username, char* database, struct message** msg) { diff --git a/src/libpgagroal/pipeline_session.c b/src/libpgagroal/pipeline_session.c index f3172f75..9bef0997 100644 --- a/src/libpgagroal/pipeline_session.c +++ b/src/libpgagroal/pipeline_session.c @@ -212,7 +212,14 @@ session_client(struct ev_loop *loop, struct ev_io *watcher, int revents) { if (likely(msg->kind != 'X')) { - status = pgagroal_write_socket_message(wi->server_fd, msg); + if (wi->server_ssl == NULL) + { + status = pgagroal_write_socket_message(wi->server_fd, msg); + } + else + { + status = pgagroal_write_ssl_message(wi->server_ssl, msg); + } if (unlikely(status != MESSAGE_STATUS_OK)) { goto server_error; @@ -267,7 +274,15 @@ session_server(struct ev_loop *loop, struct ev_io *watcher, int revents) client_active(wi->slot, wi->pipeline_shmem); - status = pgagroal_read_socket_message(wi->server_fd, &msg); + if (wi->server_ssl == NULL) + { + status = pgagroal_read_socket_message(wi->server_fd, &msg); + } + else + { + status = pgagroal_read_ssl_message(wi->server_ssl, &msg); + } + if (likely(status == MESSAGE_STATUS_OK)) { if (wi->client_ssl == NULL) diff --git a/src/libpgagroal/pool.c b/src/libpgagroal/pool.c index 3c9a2a47..e8d92a3c 100644 --- a/src/libpgagroal/pool.c +++ b/src/libpgagroal/pool.c @@ -56,7 +56,7 @@ static void connection_details(void* shmem, int slot); static bool do_prefill(void* shmem, char* username, char* database, int size); int -pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, int* slot) +pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, int* slot, SSL** ssl) { bool prefill; bool do_init; @@ -83,6 +83,7 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, start: *slot = -1; + *ssl = NULL; do_init = false; has_lock = false; @@ -185,12 +186,18 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, } else { + SSL* s = NULL; bool kill = false; config->connections[*slot].limit_rule = best_rule; config->connections[*slot].pid = getpid(); config->connections[*slot].timestamp = time(NULL); + if (pgagroal_load_tls_connection(*slot, shmem, &s)) + { + kill = true; + } + /* Verify the socket for the slot */ if (!pgagroal_socket_isvalid(config->connections[*slot].fd)) { @@ -199,7 +206,7 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, if (!kill && config->validation == VALIDATION_FOREGROUND) { - kill = !pgagroal_connection_isvalid(config->connections[*slot].fd); + kill = !pgagroal_connection_isvalid(config->connections[*slot].fd); /* TODO */ } if (kill) @@ -207,7 +214,8 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, int status; ZF_LOGD("pgagroal_get_connection: Slot %d FD %d - Error", *slot, config->connections[*slot].fd); - status = pgagroal_kill_connection(shmem, *slot); + status = pgagroal_kill_connection(shmem, *slot, s); + s = NULL; prefill = true; if (status == 0) { @@ -218,6 +226,8 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, goto timeout; } } + + *ssl = s; } if (prefill) @@ -286,13 +296,15 @@ pgagroal_get_connection(void* shmem, char* username, char* database, bool reuse, } int -pgagroal_return_connection(void* shmem, int slot) +pgagroal_return_connection(void* shmem, int slot, SSL* ssl) { int state; struct configuration* config; config = (struct configuration*)shmem; + ZF_LOGI("pgagroal_return_connection: Slot %d", slot); + /* Verify the socket for the slot */ if (!pgagroal_socket_isvalid(config->connections[slot].fd)) { @@ -311,10 +323,27 @@ pgagroal_return_connection(void* shmem, int slot) /* Return the connection, if not GRACEFULLY */ if (state == STATE_IN_USE) { + SSL_CTX* ctx; + ZF_LOGD("pgagroal_return_connection: Slot %d FD %d", slot, config->connections[slot].fd); - pgagroal_write_deallocate_all(NULL, config->connections[slot].fd); - pgagroal_write_reset_all(NULL, config->connections[slot].fd); + ZF_LOGI("BEFORE CONNECTION RESET: Slot %d", slot); + pgagroal_write_deallocate_all(ssl, config->connections[slot].fd); + pgagroal_write_reset_all(ssl, config->connections[slot].fd); + ZF_LOGI("AFTER CONNECTION RESET: Slot %d", slot); + + if (pgagroal_save_tls_connection(ssl, slot, shmem)) + { + goto kill_connection; + } + + /* TODO - SSL_shutdown */ + if (ssl != NULL) + { + ctx = SSL_get_SSL_CTX(ssl); + SSL_free(ssl); + SSL_CTX_free(ctx); + } config->connections[slot].timestamp = time(NULL); @@ -339,22 +368,28 @@ pgagroal_return_connection(void* shmem, int slot) } else if (state == STATE_GRACEFULLY) { - pgagroal_write_terminate(NULL, config->connections[slot].fd); + pgagroal_write_terminate(ssl, config->connections[slot].fd); } } - return pgagroal_kill_connection(shmem, slot); +kill_connection: + + return pgagroal_kill_connection(shmem, slot, ssl); } int -pgagroal_kill_connection(void* shmem, int slot) +pgagroal_kill_connection(void* shmem, int slot, SSL* ssl) { + SSL_CTX* ctx; + int ssl_shutdown; int result = 0; int fd; struct configuration* config; config = (struct configuration*)shmem; + ZF_LOGI("pgagroal_kill_connection: Slot %d", slot); + ZF_LOGD("pgagroal_kill_connection: Slot %d FD %d State %d PID %d", slot, config->connections[slot].fd, atomic_load(&config->states[slot]), config->connections[slot].pid); @@ -363,6 +398,19 @@ pgagroal_kill_connection(void* shmem, int slot) if (fd != -1) { pgagroal_management_kill_connection(shmem, slot, fd); + + if (ssl != NULL) + { + ctx = SSL_get_SSL_CTX(ssl); + ssl_shutdown = SSL_shutdown(ssl); + if (ssl_shutdown == 0) + { + SSL_shutdown(ssl); + } + SSL_free(ssl); + SSL_CTX_free(ctx); + } + pgagroal_disconnect(fd); } else @@ -430,7 +478,7 @@ pgagroal_idle_timeout(void* shmem) double diff = difftime(now, config->connections[i].timestamp); if (diff >= (double)config->idle_timeout) { - pgagroal_kill_connection(shmem, i); + pgagroal_kill_connection(shmem, i, NULL); prefill = true; } else @@ -505,7 +553,7 @@ pgagroal_validation(void* shmem) if (kill) { - pgagroal_kill_connection(shmem, i); + pgagroal_kill_connection(shmem, i, NULL); /* TODO */ prefill = true; } else @@ -557,7 +605,7 @@ pgagroal_flush(void* shmem, int mode) { pgagroal_write_terminate(NULL, config->connections[i].fd); } - pgagroal_kill_connection(shmem, i); + pgagroal_kill_connection(shmem, i, NULL);/* TODO */ prefill = true; } else if (mode == FLUSH_ALL || mode == FLUSH_GRACEFULLY) @@ -567,7 +615,7 @@ pgagroal_flush(void* shmem, int mode) if (mode == FLUSH_ALL) { kill(config->connections[i].pid, SIGQUIT); - pgagroal_kill_connection(shmem, i); + pgagroal_kill_connection(shmem, i, NULL);/* TODO */ prefill = true; } else if (mode == FLUSH_GRACEFULLY) @@ -623,6 +671,7 @@ pgagroal_prefill(void* shmem, bool initial) if (strcmp("all", config->limits[i].database) && strcmp("all", config->limits[i].username)) { int user = -1; + SSL* server_ssl = NULL; for (int j = 0; j < config->number_of_users && user == -1; j++) { @@ -639,7 +688,7 @@ pgagroal_prefill(void* shmem, bool initial) int32_t slot = -1; if (pgagroal_prefill_auth(config->users[user].username, config->users[user].password, - config->limits[i].database, shmem, &slot) != AUTH_SUCCESS) + config->limits[i].database, shmem, &slot, &server_ssl) != AUTH_SUCCESS) { ZF_LOGW("Invalid data for user '%s' using limit entry (%d)", config->limits[i].username, i); @@ -649,10 +698,10 @@ pgagroal_prefill(void* shmem, bool initial) { if (pgagroal_socket_isvalid(config->connections[slot].fd)) { - pgagroal_write_terminate(NULL, config->connections[slot].fd); + pgagroal_write_terminate(server_ssl, config->connections[slot].fd); } } - pgagroal_kill_connection(shmem, slot); + pgagroal_kill_connection(shmem, slot, server_ssl); } break; @@ -662,7 +711,7 @@ pgagroal_prefill(void* shmem, bool initial) { if (config->connections[slot].has_security != SECURITY_INVALID) { - pgagroal_return_connection(shmem, slot); + pgagroal_return_connection(shmem, slot, server_ssl); } else { @@ -671,13 +720,15 @@ pgagroal_prefill(void* shmem, bool initial) { if (pgagroal_socket_isvalid(config->connections[slot].fd)) { - pgagroal_write_terminate(NULL, config->connections[slot].fd); + pgagroal_write_terminate(server_ssl, config->connections[slot].fd); } } - pgagroal_kill_connection(shmem, slot); + pgagroal_kill_connection(shmem, slot, server_ssl); break; } } + + server_ssl = NULL; } } else @@ -850,7 +901,7 @@ remove_connection(void* shmem, char* username, char* database) } else { - pgagroal_kill_connection(shmem, i); + pgagroal_kill_connection(shmem, i, NULL); } return true; @@ -906,6 +957,9 @@ connection_details(void* shmem, int slot) ZF_LOGV_MEM(&connection.security_messages[i], connection.security_lengths[i], " Message %p:", (const void *)&connection.security_messages[i]); } + ZF_LOGV(" Session length: %d", connection.ssl_session_length); + ZF_LOGV_MEM(&connection.ssl_session, connection.ssl_session_length, + " Session %p:", (const void *)&connection.ssl_session); break; case STATE_IN_USE: ZF_LOGD("pgagroal_pool_status: State: IN_USE"); diff --git a/src/libpgagroal/security.c b/src/libpgagroal/security.c index 1d8767a0..5bb94937 100644 --- a/src/libpgagroal/security.c +++ b/src/libpgagroal/security.c @@ -58,21 +58,21 @@ static int get_auth_type(struct message* msg, int* auth_type); static int compare_auth_response(struct message* orig, struct message* response, int auth_type); -static int use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int hba_method, void* shmem); +static int use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int hba_method, void* shmem, SSL** server_ssl); static int use_unpooled_connection(struct message* msg, SSL* c_ssl, int client_fd, int slot, - char* username, int hba_method, void* shmem); + char* username, int hba_method, void* shmem, SSL** server_ssl); static int client_trust(SSL* c_ssl, int client_fd, char* username, char* password, int slot, void* shmem); static int client_password(SSL* c_ssl, int client_fd, char* username, char* password, int slot, void* shmem); static int client_md5(SSL* c_ssl, int client_fd, char* username, char* password, int slot, void* shmem); static int client_scram256(SSL* c_ssl, int client_fd, char* username, char* password, int slot, void* shmem); static int client_ok(SSL* c_ssl, int client_fd, int slot, void* shmem); -static int server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd, int slot, void* shmem); +static int server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd, int slot, SSL* s_ssl, void* shmem); static int server_authenticate(struct message* msg, int auth_type, char* username, char* password, - int slot, void* shmem); + SSL* server_ssl, int slot, void* shmem); static int server_trust(int slot, void* shmem); -static int server_password(char* username, char* password, int slot, void* shmem); -static int server_md5(char* username, char* password, int slot, void* shmem); -static int server_scram256(char* username, char* password, int slot, void* shmem); +static int server_password(char* username, char* password, SSL* server_ssl, int slot, void* shmem); +static int server_md5(char* username, char* password, SSL* server_ssl, int slot, void* shmem); +static int server_scram256(char* username, char* password, SSL* server_ssl, int slot, void* shmem); static bool is_allowed(char* username, char* database, char* address, void* shmem, int* hba_method); static bool is_allowed_username(char* username, char* entry); @@ -105,10 +105,13 @@ static int server_signature(char* password, char* salt, int salt_length, int it static bool is_tls_user(char* username, char* database, void* shmem); static int create_ssl_ctx(bool client, SSL_CTX** ctx); +static int create_ssl_client(SSL_CTX* ctx, int socket, void* shmem, unsigned char server, SSL** ssl); static int create_ssl_server(SSL_CTX* ctx, int socket, void* shmem, SSL** ssl); +static int establish_client_tls_connection(int slot, void* shmem, SSL** ssl); +static int create_client_tls_connection(int slot, void* shmem, SSL** ssl); int -pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL** client_ssl) +pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL** client_ssl, SSL** server_ssl) { int status = MESSAGE_STATUS_ERROR; int ret; @@ -127,6 +130,7 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* *slot = -1; *client_ssl = NULL; + *server_ssl = NULL; /* Receive client calls - at any point if client exits return AUTH_ERROR */ status = pgagroal_read_timeout_message(NULL, client_fd, config->authentication_timeout, &msg); @@ -302,7 +306,7 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* } /* Get connection */ - ret = pgagroal_get_connection(shmem, username, database, true, slot); + ret = pgagroal_get_connection(shmem, username, database, true, slot, server_ssl); if (ret != 0) { if (ret == 1) @@ -321,12 +325,14 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* goto bad_password; } + ZF_LOGI("Slot %d SECURITY %d", *slot, config->connections[*slot].has_security); + if (config->connections[*slot].has_security != SECURITY_INVALID) { ZF_LOGD("authenticate: getting pooled connection"); pgagroal_free_message(msg); - ret = use_pooled_connection(c_ssl, client_fd, *slot, username, hba_method, shmem); + ret = use_pooled_connection(c_ssl, client_fd, *slot, username, hba_method, shmem, server_ssl); if (ret == AUTH_BAD_PASSWORD) { goto bad_password; @@ -342,7 +348,7 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* { ZF_LOGD("authenticate: creating pooled connection"); - ret = use_unpooled_connection(request_msg, c_ssl, client_fd, *slot, username, hba_method, shmem); + ret = use_unpooled_connection(request_msg, c_ssl, client_fd, *slot, username, hba_method, shmem, server_ssl); if (ret == AUTH_BAD_PASSWORD) { goto bad_password; @@ -396,8 +402,9 @@ pgagroal_authenticate(int client_fd, char* address, void* shmem, int* slot, SSL* } int -pgagroal_prefill_auth(char* username, char* password, char* database, void* shmem, int* slot) +pgagroal_prefill_auth(char* username, char* password, char* database, void* shmem, int* slot, SSL** server_ssl) { + SSL* s_ssl = NULL; int server_fd = -1; int auth_type = -1; struct configuration* config = NULL; @@ -409,26 +416,35 @@ pgagroal_prefill_auth(char* username, char* password, char* database, void* shme config = (struct configuration*)shmem; /* Get connection */ - ret = pgagroal_get_connection(shmem, username, database, false, slot); + ret = pgagroal_get_connection(shmem, username, database, false, slot, &s_ssl); if (ret != 0) { goto error; } server_fd = config->connections[*slot].fd; + /* Establish TLS if needed */ + if (config->servers[config->connections[*slot].server].tls) + { + if (establish_client_tls_connection(*slot, shmem, &s_ssl) != AUTH_SUCCESS) + { + goto error; + } + } + status = pgagroal_create_startup_message(username, database, &startup_msg); if (status != MESSAGE_STATUS_OK) { goto error; } - status = pgagroal_write_message(NULL, server_fd, startup_msg); + status = pgagroal_write_message(s_ssl, server_fd, startup_msg); if (status != MESSAGE_STATUS_OK) { goto error; } - status = pgagroal_read_block_message(NULL, server_fd, &msg); + status = pgagroal_read_block_message(s_ssl, server_fd, &msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -451,7 +467,7 @@ pgagroal_prefill_auth(char* username, char* password, char* database, void* shme goto error; } - if (server_authenticate(msg, auth_type, username, password, *slot, shmem)) + if (server_authenticate(msg, auth_type, username, password, s_ssl, *slot, shmem)) { goto error; } @@ -460,13 +476,15 @@ pgagroal_prefill_auth(char* username, char* password, char* database, void* shme config->servers[config->connections[*slot].server].primary == SERVER_NOTINIT_PRIMARY) { ZF_LOGD("Verify server mode: %d", config->connections[*slot].server); - pgagroal_update_server_state(shmem, *slot, server_fd); + pgagroal_update_server_state(shmem, *slot, server_fd, s_ssl); pgagroal_server_status(shmem); } ZF_LOGV("prefill_auth: has_security %d", config->connections[*slot].has_security); ZF_LOGD("prefill_auth: SUCCESS"); + *server_ssl = s_ssl; + pgagroal_free_copy_message(startup_msg); pgagroal_free_message(msg); @@ -478,10 +496,11 @@ pgagroal_prefill_auth(char* username, char* password, char* database, void* shme if (*slot != -1) { - pgagroal_kill_connection(shmem, *slot); + pgagroal_kill_connection(shmem, *slot, s_ssl); } *slot = -1; + *server_ssl = NULL; pgagroal_free_copy_message(startup_msg); pgagroal_free_message(msg); @@ -604,7 +623,7 @@ compare_auth_response(struct message* orig, struct message* response, int auth_t } static int -use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int hba_method, void* shmem) +use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int hba_method, void* shmem, SSL** server_ssl) { int status = MESSAGE_STATUS_ERROR; struct configuration* config = NULL; @@ -621,6 +640,21 @@ use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int h hba_method = config->connections[slot].has_security; } + ZF_LOGI("BEGIN - use_pooled_connection: Slot %d", slot); + + /* TODO -- We got TLS information so reestablish connection */ + if (config->connections[slot].ssl_session_length > 0 && *server_ssl == NULL) + { + int status; + + ZF_LOGI("SSL session length: %d", config->connections[slot].ssl_session_length); + + status = create_client_tls_connection(slot, shmem, server_ssl); + + ZF_LOGI("status: %d", status); + ZF_LOGI("s_ssl: %p", *server_ssl); + } + if (password == NULL) { /* We can only deal with SECURITY_TRUST, SECURITY_PASSWORD and SECURITY_MD5 */ @@ -711,7 +745,9 @@ use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int h else if (hba_method == SECURITY_SCRAM256) { /* R/10 */ + ZF_LOGI("SCRAM256: Slot %d", slot); status = client_scram256(c_ssl, client_fd, username, password, slot, shmem); + ZF_LOGI("SCRAM256: Slot %d Status %d", slot, status); if (status == AUTH_BAD_PASSWORD) { goto bad_password; @@ -732,16 +768,22 @@ use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int h } } + ZF_LOGI("SUCCESS - use_pooled_connection: Slot %d", slot); + return AUTH_SUCCESS; bad_password: + ZF_LOGI("BAD_PASSWORD - use_pooled_connection: Slot %d", slot); + ZF_LOGV("use_pooled_connection: bad password for slot %d", slot); return AUTH_BAD_PASSWORD; error: + ZF_LOGI("ERROR - use_pooled_connection: Slot %d", slot); + ZF_LOGV("use_pooled_connection: failed for slot %d", slot); return AUTH_ERROR; @@ -749,12 +791,13 @@ use_pooled_connection(SSL* c_ssl, int client_fd, int slot, char* username, int h static int use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, int slot, - char* username, int hba_method, void* shmem) + char* username, int hba_method, void* shmem, SSL** server_ssl) { int status = MESSAGE_STATUS_ERROR; int server_fd; int auth_type = -1; char* password; + SSL* s_ssl = NULL; struct message* msg = NULL; struct message* auth_msg = NULL; struct configuration* config = NULL; @@ -773,9 +816,20 @@ use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, goto error; } + /* We may need a TLS connection to the server */ + if (config->servers[config->connections[slot].server].tls) + { + if (establish_client_tls_connection(slot, shmem, server_ssl)) + { + goto error; + } + } + + s_ssl = *server_ssl; + /* Send auth request to PostgreSQL */ ZF_LOGV("authenticate: client auth request (%d)", client_fd); - status = pgagroal_write_message(NULL, server_fd, request_msg); + status = pgagroal_write_message(s_ssl, server_fd, request_msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -784,7 +838,7 @@ use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, /* Keep response, and send response to client */ ZF_LOGV("authenticate: server auth request (%d)", server_fd); - status = pgagroal_read_block_message(NULL, server_fd, &msg); + status = pgagroal_read_block_message(s_ssl, server_fd, &msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -814,7 +868,7 @@ use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, if (password == NULL) { - if (server_passthrough(msg, auth_type, c_ssl, client_fd, slot, shmem)) + if (server_passthrough(msg, auth_type, c_ssl, client_fd, slot, s_ssl, shmem)) { goto error; } @@ -880,7 +934,7 @@ use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, goto error; } - if (server_authenticate(auth_msg, auth_type, username, password, slot, shmem)) + if (server_authenticate(auth_msg, auth_type, username, password, s_ssl, slot, shmem)) { if (pgagroal_socket_isvalid(client_fd)) { @@ -901,7 +955,7 @@ use_unpooled_connection(struct message* request_msg, SSL* c_ssl, int client_fd, config->servers[config->connections[slot].server].primary == SERVER_NOTINIT_PRIMARY) { ZF_LOGD("Verify server mode: %d", config->connections[slot].server); - pgagroal_update_server_state(shmem, slot, server_fd); + pgagroal_update_server_state(shmem, slot, server_fd, s_ssl); pgagroal_server_status(shmem); } @@ -1201,6 +1255,8 @@ client_scram256(SSL* c_ssl, int client_fd, char* username, char* password, int s } } + ZF_LOGI("SLOT: %d STATUS: %d", slot, status); + if (status != MESSAGE_STATUS_OK) { goto error; @@ -1407,7 +1463,7 @@ client_ok(SSL* c_ssl, int client_fd, int slot, void* shmem) } static int -server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd, int slot, void* shmem) +server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd, int slot, SSL* s_ssl, void* shmem) { int status = MESSAGE_STATUS_ERROR; int server_fd; @@ -1462,14 +1518,14 @@ server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd memcpy(&config->connections[slot].security_messages[auth_index], msg->data, msg->length); auth_index++; - status = pgagroal_write_message(NULL, server_fd, msg); + status = pgagroal_write_message(s_ssl, server_fd, msg); if (status != MESSAGE_STATUS_OK) { goto error; } pgagroal_free_message(msg); - status = pgagroal_read_block_message(NULL, server_fd, &msg); + status = pgagroal_read_block_message(s_ssl, server_fd, &msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -1510,14 +1566,14 @@ server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd memcpy(&config->connections[slot].security_messages[auth_index], msg->data, msg->length); auth_index++; - status = pgagroal_write_message(NULL, server_fd, msg); + status = pgagroal_write_message(s_ssl, server_fd, msg); if (status != MESSAGE_STATUS_OK) { goto error; } pgagroal_free_message(msg); - status = pgagroal_read_block_message(NULL, server_fd, &msg); + status = pgagroal_read_block_message(s_ssl, server_fd, &msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -1568,7 +1624,7 @@ server_passthrough(struct message* msg, int auth_type, SSL* c_ssl, int client_fd } static int -server_authenticate(struct message* msg, int auth_type, char* username, char* password, int slot, void* shmem) +server_authenticate(struct message* msg, int auth_type, char* username, char* password, SSL* server_ssl, int slot, void* shmem) { struct configuration* config = NULL; @@ -1594,15 +1650,15 @@ server_authenticate(struct message* msg, int auth_type, char* username, char* pa } else if (auth_type == SECURITY_PASSWORD) { - return server_password(username, password, slot, shmem); + return server_password(username, password, server_ssl, slot, shmem); } else if (auth_type == SECURITY_MD5) { - return server_md5(username, password, slot, shmem); + return server_md5(username, password, server_ssl, slot, shmem); } else if (auth_type == SECURITY_SCRAM256) { - return server_scram256(username, password, slot, shmem); + return server_scram256(username, password, server_ssl, slot, shmem); } error: @@ -1627,7 +1683,7 @@ server_trust(int slot, void* shmem) } static int -server_password(char* username, char* password, int slot, void* shmem) +server_password(char* username, char* password, SSL* server_ssl, int slot, void* shmem) { int status = MESSAGE_STATUS_ERROR; int auth_index = 1; @@ -1648,7 +1704,7 @@ server_password(char* username, char* password, int slot, void* shmem) goto error; } - status = pgagroal_write_message(NULL, server_fd, password_msg); + status = pgagroal_write_message(server_ssl, server_fd, password_msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -1658,7 +1714,7 @@ server_password(char* username, char* password, int slot, void* shmem) memcpy(&config->connections[slot].security_messages[auth_index], password_msg->data, password_msg->length); auth_index++; - status = pgagroal_read_block_message(NULL, server_fd, &auth_msg); + status = pgagroal_read_block_message(server_ssl, server_fd, &auth_msg); if (auth_msg->length > SECURITY_BUFFER_SIZE) { ZF_LOGE("Security message too large: %ld", auth_msg->length); @@ -1709,7 +1765,7 @@ server_password(char* username, char* password, int slot, void* shmem) } static int -server_md5(char* username, char* password, int slot, void* shmem) +server_md5(char* username, char* password, SSL* server_ssl, int slot, void* shmem) { int status = MESSAGE_STATUS_ERROR; int auth_index = 1; @@ -1766,7 +1822,7 @@ server_md5(char* username, char* password, int slot, void* shmem) goto error; } - status = pgagroal_write_message(NULL, server_fd, md5_msg); + status = pgagroal_write_message(server_ssl, server_fd, md5_msg); if (status != MESSAGE_STATUS_OK) { goto error; @@ -1776,7 +1832,7 @@ server_md5(char* username, char* password, int slot, void* shmem) memcpy(&config->connections[slot].security_messages[auth_index], md5_msg->data, md5_msg->length); auth_index++; - status = pgagroal_read_block_message(NULL, server_fd, &auth_msg); + status = pgagroal_read_block_message(server_ssl, server_fd, &auth_msg); if (auth_msg->length > SECURITY_BUFFER_SIZE) { ZF_LOGE("Security message too large: %ld", auth_msg->length); @@ -1845,7 +1901,7 @@ server_md5(char* username, char* password, int slot, void* shmem) } static int -server_scram256(char* username, char* password, int slot, void* shmem) +server_scram256(char* username, char* password, SSL* server_ssl, int slot, void* shmem) { int status = MESSAGE_STATUS_ERROR; int auth_index = 1; @@ -1899,13 +1955,13 @@ server_scram256(char* username, char* password, int slot, void* shmem) memcpy(&config->connections[slot].security_messages[auth_index], sasl_response->data, sasl_response->length); auth_index++; - status = pgagroal_write_message(NULL, server_fd, sasl_response); + status = pgagroal_write_message(server_ssl, server_fd, sasl_response); if (status != MESSAGE_STATUS_OK) { goto error; } - status = pgagroal_read_block_message(NULL, server_fd, &sasl_continue); + status = pgagroal_read_block_message(server_ssl, server_fd, &sasl_continue); if (sasl_continue->length > SECURITY_BUFFER_SIZE) { ZF_LOGE("Security message too large: %ld", sasl_continue->length); @@ -1959,13 +2015,13 @@ server_scram256(char* username, char* password, int slot, void* shmem) memcpy(&config->connections[slot].security_messages[auth_index], sasl_continue_response->data, sasl_continue_response->length); auth_index++; - status = pgagroal_write_message(NULL, server_fd, sasl_continue_response); + status = pgagroal_write_message(server_ssl, server_fd, sasl_continue_response); if (status != MESSAGE_STATUS_OK) { goto error; } - status = pgagroal_read_block_message(NULL, server_fd, &sasl_final); + status = pgagroal_read_block_message(server_ssl, server_fd, &sasl_final); if (sasl_final->length > SECURITY_BUFFER_SIZE) { ZF_LOGE("Security message too large: %ld", sasl_final->length); @@ -2568,6 +2624,57 @@ pgagroal_tls_valid(void* shmem) return 1; } +int +pgagroal_load_tls_connection(int slot, void* shmem, SSL** ssl) +{ + int result = 0; + struct configuration* config; + + config = (struct configuration*)shmem; + + if (config->connections[slot].ssl_session_length > 0) + { + result = create_client_tls_connection(slot, shmem, ssl); + } + + return result; +} + +int +pgagroal_save_tls_connection(SSL* ssl, int slot, void* shmem) +{ + int length; + SSL_SESSION* session = NULL; + unsigned char* p = NULL; + struct configuration* config; + + config = (struct configuration*)shmem; + + config->connections[slot].ssl_session_length = 0; + memset(&config->connections[slot].ssl_session, 0, SSL_SESSION_BUFFER_SIZE); + + if (ssl != NULL) + { + p = (unsigned char*)config->connections[slot].ssl_session; + + session = SSL_get_session(ssl); + + length = i2d_SSL_SESSION(session, NULL); + if (length > SSL_SESSION_BUFFER_SIZE) + { + goto error; + } + + config->connections[slot].ssl_session_length = i2d_SSL_SESSION(session, &p); + } + + return 0; + +error: + + return 1; +} + static int derive_key_iv(char *password, unsigned char *key, unsigned char *iv) { @@ -3392,7 +3499,8 @@ create_ssl_ctx(bool client, SSL_CTX** ctx) SSL_CTX_set_mode(c, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); SSL_CTX_set_options(c, SSL_OP_NO_TICKET); - SSL_CTX_set_session_cache_mode(c, SSL_SESS_CACHE_OFF); + SSL_CTX_set_session_cache_mode(c, SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE); + /* SSL_CTX_set_session_cache_mode(c, SSL_SESS_CACHE_OFF); */ *ctx = c; @@ -3408,6 +3516,86 @@ create_ssl_ctx(bool client, SSL_CTX** ctx) return 1; } +static int +create_ssl_client(SSL_CTX* ctx, int socket, void* shmem, unsigned char server, SSL** ssl) +{ + SSL* s = NULL; + bool have_cert = false; + bool have_rootcert = false; + struct configuration* config; + + config = (struct configuration*)shmem; + + if (strlen(config->servers[server].tls_ca_file) > 0) + { + if (SSL_CTX_load_verify_locations(ctx, config->servers[server].tls_ca_file, NULL) != 1) + { + ZF_LOGE("Couldn't load TLS CA: %s", config->servers[server].tls_ca_file); + goto error; + } + + have_rootcert = true; + } + + if (strlen(config->servers[server].tls_cert_file) > 0) + { + if (SSL_CTX_use_certificate_chain_file(ctx, config->servers[server].tls_cert_file) != 1) + { + ZF_LOGE("Couldn't load TLS certificate: %s", config->servers[server].tls_cert_file); + goto error; + } + + have_cert = true; + } + + s = SSL_new(ctx); + + if (s == NULL) + { + goto error; + } + + if (SSL_set_fd(s, socket) == 0) + { + goto error; + } + + if (have_cert && strlen(config->servers[server].tls_key_file) > 0) + { + if (SSL_use_PrivateKey_file(s, config->servers[server].tls_key_file, SSL_FILETYPE_PEM) != 1) + { + ZF_LOGE("Couldn't load TLS private key: %s", config->servers[server].tls_key_file); + goto error; + } + + if (SSL_check_private_key(s) != 1) + { + ZF_LOGE("TLS private key check failed: %s", config->servers[server].tls_key_file); + goto error; + } + } + + if (have_rootcert) + { + SSL_set_verify(s, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, NULL); + } + + *ssl = s; + + return 0; + +error: + + if (s != NULL) + { + SSL_shutdown(s); + SSL_free(s); + } + SSL_CTX_free(ctx); + + return 1; +} + static int create_ssl_server(SSL_CTX* ctx, int socket, void* shmem, SSL** ssl) { @@ -3493,3 +3681,157 @@ create_ssl_server(SSL_CTX* ctx, int socket, void* shmem, SSL** ssl) return 1; } + +static int +establish_client_tls_connection(int slot, void* shmem, SSL** ssl) +{ + int fd = -1; + struct configuration* config = NULL; + struct message* ssl_msg = NULL; + struct message* msg = NULL; + int status = -1; + + config = (struct configuration*)shmem; + + fd = config->connections[slot].fd; + + status = pgagroal_create_ssl_message(&ssl_msg); + if (status != MESSAGE_STATUS_OK) + { + goto error; + } + + status = pgagroal_write_message(NULL, fd, ssl_msg); + if (status != MESSAGE_STATUS_OK) + { + goto error; + } + + status = pgagroal_read_block_message(NULL, fd, &msg); + if (status != MESSAGE_STATUS_OK) + { + goto error; + } + + if (msg->kind == 'S') + { + create_client_tls_connection(slot, shmem, ssl); + } + + pgagroal_free_copy_message(ssl_msg); + pgagroal_free_message(msg); + + return AUTH_SUCCESS; + +error: + + ZF_LOGD("establish_client_tls_connection: ERROR"); + + pgagroal_free_copy_message(ssl_msg); + pgagroal_free_message(msg); + + return AUTH_ERROR; +} + +static int +create_client_tls_connection(int slot, void* shmem, SSL** ssl) +{ + SSL_CTX* ctx = NULL; + SSL* s = NULL; + SSL_SESSION* session = NULL; + int fd = -1; + int status = -1; + unsigned char* p = NULL; + struct configuration* config; + + config = (struct configuration*)shmem; + + fd = config->connections[slot].fd; + + /* We are acting as a client against the server */ + if (create_ssl_ctx(true, &ctx)) + { + ZF_LOGE("CTX failed"); + goto error; + } + + /* Create SSL structure */ + if (create_ssl_client(ctx, fd, shmem, config->connections[slot].server, &s)) + { + ZF_LOGE("Client failed"); + goto error; + } + + /* If we have an existing session then load it */ + if (config->connections[slot].ssl_session_length > 0) + { + p = (unsigned char*)config->connections[slot].ssl_session; + + session = d2i_SSL_SESSION(NULL, (const unsigned char**)&p, config->connections[slot].ssl_session_length); + + ZF_LOGE("SSL: %p", s); + ZF_LOGE("SESSION: %p", session); + + if (session == NULL) + { + goto error; + } + + if (SSL_set_session(s, session) != 1) + { + goto error; + } + } + + do + { + status = SSL_connect(s); + /* status = SSL_do_handshake(s); */ + + if (status != 1) + { + int err = SSL_get_error(s, status); + switch (err) + { + case SSL_ERROR_ZERO_RETURN: + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_CONNECT: + case SSL_ERROR_WANT_ACCEPT: + case SSL_ERROR_WANT_X509_LOOKUP: + case SSL_ERROR_WANT_ASYNC: + case SSL_ERROR_WANT_ASYNC_JOB: + case SSL_ERROR_WANT_CLIENT_HELLO_CB: + break; + case SSL_ERROR_SYSCALL: + ZF_LOGE("SSL_ERROR_SYSCALL: %s (%d)", strerror(errno), fd); + errno = 0; + goto error; + break; + case SSL_ERROR_SSL: + ZF_LOGE("SSL_ERROR_SSL: %s (%d) Slot: %d", strerror(errno), fd, slot); + ZF_LOGE("%s", ERR_error_string(err, NULL)); + ZF_LOGE("%s", ERR_lib_error_string(err)); + ZF_LOGE("%s", ERR_reason_error_string(err)); + errno = 0; + goto error; + break; + } + ERR_clear_error(); + } + } while (status != 1); + + *ssl = s; + + ZF_LOGI("create_client_tls_connection: SUCCESS"); + + return AUTH_SUCCESS; + +error: + + ZF_LOGI("create_client_tls_connection: ERROR"); + + *ssl = s; + + return AUTH_ERROR; +} diff --git a/src/libpgagroal/server.c b/src/libpgagroal/server.c index 31b128cd..787bc69e 100644 --- a/src/libpgagroal/server.c +++ b/src/libpgagroal/server.c @@ -71,7 +71,7 @@ pgagroal_get_primary(void* shmem, int* server) } int -pgagroal_update_server_state(void* shmem, int slot, int socket) +pgagroal_update_server_state(void* shmem, int slot, int socket, SSL* ssl) { int status; int server; @@ -96,13 +96,13 @@ pgagroal_update_server_state(void* shmem, int slot, int socket) qmsg.length = size; qmsg.data = &is_recovery; - status = pgagroal_write_message(NULL, socket, &qmsg); + status = pgagroal_write_message(ssl, socket, &qmsg); if (status != MESSAGE_STATUS_OK) { goto error; } - status = pgagroal_read_block_message(NULL, socket, &tmsg); + status = pgagroal_read_block_message(ssl, socket, &tmsg); if (status != MESSAGE_STATUS_OK) { goto error; diff --git a/src/libpgagroal/worker.c b/src/libpgagroal/worker.c index b875c924..dc738835 100644 --- a/src/libpgagroal/worker.c +++ b/src/libpgagroal/worker.c @@ -67,6 +67,7 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) struct pipeline p; int32_t slot = -1; SSL* client_ssl = NULL; + SSL* server_ssl = NULL; pgagroal_start_logging(shmem); pgagroal_memory_init(shmem); @@ -77,7 +78,7 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) memset(&server_io, 0, sizeof(struct worker_io)); /* Authentication */ - auth_status = pgagroal_authenticate(client_fd, address, shmem, &slot, &client_ssl); + auth_status = pgagroal_authenticate(client_fd, address, shmem, &slot, &client_ssl, &server_ssl); if (auth_status == AUTH_SUCCESS) { ZF_LOGD("pgagroal_worker: Slot %d (%d -> %d)", slot, client_fd, config->connections[slot].fd); @@ -130,6 +131,7 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) client_io.server_fd = config->connections[slot].fd; client_io.slot = slot; client_io.client_ssl = client_ssl; + client_io.server_ssl = server_ssl; client_io.shmem = shmem; client_io.pipeline_shmem = pipeline_shmem; @@ -138,6 +140,7 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) server_io.server_fd = config->connections[slot].fd; server_io.slot = slot; server_io.client_ssl = client_ssl; + server_io.server_ssl = server_ssl; server_io.shmem = shmem; server_io.pipeline_shmem = pipeline_shmem; @@ -192,7 +195,7 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) (exit_code == WORKER_SUCCESS || exit_code == WORKER_CLIENT_FAILURE || (exit_code == WORKER_FAILURE && config->connections[slot].has_security != SECURITY_INVALID))) { - pgagroal_return_connection(shmem, slot); + pgagroal_return_connection(shmem, slot, server_ssl); } else if (exit_code == WORKER_SERVER_FAILURE || exit_code == WORKER_SERVER_FATAL || exit_code == WORKER_SHUTDOWN || (exit_code == WORKER_FAILURE && config->connections[slot].has_security == SECURITY_INVALID)) @@ -205,11 +208,11 @@ pgagroal_worker(int client_fd, char* address, void* shmem, void* pipeline_shmem) pgagroal_connection_isvalid(config->connections[slot].fd) && config->connections[slot].has_security != SECURITY_INVALID) { - pgagroal_return_connection(shmem, slot); + pgagroal_return_connection(shmem, slot, server_ssl); } else { - pgagroal_kill_connection(shmem, slot); + pgagroal_kill_connection(shmem, slot, server_ssl); } } }