zephyr/samples/net/http_server/src/ssl_utils.c
Paul Sokolovsky 25307d5331 net: net_pkt_append: Refactor to return length of data actually added
For stream-based protocols (TCP), adding less data than requested
("short write") is generally not a problem - the rest of data can
be sent in the next packet. So, make net_pkt_append() return length
of written data instead of just bool flag, which makes it closer
to the behavior of POSIX send()/write() calls.

There're many users of older net_pkt_append() in the codebase
however, so net_pkt_append_all() convenience function is added which
keeps returning a boolean flag. All current users were converted to
this function, except for two:

samples/net/http_server/src/ssl_utils.c
samples/net/mbedtls_sslclient/src/tcp.c

Both are related to TLS and implement mbedTLS "tx callback", which
follows POSIX short-write semantics. Both cases also had a code to
workaround previous boolean-only behavior of net_pkt_append() - after
calling it, they measured length of the actual data added (but only
in case of successful return of net_pkt_append(), so that didn't
really help). So, these 2 cases are already improved.

Jira: ZEP-1984

Change-Id: Ibaf7c029b15e91b516d73dab3612eed190ee982b
Signed-off-by: Paul Sokolovsky <paul.sokolovsky@linaro.org>
2017-04-28 15:01:09 +03:00

290 lines
6.0 KiB
C

