diff --git a/subsys/net/lib/sockets/socketpair.c b/subsys/net/lib/sockets/socketpair.c index 5cec71bdd43..4d7c739f7fb 100644 --- a/subsys/net/lib/sockets/socketpair.c +++ b/subsys/net/lib/sockets/socketpair.c @@ -369,13 +369,13 @@ out: static ssize_t spair_write(void *obj, const void *buffer, size_t count) { int res; - bool is_connected; + int key; size_t avail; bool is_nonblock; - bool will_block; size_t bytes_written; bool have_local_sem = false; bool have_remote_sem = false; + bool will_block = false; struct spair *const spair = (struct spair *)obj; struct spair *remote = NULL; @@ -385,9 +385,10 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count) goto out; } + key = irq_lock(); is_nonblock = sock_is_nonblock(spair); - res = k_sem_take(&spair->sem, K_NO_WAIT); + irq_unlock(key); if (res < 0) { if (is_nonblock) { errno = EAGAIN; @@ -401,6 +402,7 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count) res = -1; goto out; } + is_nonblock = sock_is_nonblock(spair); } have_local_sem = true; @@ -408,10 +410,7 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count) remote = z_get_fd_obj(spair->remote, (const struct fd_op_vtable *)&spair_fd_op_vtable, 0); - is_connected = sock_is_connected(spair); - is_nonblock = sock_is_nonblock(spair); - - if (!is_connected) { + if (remote == NULL) { errno = EPIPE; res = -1; goto out; @@ -434,14 +433,17 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count) have_remote_sem = true; - avail = is_connected ? spair_write_avail(spair) : 0; - if (avail == 0 && is_nonblock) { - errno = EAGAIN; - res = -1; - goto out; + avail = spair_write_avail(spair); + + if (avail == 0) { + if (is_nonblock) { + errno = EAGAIN; + res = -1; + goto out; + } + will_block = true; } - will_block = (count > avail) && !is_nonblock; if (will_block) { for (int signaled = false, result = -1; !signaled; @@ -464,6 +466,16 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count) goto out; } + remote = z_get_fd_obj(spair->remote, + (const struct fd_op_vtable *) + &spair_fd_op_vtable, 0); + + if (remote == NULL) { + errno = EPIPE; + res = -1; + goto out; + } + res = k_sem_take(&remote->sem, K_NO_WAIT); if (res < 0) { if (is_nonblock) { @@ -569,14 +581,13 @@ out: static ssize_t spair_read(void *obj, void *buffer, size_t count) { int res; - + int key; bool is_connected; size_t avail; bool is_nonblock; - bool will_block; size_t bytes_read; - bool have_local_sem = false; + bool will_block = false; struct spair *const spair = (struct spair *)obj; if (obj == NULL || buffer == NULL || count == 0) { @@ -585,9 +596,10 @@ static ssize_t spair_read(void *obj, void *buffer, size_t count) goto out; } + key = irq_lock(); is_nonblock = sock_is_nonblock(spair); - res = k_sem_take(&spair->sem, K_NO_WAIT); + irq_unlock(key); if (res < 0) { if (is_nonblock) { errno = EAGAIN; @@ -601,24 +613,28 @@ static ssize_t spair_read(void *obj, void *buffer, size_t count) res = -1; goto out; } + is_nonblock = sock_is_nonblock(spair); } have_local_sem = true; is_connected = sock_is_connected(spair); avail = spair_read_avail(spair); - will_block = (avail == 0) && !is_nonblock; - if (avail == 0 && !is_connected) { - /* signal EOF */ - res = 0; - goto out; - } + if (avail == 0) { + if (!is_connected) { + /* signal EOF */ + res = 0; + goto out; + } - if (avail == 0 && is_nonblock) { - errno = EAGAIN; - res = -1; - goto out; + if (is_nonblock) { + errno = EAGAIN; + res = -1; + goto out; + } + + will_block = true; } if (will_block) {