/** * @file ServerConnection.c * @author Ambroz Bizjak * * @section LICENSE * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the author nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include #include #define STATE_CONNECTING 1 #define STATE_WAITINIT 2 #define STATE_COMPLETE 3 static void report_error (ServerConnection *o); static void connector_handler (ServerConnection *o, int is_error); static void pending_handler (ServerConnection *o); static SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey); static void connection_handler (ServerConnection *o, int event); static void sslcon_handler (ServerConnection *o, int event); static void decoder_handler_error (ServerConnection *o); static void input_handler_send (ServerConnection *o, uint8_t *data, int data_len); static void packet_hello (ServerConnection *o, uint8_t *data, int data_len); static void packet_newclient (ServerConnection *o, uint8_t *data, int data_len); static void packet_endclient (ServerConnection *o, uint8_t *data, int data_len); static void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len); static int start_packet (ServerConnection *o, void **data, int len); static void end_packet (ServerConnection *o, uint8_t type); static void newclient_job_handler (ServerConnection *o); void report_error (ServerConnection *o) { DEBUGERROR(&o->d_err, o->handler_error(o->user)) } void connector_handler (ServerConnection *o, int is_error) { DebugObject_Access(&o->d_obj); ASSERT(o->state == STATE_CONNECTING) ASSERT(!o->buffers_released) // check connection attempt result if (is_error) { BLog(BLOG_ERROR, "connection failed"); goto fail0; } BLog(BLOG_NOTICE, "connected"); // init connection if (!BConnection_Init(&o->con, BConnection_source_connector(&o->connector), o->reactor, o, (BConnection_handler)connection_handler)) { BLog(BLOG_ERROR, "BConnection_Init failed"); goto fail0; } // init connection interfaces BConnection_SendAsync_Init(&o->con); BConnection_RecvAsync_Init(&o->con); StreamPassInterface *send_iface = BConnection_SendAsync_GetIf(&o->con); StreamRecvInterface *recv_iface = BConnection_RecvAsync_GetIf(&o->con); if (o->have_ssl) { // create bottom NSPR file descriptor if (!BSSLConnection_MakeBackend(&o->bottom_prfd, send_iface, recv_iface, o->twd, o->ssl_flags)) { BLog(BLOG_ERROR, "BSSLConnection_MakeBackend failed"); goto fail0a; } // create SSL file descriptor from the bottom NSPR file descriptor if (!(o->ssl_prfd = SSL_ImportFD(NULL, &o->bottom_prfd))) { BLog(BLOG_ERROR, "SSL_ImportFD failed"); ASSERT_FORCE(PR_Close(&o->bottom_prfd) == PR_SUCCESS) goto fail0a; } // set client mode if (SSL_ResetHandshake(o->ssl_prfd, PR_FALSE) != SECSuccess) { BLog(BLOG_ERROR, "SSL_ResetHandshake failed"); goto fail1; } // set server name if (SSL_SetURL(o->ssl_prfd, o->server_name) != SECSuccess) { BLog(BLOG_ERROR, "SSL_SetURL failed"); goto fail1; } // set client certificate callback if (SSL_GetClientAuthDataHook(o->ssl_prfd, (SSLGetClientAuthData)client_auth_data_callback, o) != SECSuccess) { BLog(BLOG_ERROR, "SSL_GetClientAuthDataHook failed"); goto fail1; } // init BSSLConnection BSSLConnection_Init(&o->sslcon, o->ssl_prfd, 0, BReactor_PendingGroup(o->reactor), o, (BSSLConnection_handler)sslcon_handler); send_iface = BSSLConnection_GetSendIf(&o->sslcon); recv_iface = BSSLConnection_GetRecvIf(&o->sslcon); } // init input chain PacketPassInterface_Init(&o->input_interface, SC_MAX_ENC, (PacketPassInterface_handler_send)input_handler_send, o, BReactor_PendingGroup(o->reactor)); if (!PacketProtoDecoder_Init(&o->input_decoder, recv_iface, &o->input_interface, BReactor_PendingGroup(o->reactor), o, (PacketProtoDecoder_handler_error)decoder_handler_error)) { BLog(BLOG_ERROR, "PacketProtoDecoder_Init failed"); goto fail2; } // set job to send hello // this needs to be in here because hello sending must be done after sending started (so we can write into the send buffer), // but before receiving started (so we don't get into conflict with the user sending packets) BPending_Init(&o->start_job, BReactor_PendingGroup(o->reactor), (BPending_handler)pending_handler, o); BPending_Set(&o->start_job); // init keepalive output branch SCKeepaliveSource_Init(&o->output_ka_zero, BReactor_PendingGroup(o->reactor)); PacketProtoEncoder_Init(&o->output_ka_encoder, SCKeepaliveSource_GetOutput(&o->output_ka_zero), BReactor_PendingGroup(o->reactor)); // init output common // init sender PacketStreamSender_Init(&o->output_sender, send_iface, PACKETPROTO_ENCLEN(SC_MAX_ENC), BReactor_PendingGroup(o->reactor)); // init keepalives if (!KeepaliveIO_Init(&o->output_keepaliveio, o->reactor, PacketStreamSender_GetInput(&o->output_sender), PacketProtoEncoder_GetOutput(&o->output_ka_encoder), o->keepalive_interval)) { BLog(BLOG_ERROR, "KeepaliveIO_Init failed"); goto fail3; } // init queue PacketPassPriorityQueue_Init(&o->output_queue, KeepaliveIO_GetInput(&o->output_keepaliveio), BReactor_PendingGroup(o->reactor), 0); // init output local flow // init queue flow PacketPassPriorityQueueFlow_Init(&o->output_local_qflow, &o->output_queue, 0); // init PacketProtoFlow if (!PacketProtoFlow_Init(&o->output_local_oflow, SC_MAX_ENC, o->buffer_size, PacketPassPriorityQueueFlow_GetInput(&o->output_local_qflow), BReactor_PendingGroup(o->reactor))) { BLog(BLOG_ERROR, "PacketProtoFlow_Init failed"); goto fail4; } o->output_local_if = PacketProtoFlow_GetInput(&o->output_local_oflow); // have no output packet o->output_local_packet_len = -1; // init output user flow PacketPassPriorityQueueFlow_Init(&o->output_user_qflow, &o->output_queue, 1); // update state o->state = STATE_WAITINIT; return; fail4: PacketPassPriorityQueueFlow_Free(&o->output_local_qflow); PacketPassPriorityQueue_Free(&o->output_queue); KeepaliveIO_Free(&o->output_keepaliveio); fail3: PacketStreamSender_Free(&o->output_sender); PacketProtoEncoder_Free(&o->output_ka_encoder); SCKeepaliveSource_Free(&o->output_ka_zero); BPending_Free(&o->start_job); PacketProtoDecoder_Free(&o->input_decoder); fail2: PacketPassInterface_Free(&o->input_interface); if (o->have_ssl) { BSSLConnection_Free(&o->sslcon); fail1: ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS) } fail0a: BConnection_RecvAsync_Free(&o->con); BConnection_SendAsync_Free(&o->con); BConnection_Free(&o->con); fail0: // report error report_error(o); } void pending_handler (ServerConnection *o) { ASSERT(o->state == STATE_WAITINIT) ASSERT(!o->buffers_released) DebugObject_Access(&o->d_obj); // send hello struct sc_client_hello omsg; void *packet; if (!start_packet(o, &packet, sizeof(omsg))) { BLog(BLOG_ERROR, "no buffer for hello"); report_error(o); return; } omsg.version = htol16(SC_VERSION); memcpy(packet, &omsg, sizeof(omsg)); end_packet(o, SCID_CLIENTHELLO); } SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey) { ASSERT(o->have_ssl) DebugObject_Access(&o->d_obj); CERTCertificate *newcert; if (!(newcert = CERT_DupCertificate(o->client_cert))) { return SECFailure; } SECKEYPrivateKey *newkey; if (!(newkey = SECKEY_CopyPrivateKey(o->client_key))) { CERT_DestroyCertificate(newcert); return SECFailure; } *pRetCert = newcert; *pRetKey = newkey; return SECSuccess; } void connection_handler (ServerConnection *o, int event) { DebugObject_Access(&o->d_obj); ASSERT(o->state >= STATE_WAITINIT) ASSERT(!o->buffers_released) if (event == BCONNECTION_EVENT_RECVCLOSED) { BLog(BLOG_INFO, "connection closed"); } else { BLog(BLOG_INFO, "connection error"); } report_error(o); return; } void sslcon_handler (ServerConnection *o, int event) { DebugObject_Access(&o->d_obj); ASSERT(o->have_ssl) ASSERT(o->state >= STATE_WAITINIT) ASSERT(!o->buffers_released) ASSERT(event == BSSLCONNECTION_EVENT_ERROR) BLog(BLOG_ERROR, "SSL error"); report_error(o); return; } void decoder_handler_error (ServerConnection *o) { DebugObject_Access(&o->d_obj); ASSERT(o->state >= STATE_WAITINIT) ASSERT(!o->buffers_released) BLog(BLOG_ERROR, "decoder error"); report_error(o); return; } void input_handler_send (ServerConnection *o, uint8_t *data, int data_len) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(!o->buffers_released) ASSERT(data_len >= 0) ASSERT(data_len <= SC_MAX_ENC) DebugObject_Access(&o->d_obj); // accept packet PacketPassInterface_Done(&o->input_interface); // parse header if (data_len < sizeof(struct sc_header)) { BLog(BLOG_ERROR, "packet too short (no sc header)"); report_error(o); return; } struct sc_header header; memcpy(&header, data, sizeof(header)); data += sizeof(header); data_len -= sizeof(header); uint8_t type = ltoh8(header.type); // call appropriate handler based on packet type switch (type) { case SCID_SERVERHELLO: packet_hello(o, data, data_len); return; case SCID_NEWCLIENT: packet_newclient(o, data, data_len); return; case SCID_ENDCLIENT: packet_endclient(o, data, data_len); return; case SCID_INMSG: packet_inmsg(o, data, data_len); return; default: BLog(BLOG_ERROR, "unknown packet type %d", (int)type); report_error(o); return; } } void packet_hello (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_WAITINIT) { BLog(BLOG_ERROR, "hello: not expected"); report_error(o); return; } if (data_len != sizeof(struct sc_server_hello)) { BLog(BLOG_ERROR, "hello: invalid length"); report_error(o); return; } struct sc_server_hello msg; memcpy(&msg, data, sizeof(msg)); peerid_t id = ltoh16(msg.id); // change state o->state = STATE_COMPLETE; // report o->handler_ready(o->user, id, msg.clientAddr); return; } void packet_newclient (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "newclient: not expected"); report_error(o); return; } if (data_len < sizeof(struct sc_server_newclient) || data_len > sizeof(struct sc_server_newclient) + SCID_NEWCLIENT_MAX_CERT_LEN) { BLog(BLOG_ERROR, "newclient: invalid length"); report_error(o); return; } struct sc_server_newclient msg; memcpy(&msg, data, sizeof(msg)); peerid_t id = ltoh16(msg.id); // schedule reporting new client o->newclient_data = data; o->newclient_data_len = data_len; BPending_Set(&o->newclient_job); // send acceptpeer struct sc_client_acceptpeer omsg; void *packet; if (!start_packet(o, &packet, sizeof(omsg))) { BLog(BLOG_ERROR, "newclient: out of buffer for acceptpeer"); report_error(o); return; } omsg.clientid = htol16(id); memcpy(packet, &omsg, sizeof(omsg)); end_packet(o, SCID_ACCEPTPEER); } void packet_endclient (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "endclient: not expected"); report_error(o); return; } if (data_len != sizeof(struct sc_server_endclient)) { BLog(BLOG_ERROR, "endclient: invalid length"); report_error(o); return; } struct sc_server_endclient msg; memcpy(&msg, data, sizeof(msg)); peerid_t id = ltoh16(msg.id); // report o->handler_endclient(o->user, id); return; } void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "inmsg: not expected"); report_error(o); return; } if (data_len < sizeof(struct sc_server_inmsg)) { BLog(BLOG_ERROR, "inmsg: missing header"); report_error(o); return; } if (data_len > sizeof(struct sc_server_inmsg) + SC_MAX_MSGLEN) { BLog(BLOG_ERROR, "inmsg: too long"); report_error(o); return; } struct sc_server_inmsg msg; memcpy(&msg, data, sizeof(msg)); peerid_t peer_id = ltoh16(msg.clientid); uint8_t *payload = data + sizeof(struct sc_server_inmsg); int payload_len = data_len - sizeof(struct sc_server_inmsg); // report o->handler_message(o->user, peer_id, payload, payload_len); return; } int start_packet (ServerConnection *o, void **data, int len) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(o->output_local_packet_len == -1) ASSERT(len >= 0) ASSERT(len <= SC_MAX_PAYLOAD) ASSERT(data || len == 0) // obtain memory location if (!BufferWriter_StartPacket(o->output_local_if, &o->output_local_packet)) { BLog(BLOG_ERROR, "out of buffer"); return 0; } o->output_local_packet_len = len; if (data) { *data = o->output_local_packet + sizeof(struct sc_header); } return 1; } void end_packet (ServerConnection *o, uint8_t type) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(o->output_local_packet_len >= 0) ASSERT(o->output_local_packet_len <= SC_MAX_PAYLOAD) // write header struct sc_header header; header.type = htol8(type); memcpy(o->output_local_packet, &header, sizeof(header)); // finish writing packet BufferWriter_EndPacket(o->output_local_if, sizeof(struct sc_header) + o->output_local_packet_len); o->output_local_packet_len = -1; } int ServerConnection_Init ( ServerConnection *o, BReactor *reactor, BThreadWorkDispatcher *twd, BAddr addr, int keepalive_interval, int buffer_size, int have_ssl, int ssl_flags, CERTCertificate *client_cert, SECKEYPrivateKey *client_key, const char *server_name, void *user, ServerConnection_handler_error handler_error, ServerConnection_handler_ready handler_ready, ServerConnection_handler_newclient handler_newclient, ServerConnection_handler_endclient handler_endclient, ServerConnection_handler_message handler_message ) { ASSERT(keepalive_interval > 0) ASSERT(buffer_size > 0) ASSERT(have_ssl == 0 || have_ssl == 1) ASSERT(!have_ssl || server_name) // init arguments o->reactor = reactor; o->twd = twd; o->keepalive_interval = keepalive_interval; o->buffer_size = buffer_size; o->have_ssl = have_ssl; if (have_ssl) { o->ssl_flags = ssl_flags; o->client_cert = client_cert; o->client_key = client_key; } o->user = user; o->handler_error = handler_error; o->handler_ready = handler_ready; o->handler_newclient = handler_newclient; o->handler_endclient = handler_endclient; o->handler_message = handler_message; o->server_name = NULL; if (have_ssl && !(o->server_name = b_strdup(server_name))) { BLog(BLOG_ERROR, "malloc failed"); goto fail0; } if (!BConnection_AddressSupported(addr)) { BLog(BLOG_ERROR, "BConnection_AddressSupported failed"); goto fail1; } // init connector if (!BConnector_Init(&o->connector, addr, o->reactor, o, (BConnector_handler)connector_handler)) { BLog(BLOG_ERROR, "BConnector_Init failed"); goto fail1; } // init newclient job BPending_Init(&o->newclient_job, BReactor_PendingGroup(o->reactor), (BPending_handler)newclient_job_handler, o); // set state o->state = STATE_CONNECTING; o->buffers_released = 0; DebugError_Init(&o->d_err, BReactor_PendingGroup(o->reactor)); DebugObject_Init(&o->d_obj); return 1; fail1: free(o->server_name); fail0: return 0; } void ServerConnection_Free (ServerConnection *o) { DebugObject_Free(&o->d_obj); DebugError_Free(&o->d_err); if (o->state > STATE_CONNECTING) { // allow freeing queue flows PacketPassPriorityQueue_PrepareFree(&o->output_queue); // stop using any buffers before they get freed if (o->have_ssl && !o->buffers_released) { BSSLConnection_ReleaseBuffers(&o->sslcon); } // free output user flow PacketPassPriorityQueueFlow_Free(&o->output_user_qflow); // free output local flow PacketProtoFlow_Free(&o->output_local_oflow); PacketPassPriorityQueueFlow_Free(&o->output_local_qflow); // free output common PacketPassPriorityQueue_Free(&o->output_queue); KeepaliveIO_Free(&o->output_keepaliveio); PacketStreamSender_Free(&o->output_sender); // free output keep-alive branch PacketProtoEncoder_Free(&o->output_ka_encoder); SCKeepaliveSource_Free(&o->output_ka_zero); // free job BPending_Free(&o->start_job); // free input chain PacketProtoDecoder_Free(&o->input_decoder); PacketPassInterface_Free(&o->input_interface); // free SSL if (o->have_ssl) { BSSLConnection_Free(&o->sslcon); ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS) } // free connection interfaces BConnection_RecvAsync_Free(&o->con); BConnection_SendAsync_Free(&o->con); // free connection BConnection_Free(&o->con); } // free newclient job BPending_Free(&o->newclient_job); // free connector BConnector_Free(&o->connector); // free server name free(o->server_name); } void ServerConnection_ReleaseBuffers (ServerConnection *o) { DebugObject_Access(&o->d_obj); ASSERT(!o->buffers_released) if (o->state > STATE_CONNECTING && o->have_ssl) { BSSLConnection_ReleaseBuffers(&o->sslcon); } o->buffers_released = 1; } PacketPassInterface * ServerConnection_GetSendInterface (ServerConnection *o) { ASSERT(o->state == STATE_COMPLETE) DebugError_AssertNoError(&o->d_err); DebugObject_Access(&o->d_obj); return PacketPassPriorityQueueFlow_GetInput(&o->output_user_qflow); } void newclient_job_handler (ServerConnection *o) { DebugObject_Access(&o->d_obj); ASSERT(o->state == STATE_COMPLETE) struct sc_server_newclient msg; memcpy(&msg, o->newclient_data, sizeof(msg)); peerid_t id = ltoh16(msg.id); int flags = ltoh16(msg.flags); uint8_t *cert_data = o->newclient_data + sizeof(msg); int cert_len = o->newclient_data_len - sizeof(msg); // report new client o->handler_newclient(o->user, id, flags, cert_data, cert_len); return; }