Skip to content

Commit

Permalink
[Pal/lib] Add spinlocks to mbedTLS-specific SSL recv/send
Browse files Browse the repository at this point in the history
Linux-SGX PAL wraps all pipe/UNIX domain socket communication in
TLS sessions. Previously, Graphene assumed that only one thread
at a time accesses one TLS session (i.e., no multi-threading
support). This commit adds spinlocks to `lib_SSL*` functions to
support such (rare) multi-threading scenarios. Note that we cannot
rely on pthreads and/or mutexes so we use simple spinlocks.

This commit adds a corresponding LibOS test.
  • Loading branch information
Dmitrii Kuvaiskii authored and llly committed Mar 5, 2021
1 parent b3872c6 commit 469a3cd
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 9 deletions.
1 change: 1 addition & 0 deletions LibOS/shim/test/regression/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
/multi_pthread_exitless
/openmp
/pipe
/pipe_multithread
/pipe_nonblocking
/pipe_ocloexec
/poll
Expand Down
20 changes: 11 additions & 9 deletions LibOS/shim/test/regression/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ c_executables = \
multi_pthread \
openmp \
pipe \
pipe_multithread \
pipe_nonblocking \
pipe_ocloexec \
poll \
Expand Down Expand Up @@ -129,26 +130,27 @@ extra_rules = \
include ../../../../Scripts/manifest.mk
include ../../../../Scripts/Makefile.Test

CFLAGS-bootstrap_static = -static
CFLAGS-abort_multithread = -pthread
CFLAGS-bootstrap_pie = -fPIC -pie
CFLAGS-bootstrap_static = -static
CFLAGS-debug = -g3
CFLAGS-debug_regs-x86_64 = -g3
CFLAGS-eventfd = -pthread
CFLAGS-exec_same = -pthread
CFLAGS-shared_object = -fPIC -pie
CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux
CFLAGS-openmp = -fopenmp
CFLAGS-multi_pthread = -pthread
CFLAGS-exit_group = -pthread
CFLAGS-abort_multithread = -pthread
CFLAGS-eventfd = -pthread
CFLAGS-futex_bitset = -pthread
CFLAGS-futex_requeue = -pthread
CFLAGS-futex_wake_op = -pthread
CFLAGS-multi_pthread = -pthread
CFLAGS-openmp = -fopenmp
CFLAGS-pipe_multithread = -pthread
CFLAGS-proc_common = -pthread
CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread
CFLAGS-pthread_set_get_affinity += -pthread
CFLAGS-shared_object = -fPIC -pie
CFLAGS-sigaction_per_process += -pthread
CFLAGS-signal_multithread += -pthread
CFLAGS-pthread_set_get_affinity += -pthread
CFLAGS-spinlock += -I$(PALDIR)/../include/lib -I$(PALDIR)/../include/arch/$(ARCH) -pthread
CFLAGS-syscall += -I$(PALDIR)/../include -I$(PALDIR)/host/$(PAL_HOST) -I$(PALDIR)/../include/arch/$(ARCH)/Linux

CFLAGS-attestation += -I$(PALDIR)/../lib/crypto/mbedtls/crypto/include \
-I$(PALDIR)/host/Linux-SGX \
Expand Down
81 changes: 81 additions & 0 deletions LibOS/shim/test/regression/pipe_multithread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* test creates two threads simulteneously writing on the same pipe */

#include <err.h>
#include <errno.h>
#include <pthread.h>
#include <stdint.h>
#include <stdio.h>
#include <sys/socket.h>
#include <sys/types.h>

#define ITERATIONS 100000

int fds[2];

static void* thread_run(void* arg) {
char c = (char)(uintptr_t)arg;
for (int i = 0; i < ITERATIONS; i++) {
ssize_t bytes = 0;
while (bytes < sizeof(c)) {
bytes = send(fds[1], &c, sizeof(c), /*flags=*/0);
if (bytes < 0) {
if (errno == EAGAIN || errno == EINTR)
continue;
err(1, "send");
}
}
}
return NULL;
}

int main(int argc, char** argv) {
int ret;
pthread_t threads[2];
char thread_ids[2] = {42, 24};
int thread_bytes[2] = {0, 0};

ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fds);
if (ret) {
err(1, "socketpair");
}

ret = pthread_create(&threads[0], NULL, &thread_run, (void*)(uintptr_t)thread_ids[0]);
if (ret) {
errno = ret;
err(1, "pthread_create");
}

ret = pthread_create(&threads[1], NULL, &thread_run, (void*)(uintptr_t)thread_ids[1]);
if (ret) {
errno = ret;
err(1, "pthread_create");
}

for (int i = 0; i < 2 * ITERATIONS; i++) {
char c = 0;
ssize_t bytes = 0;
while (bytes < sizeof(c)) {
bytes = recv(fds[0], &c, sizeof(c), /*flags=*/0);
if (bytes < 0) {
if (errno == EAGAIN || errno == EINTR)
continue;
err(1, "recv");
}
}

if (c == thread_ids[0])
thread_bytes[0] += bytes;
else if (c == thread_ids[1])
thread_bytes[1] += bytes;
else
errx(1, "received unrecognized thread ID");
}

