aboutsummaryrefslogtreecommitdiff
path: root/net/mptcp
diff options
context:
space:
mode:
Diffstat (limited to 'net/mptcp')
-rw-r--r--net/mptcp/mptcp_pm_gen.c1
-rw-r--r--net/mptcp/pm_netlink.c3
-rw-r--r--net/mptcp/pm_userspace.c18
-rw-r--r--net/mptcp/protocol.c18
4 files changed, 32 insertions, 8 deletions
diff --git a/net/mptcp/mptcp_pm_gen.c b/net/mptcp/mptcp_pm_gen.c
index c30a2a90a192..bfb37c5a88c4 100644
--- a/net/mptcp/mptcp_pm_gen.c
+++ b/net/mptcp/mptcp_pm_gen.c
@@ -112,7 +112,6 @@ const struct genl_ops mptcp_pm_nl_ops[11] = {
.dumpit = mptcp_pm_nl_get_addr_dumpit,
.policy = mptcp_pm_get_addr_nl_policy,
.maxattr = MPTCP_PM_ATTR_TOKEN,
- .flags = GENL_UNS_ADMIN_PERM,
},
{
.cmd = MPTCP_PM_CMD_FLUSH_ADDRS,
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index db586a5b3866..45a2b5f05d38 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -524,7 +524,8 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info)
{
struct mptcp_pm_addr_entry *entry;
- list_for_each_entry(entry, &pernet->local_addr_list, list) {
+ list_for_each_entry_rcu(entry, &pernet->local_addr_list, list,
+ lockdep_is_held(&pernet->lock)) {
if (mptcp_addresses_equal(&entry->addr, info, entry->addr.port))
return entry;
}
diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c
index 2cceded3a83a..e35178f5205f 100644
--- a/net/mptcp/pm_userspace.c
+++ b/net/mptcp/pm_userspace.c
@@ -91,6 +91,7 @@ static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk,
struct mptcp_pm_addr_entry *addr)
{
struct mptcp_pm_addr_entry *entry, *tmp;
+ struct sock *sk = (struct sock *)msk;
list_for_each_entry_safe(entry, tmp, &msk->pm.userspace_pm_local_addr_list, list) {
if (mptcp_addresses_equal(&entry->addr, &addr->addr, false)) {
@@ -98,7 +99,7 @@ static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk,
* be used multiple times (e.g. fullmesh mode).
*/
list_del_rcu(&entry->list);
- kfree(entry);
+ sock_kfree_s(sk, entry, sizeof(*entry));
msk->pm.local_addr_used--;
return 0;
}
@@ -307,14 +308,17 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info)
lock_sock(sk);
+ spin_lock_bh(&msk->pm.lock);
match = mptcp_userspace_pm_lookup_addr_by_id(msk, id_val);
if (!match) {
GENL_SET_ERR_MSG(info, "address with specified id not found");
+ spin_unlock_bh(&msk->pm.lock);
release_sock(sk);
goto out;
}
list_move(&match->list, &free_list);
+ spin_unlock_bh(&msk->pm.lock);
mptcp_pm_remove_addrs(msk, &free_list);
@@ -559,6 +563,7 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info)
struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
struct net *net = sock_net(skb->sk);
+ struct mptcp_pm_addr_entry *entry;
struct mptcp_sock *msk;
int ret = -EINVAL;
struct sock *sk;
@@ -600,6 +605,17 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info)
if (loc.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
bkup = 1;
+ spin_lock_bh(&msk->pm.lock);
+ list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
+ if (mptcp_addresses_equal(&entry->addr, &loc.addr, false)) {
+ if (bkup)
+ entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
+ else
+ entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
+ }
+ }
+ spin_unlock_bh(&msk->pm.lock);
+
lock_sock(sk);
ret = mptcp_pm_nl_mp_prio_send_ack(msk, &loc.addr, &rem.addr, bkup);
release_sock(sk);
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 6d0e201c3eb2..48d480982b78 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -2082,7 +2082,8 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
slow = lock_sock_fast(ssk);
WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf);
WRITE_ONCE(tcp_sk(ssk)->window_clamp, window_clamp);
- tcp_cleanup_rbuf(ssk, 1);
+ if (tcp_can_send_ack(ssk))
+ tcp_cleanup_rbuf(ssk, 1);
unlock_sock_fast(ssk, slow);
}
}
@@ -2205,7 +2206,7 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
cmsg_flags = MPTCP_CMSG_INQ;
while (copied < len) {
- int bytes_read;
+ int err, bytes_read;
bytes_read = __mptcp_recvmsg_mskq(msk, msg, len - copied, flags, &tss, &cmsg_flags);
if (unlikely(bytes_read < 0)) {
@@ -2267,9 +2268,16 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
}
pr_debug("block timeout %ld\n", timeo);
- sk_wait_data(sk, &timeo, NULL);
+ mptcp_rcv_space_adjust(msk, copied);
+ err = sk_wait_data(sk, &timeo, NULL);
+ if (err < 0) {
+ err = copied ? : err;
+ goto out_err;
+ }
}
+ mptcp_rcv_space_adjust(msk, copied);
+
out_err:
if (cmsg_flags && copied >= 0) {
if (cmsg_flags & MPTCP_CMSG_TS)
@@ -2285,8 +2293,6 @@ out_err:
pr_debug("msk=%p rx queue empty=%d:%d copied=%d\n",
msk, skb_queue_empty_lockless(&sk->sk_receive_queue),
skb_queue_empty(&msk->receive_queue), copied);
- if (!(flags & MSG_PEEK))
- mptcp_rcv_space_adjust(msk, copied);
release_sock(sk);
return copied;
@@ -2864,8 +2870,10 @@ static int mptcp_init_sock(struct sock *sk)
if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net))
return -ENOMEM;
+ rcu_read_lock();
ret = mptcp_init_sched(mptcp_sk(sk),
mptcp_sched_find(mptcp_get_scheduler(net)));
+ rcu_read_unlock();
if (ret)
return ret;