net: Adapt checksum calculation to new net_pkt API

Let's just use the packet cursor relevantly.

Signed-off-by: Tomasz Bursztyka <tomasz.bursztyka@linux.intel.com>
This commit is contained in:
Tomasz Bursztyka 2018-12-03 20:35:16 +01:00 committed by Jukka Rissanen
parent 593f957618
commit 2caa0b4e43

View File

@ -390,24 +390,25 @@ int net_addr_pton(sa_family_t family, const char *src,
return 0;
}
static u16_t calc_chksum(u16_t sum, const u8_t *ptr, u16_t len)
static u16_t calc_chksum(u16_t sum, const u8_t *data, size_t len)
{
u16_t tmp;
const u8_t *end;
u16_t tmp;
end = ptr + len - 1;
end = data + len - 1;
while (ptr < end) {
tmp = (ptr[0] << 8) + ptr[1];
while (data < end) {
tmp = (data[0] << 8) + data[1];
sum += tmp;
if (sum < tmp) {
sum++;
}
ptr += 2;
data += 2;
}
if (ptr == end) {
tmp = ptr[0] << 8;
if (data == end) {
tmp = data[0] << 8;
sum += tmp;
if (sum < tmp) {
sum++;
@ -417,50 +418,37 @@ static u16_t calc_chksum(u16_t sum, const u8_t *ptr, u16_t len)
return sum;
}
static inline u16_t calc_chksum_pkt(u16_t sum, struct net_pkt *pkt,
u16_t upper_layer_len)
static inline u16_t pkt_calc_chksum(struct net_pkt *pkt, u16_t sum)
{
u16_t proto_len = net_pkt_ip_hdr_len(pkt) +
net_pkt_ipv6_ext_len(pkt);
struct net_buf *frag;
u16_t offset;
s16_t len;
u8_t *ptr;
struct net_pkt_cursor *cur = &pkt->cursor;
size_t len;
ARG_UNUSED(upper_layer_len);
frag = net_frag_skip(pkt->frags, proto_len, &offset, 0);
if (!frag) {
NET_DBG("Trying to read past pkt len (proto len %d)",
proto_len);
return 0;
if (!cur->buf || !cur->pos) {
return sum;
}
NET_ASSERT(offset <= frag->len);
len = cur->buf->len - (cur->pos - cur->buf->data);
ptr = frag->data + offset;
len = frag->len - offset;
while (cur->buf) {
sum = calc_chksum(sum, cur->pos, len);
while (frag) {
sum = calc_chksum(sum, ptr, len);
frag = frag->frags;
if (!frag) {
cur->buf = cur->buf->frags;
if (!cur->buf || !cur->buf->len) {
break;
}
ptr = frag->data;
cur->pos = cur->buf->data;
/* Do we need to take first byte from next fragment */
if (len % 2) {
u16_t tmp = *ptr;
sum += tmp;
if (sum < tmp) {
sum += *cur->pos;
if (sum < *cur->pos) {
sum++;
}
len = frag->len - 1;
ptr++;
cur->pos++;
len = cur->buf->len - 1;
} else {
len = frag->len;
len = cur->buf->len;
}
}
@ -469,41 +457,45 @@ static inline u16_t calc_chksum_pkt(u16_t sum, struct net_pkt *pkt,
u16_t net_calc_chksum(struct net_pkt *pkt, u8_t proto)
{
u16_t upper_layer_len;
size_t len = 0U;
u16_t sum = 0U;
struct net_pkt_cursor backup;
switch (net_pkt_family(pkt)) {
#if defined(CONFIG_NET_IPV4)
case AF_INET:
upper_layer_len = ntohs(NET_IPV4_HDR(pkt)->len) -
net_pkt_ipv6_ext_len(pkt) -
net_pkt_ip_hdr_len(pkt);
net_pkt_cursor_backup(pkt, &backup);
net_pkt_cursor_init(pkt);
net_pkt_set_overwrite(pkt, true);
if (IS_ENABLED(CONFIG_NET_IPV4) &&
net_pkt_family(pkt) == AF_INET) {
if (proto != IPPROTO_ICMP) {
sum = calc_chksum(upper_layer_len + proto,
(u8_t *)&NET_IPV4_HDR(pkt)->src,
2 * sizeof(struct in_addr));
len = 2 * sizeof(struct in_addr);
sum = net_pkt_get_len(pkt) -
net_pkt_ip_hdr_len(pkt) + proto;
}
break;
#endif
#if defined(CONFIG_NET_IPV6)
case AF_INET6:
upper_layer_len = ntohs(NET_IPV6_HDR(pkt)->len) -
net_pkt_ipv6_ext_len(pkt);
sum = calc_chksum(upper_layer_len + proto,
(u8_t *)&NET_IPV6_HDR(pkt)->src,
2 * sizeof(struct in6_addr));
break;
#endif
default:
} else if (IS_ENABLED(CONFIG_NET_IPV6) &&
net_pkt_family(pkt) == AF_INET6) {
len = 2 * sizeof(struct in6_addr);
sum = net_pkt_get_len(pkt) -
net_pkt_ip_hdr_len(pkt) -
net_pkt_ipv6_ext_len(pkt) + proto;
} else {
NET_DBG("Unknown protocol family %d", net_pkt_family(pkt));
return 0;
}
sum = calc_chksum_pkt(sum, pkt, upper_layer_len);
net_pkt_skip(pkt, net_pkt_ip_hdr_len(pkt) - len);
sum = calc_chksum(sum, pkt->cursor.pos, len);
net_pkt_skip(pkt, len + net_pkt_ipv6_ext_len(pkt));
sum = pkt_calc_chksum(pkt, sum);
sum = (sum == 0) ? 0xffff : htons(sum);
net_pkt_cursor_restore(pkt, &backup);
return ~sum;
}
@ -512,7 +504,7 @@ u16_t net_calc_chksum_ipv4(struct net_pkt *pkt)
{
u16_t sum;
sum = calc_chksum(0, (u8_t *)NET_IPV4_HDR(pkt), NET_IPV4H_LEN);
sum = calc_chksum(0, pkt->buffer->data, NET_IPV4H_LEN);
sum = (sum == 0) ? 0xffff : htons(sum);