printf("received total bytes from threads: %d and %d\n", thread_bytes[0], thread_bytes[1]);

if (thread_bytes[0] != ITERATIONS || thread_bytes[1] != ITERATIONS)
errx(1, "received wrong number of bytes from threads");

puts("TEST OK");
return 0;
}
3 changes: 3 additions & 0 deletions Pal/include/lib/pal_crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <stdint.h>
#include <unistd.h>

#include "spinlock.h"

#define SHA256_DIGEST_LEN 32

#ifdef CRYPTO_USE_MBEDTLS
Expand Down Expand Up @@ -51,6 +53,7 @@ typedef struct {
ssize_t (*pal_recv_cb)(int fd, void* buf, size_t buf_size);
ssize_t (*pal_send_cb)(int fd, const void* buf, size_t buf_size);
int stream_fd;
spinlock_t lock;
} LIB_SSL_CONTEXT;

#endif /* CRYPTO_USE_MBEDTLS */
Expand Down
23 changes: 23 additions & 0 deletions Pal/lib/crypto/adapters/mbedtls_adapter.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "pal_debug.h"
#include "pal_error.h"
#include "rng-arch.h"
#include "spinlock.h"

int mbedtls_to_pal_error(int error) {
switch (error) {
Expand Down Expand Up @@ -380,6 +381,11 @@ static int recv_cb(void* ctx, uint8_t* buf, size_t buf_size) {
/* pal_recv_cb cannot receive more than 32-bit limit, trim buf_size to fit in 32-bit */
buf_size = INT_MAX;
}

/* NOTE: If two threads recv on the same SSL context simultaneously, one of them may block on
* recv() and the other will spin and burn CPU cycles. We consider "shared SSL context"
* a rare case and use simple spinlocks instead of mutexes. */
assert(spinlock_is_locked(&ssl_ctx->lock));
ssize_t ret = ssl_ctx->pal_recv_cb(fd, buf, buf_size);

if (ret < 0) {
Expand All @@ -403,7 +409,13 @@ static int send_cb(void* ctx, uint8_t const* buf, size_t buf_size) {
/* pal_send_cb cannot send more than 32-bit limit, trim buf_size to fit in 32-bit */
buf_size = INT_MAX;
}

/* NOTE: If two threads send on the same SSL context simultaneously, one of them may block on
* send() and the other will spin and burn CPU cycles. We consider "shared SSL context"
* a rare case and use simple spinlocks instead of mutexes. */
assert(spinlock_is_locked(&ssl_ctx->lock));
ssize_t ret = ssl_ctx->pal_send_cb(fd, buf, buf_size);

if (ret < 0) {
if (ret == -EINTR || ret == -EAGAIN || ret == -EWOULDBLOCK)
return MBEDTLS_ERR_SSL_WANT_WRITE;
Expand All @@ -430,6 +442,7 @@ int lib_SSLInit(LIB_SSL_CONTEXT* ssl_ctx, int stream_fd, bool is_server, const u
ssl_ctx->pal_recv_cb = pal_recv_cb;
ssl_ctx->pal_send_cb = pal_send_cb;
ssl_ctx->stream_fd = stream_fd;
spinlock_init(&ssl_ctx->lock);

mbedtls_entropy_init(&ssl_ctx->entropy);
mbedtls_ctr_drbg_init(&ssl_ctx->ctr_drbg);
Expand Down Expand Up @@ -482,32 +495,42 @@ int lib_SSLFree(LIB_SSL_CONTEXT* ssl_ctx) {

int lib_SSLHandshake(LIB_SSL_CONTEXT* ssl_ctx) {
int ret;

spinlock_lock(&ssl_ctx->lock);
while ((ret = mbedtls_ssl_handshake(&ssl_ctx->ssl)) != 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
break;
}
spinlock_unlock(&ssl_ctx->lock);

if (ret != 0)
return mbedtls_to_pal_error(ret);

return 0;
}

int lib_SSLRead(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_read(&ssl_ctx->ssl, buf, buf_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret < 0)
return mbedtls_to_pal_error(ret);
return ret;
}

int lib_SSLWrite(LIB_SSL_CONTEXT* ssl_ctx, const uint8_t* buf, size_t buf_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_write(&ssl_ctx->ssl, buf, buf_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret <= 0)
return mbedtls_to_pal_error(ret);
return ret;
}

int lib_SSLSave(LIB_SSL_CONTEXT* ssl_ctx, uint8_t* buf, size_t buf_size, size_t* out_size) {
spinlock_lock(&ssl_ctx->lock);
int ret = mbedtls_ssl_context_save(&ssl_ctx->ssl, buf, buf_size, out_size);
spinlock_unlock(&ssl_ctx->lock);
if (ret == MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL) {
return -PAL_ERROR_NOMEM;
} else if (ret < 0) {
Expand Down

0 comments on commit 469a3cd

Please sign in to comment.