Skip to content

Commit

Permalink
Call socket close() on Polaris_Disconnect().
Browse files Browse the repository at this point in the history
Merge pull request #33.
  • Loading branch information
adamshapiro0 authored Apr 6, 2021
2 parents 9ccd094 + 6fcfd42 commit d6b0a60
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
3 changes: 3 additions & 0 deletions c/examples/simple_polaris_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ int main(int argc, const char* argv[]) {
P1_printf("Opened Polaris context. Authenticating...\n");

if (Polaris_Authenticate(&context, api_key, unique_id) != POLARIS_SUCCESS) {
Polaris_Free(&context);
return 3;
}

Expand All @@ -50,6 +51,7 @@ int main(int argc, const char* argv[]) {
Polaris_SetRTCMCallback(&context, HandleData, NULL);

if (Polaris_Connect(&context) != POLARIS_SUCCESS) {
Polaris_Free(&context);
return 3;
}

Expand All @@ -61,6 +63,7 @@ int main(int argc, const char* argv[]) {
if (Polaris_SendLLAPosition(&context, 37.773971, -122.430996, -0.02) !=
POLARIS_SUCCESS) {
Polaris_Disconnect(&context);
Polaris_Free(&context);
return 4;
}

Expand Down
46 changes: 26 additions & 20 deletions c/src/point_one/polaris/polaris.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ static int SendPOSTRequest(PolarisContext_t* context, const char* endpoint_url,

static int GetHTTPResponse(PolarisContext_t* context);

static void CloseSocket(PolarisContext_t* context);
static void CloseSocket(PolarisContext_t* context, int destroy_context);

#ifdef POLARIS_USE_TLS
static void ShowCerts(SSL* ssl);
Expand Down Expand Up @@ -104,7 +104,7 @@ int Polaris_Init(PolarisContext_t* context) {

/******************************************************************************/
void Polaris_Free(PolarisContext_t* context) {
CloseSocket(context);
CloseSocket(context, 1);
}

/******************************************************************************/
Expand Down Expand Up @@ -260,7 +260,7 @@ int Polaris_ConnectTo(PolarisContext_t* context, const char* endpoint_url,
#endif
if (ret != message_size) {
P1_PrintError("Error sending authentication token", ret);
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_SEND_ERROR;
}

Expand All @@ -276,6 +276,12 @@ void Polaris_Disconnect(PolarisContext_t* context) {
SSL_shutdown(context->ssl);
#endif
shutdown(context->socket, SHUT_RDWR);
// Note: We intentionally close the socket here but do _not_ destroy the SSL
// context since Polaris_Work() may be suspended on it and will segfault if
// the memory is freed. Polaris_Work() and Polaris_Run() will free it when
// they return. If they are not currently being called, the user should call
// Polaris_Free() to free the context.
CloseSocket(context, 0);
}
}

Expand Down Expand Up @@ -433,7 +439,7 @@ int Polaris_Work(PolarisContext_t* context) {
"Connection terminated. [ret=%d, errno=%d, "
"disconnected=%d]\n",
(int)bytes_read, errno, context->disconnected);
CloseSocket(context);
CloseSocket(context, 1);
if (context->disconnected) {
return 0;
} else {
Expand All @@ -447,7 +453,7 @@ int Polaris_Work(PolarisContext_t* context) {
P1_Print(
"Warning: Polaris connection closed and no data received. Is your "
"authentication token valid? Did you send a position?\n");
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_FORBIDDEN;
}
// Otherwise, there may just not be new data available (e.g., user hasn't
Expand Down Expand Up @@ -492,14 +498,14 @@ int Polaris_Run(PolarisContext_t* context, int connection_timeout_ms) {
if (ret < 0) {
// Connection closed remotely or another error occurred.
P1_DebugPrint("Connection terminated. [ret=%d]\n", ret);
CloseSocket(context);
CloseSocket(context, 1);
break;
} else if (ret == 0) {
// Did the user call disconnect?
if (context->disconnected) {
ret = POLARIS_SUCCESS;
P1_DebugPrint("Connection terminated.\n");
CloseSocket(context);
CloseSocket(context, 1);
break;
}
// Read timed out - see if we've hit the connection timeout. Otherwise,
Expand All @@ -510,7 +516,7 @@ int Polaris_Run(PolarisContext_t* context, int connection_timeout_ms) {
int elapsed_ms = P1_GetElapsedMS(&last_read_time, &current_time);
if (elapsed_ms >= connection_timeout_ms) {
P1_Print("Warning: Connection timed out after %d ms.\n", elapsed_ms);
CloseSocket(context);
CloseSocket(context, 1);
ret = POLARIS_TIMED_OUT;
break;
}
Expand All @@ -522,7 +528,7 @@ int Polaris_Run(PolarisContext_t* context, int connection_timeout_ms) {

if (context->disconnected) {
P1_DebugPrint("Connection terminated.\n");
CloseSocket(context);
CloseSocket(context, 1);
break;
}
}
Expand Down Expand Up @@ -563,7 +569,7 @@ static int OpenSocket(PolarisContext_t* context, const char* endpoint_url,
context->socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (context->socket < 0) {
P1_Print("Error opening socket.\n");
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_SOCKET_ERROR;
}

Expand All @@ -578,7 +584,7 @@ static int OpenSocket(PolarisContext_t* context, const char* endpoint_url,
P1_SocketAddrV4_t address;
if (P1_SetAddress(endpoint_url, endpoint_port, &address) < 0) {
P1_Print("Error locating address '%s'.\n", endpoint_url);
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_SOCKET_ERROR;
}

Expand All @@ -588,7 +594,7 @@ static int OpenSocket(PolarisContext_t* context, const char* endpoint_url,
connect(context->socket, (P1_SocketAddr_t*)&address, sizeof(address));
if (ret < 0) {
P1_PrintError("Error connecting to endpoint", ret);
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_SOCKET_ERROR;
}

Expand All @@ -604,7 +610,7 @@ static int OpenSocket(PolarisContext_t* context, const char* endpoint_url,
// Perform SSL handhshake.
if (SSL_connect(context->ssl) == -1) {
P1_Print("SSL handshake failed to %s:%d.\n", endpoint_url, endpoint_port);
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_ERROR;
}

Expand All @@ -617,9 +623,9 @@ static int OpenSocket(PolarisContext_t* context, const char* endpoint_url,
}

/******************************************************************************/
void CloseSocket(PolarisContext_t* context) {
void CloseSocket(PolarisContext_t* context, int destroy_context) {
#ifdef POLARIS_USE_TLS
if (context->ssl != NULL) {
if (destroy_context && context->ssl != NULL) {
if (SSL_get_shutdown(context->ssl) == 0) {
SSL_shutdown(context->ssl);
}
Expand All @@ -635,7 +641,7 @@ void CloseSocket(PolarisContext_t* context) {
}

#ifdef POLARIS_USE_TLS
if (context->ssl_ctx != NULL) {
if (destroy_context && context->ssl_ctx != NULL) {
SSL_CTX_free(context->ssl_ctx);
context->ssl_ctx = NULL;
}
Expand Down Expand Up @@ -683,7 +689,7 @@ static int SendPOSTRequest(PolarisContext_t* context, const char* endpoint_url,
// authentication, before data is coming in.
if (POLARIS_RECV_BUFFER_SIZE < header_size + content_length + 1) {
P1_Print("Error populating POST request: buffer too small.\n");
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_NOT_ENOUGH_SPACE;
}

Expand All @@ -699,7 +705,7 @@ static int SendPOSTRequest(PolarisContext_t* context, const char* endpoint_url,
if (header_size < 0) {
// This shouldn't happen.
P1_Print("Error populating POST request.\n");
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_ERROR;
}

Expand All @@ -722,7 +728,7 @@ static int SendPOSTRequest(PolarisContext_t* context, const char* endpoint_url,

if (ret != message_size) {
P1_PrintError("Error sending POST request", ret);
CloseSocket(context);
CloseSocket(context, 1);
return POLARIS_SEND_ERROR;
}

Expand Down Expand Up @@ -751,7 +757,7 @@ static int GetHTTPResponse(PolarisContext_t* context) {
}
}

CloseSocket(context);
CloseSocket(context, 1);
P1_DebugPrint("Received HTTP request. [size=%u B]\n", (unsigned)total_bytes);

// Append a null terminator to the response.
Expand Down

0 comments on commit d6b0a60

Please sign in to comment.