/*
* Copyright (c) 2017 Intel Corporation
*
* SPDX-License-Identifier: Apache-2.0
*/
#include <zephyr.h>
#include <net/net_core.h>
#include <net/net_context.h>
#include <net/net_pkt.h>
#include <net/net_if.h>
#include <string.h>
#include <errno.h>
#include <misc/printk.h>
#if !defined(CONFIG_MBEDTLS_CFG_FILE)
#include "mbedtls/config.h"
#else
#include CONFIG_MBEDTLS_CFG_FILE
#endif
#include "mbedtls/ssl.h"
#include "config.h"
#include "ssl_utils.h"
#define RX_FIFO_DEPTH 4
K_MEM_POOL_DEFINE(rx_pkts, 4, 64, RX_FIFO_DEPTH, 4);
static void ssl_received(struct net_context *context,
struct net_buf *buf, int status, void *user_data)
{
struct ssl_context *ctx = user_data;
struct rx_fifo_block *rx_data = NULL;
struct k_mem_block block;
ARG_UNUSED(context);
ARG_UNUSED(status);
if (!net_pkt_appdatalen(buf)) {
net_pkt_unref(buf);
return;
}
k_mem_pool_alloc(&rx_pkts, &block,
sizeof(struct rx_fifo_block), K_FOREVER);
rx_data = block.data;
rx_data->buf = buf;
/* For freeing memory later */
memcpy(&rx_data->block, &block, sizeof(struct k_mem_block));
k_fifo_put(&ctx->rx_fifo, (void *)rx_data);
}
static inline void ssl_sent(struct net_context *context,
int status, void *token, void *user_data)
{
struct ssl_context *ctx = user_data;
k_sem_give(&ctx->tx_sem);
}
int ssl_tx(void *context, const unsigned char *buf, size_t size)
{
struct ssl_context *ctx = context;
struct net_context *net_ctx;
struct net_buf *send_buf;
int rc, len;
net_ctx = ctx->net_ctx;
send_buf = net_pkt_get_tx(net_ctx, K_NO_WAIT);
if (!send_buf) {
return MBEDTLS_ERR_SSL_ALLOC_FAILED;
}
len = net_pkt_append(send_buf, size, (u8_t *) buf, K_FOREVER);
rc = net_context_send(send_buf, ssl_sent, K_NO_WAIT, NULL, ctx);
if (rc < 0) {
net_pkt_unref(send_buf);
return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
}
k_sem_take(&ctx->tx_sem, K_FOREVER);
return len;
}
int ssl_rx(void *context, unsigned char *buf, size_t size)
{
struct ssl_context *ctx = context;
u16_t read_bytes;
struct rx_fifo_block *rx_data;
u8_t *ptr;
int pos;
int len;
int rc = 0;
if (ctx->frag == NULL) {
rx_data = k_fifo_get(&ctx->rx_fifo, K_FOREVER);
ctx->rx_pkt = rx_data->buf;
k_mem_pool_free(&rx_data->block);
read_bytes = net_pkt_appdatalen(ctx->rx_pkt);
ctx->remaining = read_bytes;
ctx->frag = ctx->rx_pkt->frags;
ptr = net_pkt_appdata(ctx->rx_pkt);
len = ptr - ctx->frag->data;
net_buf_pull(ctx->frag, len);
} else {
read_bytes = ctx->remaining;
ptr = ctx->frag->data;
}
len = ctx->frag->len;
pos = 0;
if (read_bytes > size) {
while (ctx->frag) {
read_bytes = len < (size - pos) ? len : (size - pos);
memcpy(buf + pos, ptr, read_bytes);
pos += read_bytes;
if (pos < size) {
ctx->frag = ctx->frag->frags;
ptr = ctx->frag->data;
len = ctx->frag->len;
} else {
if (read_bytes == len) {
ctx->frag = ctx->frag->frags;
} else {
net_buf_pull(ctx->frag, read_bytes);
}
ctx->remaining -= size;
return size;
}
}
} else {
while (ctx->frag) {
memcpy(buf + pos, ptr, len);
pos += len;
ctx->frag = ctx->frag->frags;
if (!ctx->frag) {
break;
}
ptr = ctx->frag->data;
len = ctx->frag->len;
}
net_pkt_unref(ctx->rx_pkt);
ctx->rx_pkt = NULL;
ctx->frag = NULL;
ctx->remaining = 0;
if (read_bytes != pos) {
return -EIO;
}
rc = read_bytes;
}
return rc;
}
static void ssl_accepted(struct net_context *context,
struct sockaddr *addr,
socklen_t addrlen, int error, void *user_data)
{
int ret;
struct ssl_context *ctx = user_data;
ctx->net_ctx = context;
ret = net_context_recv(context, ssl_received, 0, user_data);
if (ret < 0) {
printk("Cannot receive TCP packet (family %d)",
net_context_get_family(context));
}
}
#if defined(CONFIG_NET_IPV6)
int ssl_init(struct ssl_context *ctx, void *addr)
{
struct net_context *tcp_ctx = { 0 };
struct sockaddr_in6 my_addr = { 0 };
struct in6_addr *server_addr = addr;
int rc;
k_sem_init(&ctx->tx_sem, 0, UINT_MAX);
k_fifo_init(&ctx->rx_fifo);
my_mcast_addr.sin6_family = AF_INET6;
net_ipaddr_copy(&my_addr.sin6_addr, server_addr);
my_addr.sin6_family = AF_INET6;
my_addr.sin6_port = htons(SERVER_PORT);
rc = net_context_get(AF_INET6, SOCK_STREAM, IPPROTO_TCP, &tcp_ctx);
if (rc < 0) {
printk("Cannot get network context for IPv6 TCP (%d)", rc);
return -EIO;
}
rc = net_context_bind(tcp_ctx, (struct sockaddr *)&my_addr,
sizeof(struct sockaddr_in6));
if (rc < 0) {
printk("Cannot bind IPv6 TCP port %d (%d)", SERVER_PORT, rc);
goto error;
}
ctx->rx_pkt = NULL;
ctx->remaining = 0;
ctx->net_ctx = tcp_ctx;
rc = net_context_listen(ctx->net_ctx, 0);
if (rc < 0) {
printk("Cannot listen IPv6 TCP (%d)", rc);
return -EIO;
}
rc = net_context_accept(ctx->net_ctx, ssl_accepted, 0, ctx);
if (rc < 0) {
printk("Cannot accept IPv4 (%d)", rc);
return -EIO;
}
return 0;
error:
net_context_put(tcp_ctx);
return -EINVAL;
}
#else
int ssl_init(struct ssl_context *ctx, void *addr)
{
struct net_context *tcp_ctx = { 0 };
struct sockaddr_in my_addr4 = { 0 };
struct in_addr *server_addr = addr;
int rc;
k_sem_init(&ctx->tx_sem, 0, UINT_MAX);
k_fifo_init(&ctx->rx_fifo);
net_ipaddr_copy(&my_addr4.sin_addr, server_addr);
my_addr4.sin_family = AF_INET;
my_addr4.sin_port = htons(SERVER_PORT);
rc = net_context_get(AF_INET, SOCK_STREAM, IPPROTO_TCP, &tcp_ctx);
if (rc < 0) {
printk("Cannot get network context for IPv4 TCP (%d)", rc);
return -EIO;
}
rc = net_context_bind(tcp_ctx, (struct sockaddr *)&my_addr4,
sizeof(struct sockaddr_in));
if (rc < 0) {
printk("Cannot bind IPv4 TCP port %d (%d)", SERVER_PORT, rc);
goto error;
}
ctx->rx_pkt = NULL;
ctx->remaining = 0;
ctx->net_ctx = tcp_ctx;
rc = net_context_listen(ctx->net_ctx, 0);
if (rc < 0) {
printk("Cannot listen IPv4 (%d)", rc);
return -EIO;
}
rc = net_context_accept(ctx->net_ctx, ssl_accepted, 0, ctx);
if (rc < 0) {
printk("Cannot accept IPv4 (%d)", rc);
return -EIO;
}
return 0;
error:
net_context_put(tcp_ctx);
return -EINVAL;
}
#endif