Commit a6eb48e4 authored by Philipp Schafft's avatar Philipp Schafft 🦁

Update: Abstracted TLS support in tls.c

parent 64b16f82
......@@ -695,11 +695,7 @@ static inline xmlNodePtr __add_listener(client_t *client,
if (client->role)
xmlNewTextChild(node, NULL, XMLSTR("role"), XMLSTR(client->role));
#ifdef HAVE_OPENSSL
xmlNewTextChild(node, NULL, XMLSTR("tls"), XMLSTR(client->con->ssl ? "true" : "false"));
#else
xmlNewTextChild(node, NULL, XMLSTR("tls"), XMLSTR("false"));
#endif
xmlNewTextChild(node, NULL, XMLSTR("tls"), XMLSTR(client->con->tls ? "true" : "false"));
return node;
}
......
......@@ -105,19 +105,17 @@ static inline void client_reuseconnection(client_t *client) {
client->con->sock = -1; /* TODO: do not use magic */
/* handle to keep the TLS connection */
#ifdef HAVE_OPENSSL
if (client->con->ssl) {
if (client->con->tls) {
/* AHhhggrr.. That pain....
* stealing SSL state...
* stealing TLS state...
*/
con->ssl = client->con->ssl;
con->tls = client->con->tls;
con->read = client->con->read;
con->send = client->con->send;
client->con->ssl = NULL;
client->con->tls = NULL;
client->con->read = NULL;
client->con->send = NULL;
}
#endif
client->reuse = ICECAST_REUSE_CLOSE;
......
......@@ -175,14 +175,11 @@ static void get_ssl_certificate(ice_config_t *config)
*/
static int connection_read_ssl(connection_t *con, void *buf, size_t len)
{
int bytes = SSL_read(con->ssl, buf, len);
ssize_t bytes = tls_read(con->tls, buf, len);
if (bytes < 0) {
switch (SSL_get_error(con->ssl, bytes)) {
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
if (tls_want_io(con->tls) > 0)
return -1;
}
con->error = 1;
}
return bytes;
......@@ -190,14 +187,11 @@ static int connection_read_ssl(connection_t *con, void *buf, size_t len)
static int connection_send_ssl(connection_t *con, const void *buf, size_t len)
{
int bytes = SSL_write (con->ssl, buf, len);
ssize_t bytes = tls_write(con->tls, buf, len);
if (bytes < 0) {
switch (SSL_get_error(con->ssl, bytes)){
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
return -1;
}
if (tls_want_io(con->tls) > 0)
return -1;
con->error = 1;
} else {
con->sent_bytes += bytes;
......@@ -263,14 +257,14 @@ connection_t *connection_create (sock_t sock, sock_t serversock, char *ip)
void connection_uses_ssl(connection_t *con)
{
#ifdef HAVE_OPENSSL
if (con->ssl)
if (con->tls)
return;
con->read = connection_read_ssl;
con->send = connection_send_ssl;
con->ssl = tls_ctx_SSL_new(tls_ctx);
SSL_set_accept_state(con->ssl);
SSL_set_fd(con->ssl, con->sock);
con->tls = tls_new(tls_ctx);
tls_set_incoming(con->tls);
tls_set_socket(con->tls, con->sock);
#endif
}
......@@ -1462,8 +1456,6 @@ void connection_close(connection_t *con)
sock_close(con->sock);
if (con->ip)
free(con->ip);
#ifdef HAVE_OPENSSL
if (con->ssl) { SSL_shutdown(con->ssl); SSL_free(con->ssl); }
#endif
tls_unref(con->tls);
free(con);
}
......@@ -16,10 +16,8 @@
#include <sys/types.h>
#include <time.h>
#ifdef HAVE_OPENSSL
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif
#include "tls.h"
#include "compat.h"
#include "common/httpp/httpp.h"
......@@ -42,9 +40,7 @@ typedef struct connection_tag
sock_t serversock;
int error;
#ifdef HAVE_OPENSSL
SSL *ssl; /* SSL handler */
#endif
tls_t *tls;
int (*send)(struct connection_tag *handle, const void *buf, size_t len);
int (*read)(struct connection_tag *handle, void *buf, size_t len);
......
......@@ -514,10 +514,8 @@ static refbuf_t *get_next_buffer (source_t *source)
}
source->last_read = current;
refbuf = source->format->get_buffer (source);
#ifdef HAVE_OPENSSL
if (source->client->con->ssl && (SSL_get_shutdown(source->client->con->ssl) & SSL_RECEIVED_SHUTDOWN))
if (source->client->con->tls && tls_got_shutdown(source->client->con->tls) > 1)
source->client->con->error = 1;
#endif
if (source->client->con && source->client->con->error)
{
ICECAST_LOG_INFO("End of Stream %s", source->mount);
......
......@@ -27,6 +27,12 @@ struct tls_ctx_tag {
SSL_CTX *ctx;
};
struct tls_tag {
size_t refc;
SSL *ssl;
tls_ctx_t *ctx;
};
void tls_initialize(void)
{
SSL_load_error_strings(); /* readable error messages */
......@@ -111,11 +117,119 @@ void tls_ctx_unref(tls_ctx_t *ctx)
free(ctx);
}
SSL *tls_ctx_SSL_new(tls_ctx_t *ctx)
tls_t *tls_new(tls_ctx_t *ctx)
{
tls_t *tls;
SSL *ssl;
if (!ctx)
return NULL;
return SSL_new(ctx->ctx);
ssl = SSL_new(ctx->ctx);
if (!ssl)
return NULL;
tls = calloc(1, sizeof(*tls));
if (!tls) {
SSL_free(ssl);
return NULL;
}
tls_ctx_ref(ctx);
tls->refc = 1;
tls->ssl = ssl;
tls->ctx = ctx;
return tls;
}
void tls_ref(tls_t *tls)
{
if (!tls)
return;
tls->refc++;
}
void tls_unref(tls_t *tls)
{
if (!tls)
return;
tls->refc--;
if (tls->refc)
return;
SSL_shutdown(tls->ssl);
SSL_free(tls->ssl);
if (tls->ctx)
tls_ctx_unref(tls->ctx);
free(tls);
}
void tls_set_incoming(tls_t *tls)
{
if (!tls)
return;
SSL_set_accept_state(tls->ssl);
}
void tls_set_socket(tls_t *tls, sock_t sock)
{
if (!tls)
return;
SSL_set_fd(tls->ssl, sock);
}
int tls_want_io(tls_t *tls)
{
int what;
if (!tls)
return -1;
what = SSL_want(tls->ssl);
switch (what) {
case SSL_WRITING:
case SSL_READING:
return 1;
break;
case SSL_NOTHING:
default:
return 0;
break;
}
}
int tls_got_shutdown(tls_t *tls)
{
if (!tls)
return -1;
if (SSL_get_shutdown(tls->ssl) & SSL_RECEIVED_SHUTDOWN) {
return 1;
} else {
return 0;
}
}
ssize_t tls_read(tls_t *tls, void *buffer, size_t len)
{
if (!tls)
return -1;
return SSL_read(tls->ssl, buffer, len);
}
ssize_t tls_write(tls_t *tls, const void *buffer, size_t len)
{
if (!tls)
return -1;
return SSL_write(tls->ssl, buffer, len);
}
#else
void tls_initialize(void)
......@@ -135,4 +249,42 @@ void tls_ctx_ref(tls_ctx_t *ctx)
void tls_ctx_unref(tls_ctx_t *ctx)
{
}
tls_t *tls_new(tls_ctx_t *ctx)
{
return NULL;
}
void tls_ref(tls_t *tls)
{
}
void tls_unref(tls_t *tls)
{
}
void tls_set_incoming(tls_t *tls)
{
}
void tls_set_socket(tls_t *tls, sock_t sock)
{
}
int tls_want_io(tls_t *tls)
{
return -1;
}
int tls_got_shutdown(tls_t *tls)
{
return -1;
}
ssize_t tls_read(tls_t *tls, void *buffer, size_t len)
{
return -1;
}
ssize_t tls_write(tls_t *tls, const void *buffer, size_t len)
{
return -1;
}
#endif
......@@ -14,7 +14,10 @@
#include <openssl/err.h>
#endif
#include "common/net/sock.h"
typedef struct tls_ctx_tag tls_ctx_t;
typedef struct tls_tag tls_t;
void tls_initialize(void);
void tls_shutdown(void);
......@@ -23,8 +26,18 @@ tls_ctx_t *tls_ctx_new(const char *cert_file, const char *key_file, const char *
void tls_ctx_ref(tls_ctx_t *ctx);
void tls_ctx_unref(tls_ctx_t *ctx);
#ifdef HAVE_OPENSSL
SSL *tls_ctx_SSL_new(tls_ctx_t *ctx);
#endif
tls_t *tls_new(tls_ctx_t *ctx);
void tls_ref(tls_t *tls);
void tls_unref(tls_t *tls);
void tls_set_incoming(tls_t *tls);
void tls_set_socket(tls_t *tls, sock_t sock);
int tls_want_io(tls_t *tls);
int tls_got_shutdown(tls_t *tls);
ssize_t tls_read(tls_t *tls, void *buffer, size_t len);
ssize_t tls_write(tls_t *tls, const void *buffer, size_t len);
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment