aboutsummaryrefslogtreecommitdiff
path: root/net/ipv4/inet_hashtables.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/inet_hashtables.c')
-rw-r--r--net/ipv4/inet_hashtables.c106
1 files changed, 82 insertions, 24 deletions
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 60d77e234a68..dc1c5629cd0d 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -109,6 +109,7 @@ static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb,
tb->l3mdev = l3mdev;
tb->port = port;
#if IS_ENABLED(CONFIG_IPV6)
+ tb->family = sk->sk_family;
if (sk->sk_family == AF_INET6)
tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
else
@@ -146,6 +147,9 @@ static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2,
const struct sock *sk)
{
#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family != tb2->family)
+ return false;
+
if (sk->sk_family == AF_INET6)
return ipv6_addr_equal(&tb2->v6_rcv_saddr,
&sk->sk_v6_rcv_saddr);
@@ -168,14 +172,15 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
*/
static void __inet_put_port(struct sock *sk)
{
- struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
- const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
- hashinfo->bhash_size);
- struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
- struct inet_bind_hashbucket *head2 =
- inet_bhashfn_portaddr(hashinfo, sk, sock_net(sk),
- inet_sk(sk)->inet_num);
+ struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
+ struct inet_bind_hashbucket *head, *head2;
+ struct net *net = sock_net(sk);
struct inet_bind_bucket *tb;
+ int bhash;
+
+ bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size);
+ head = &hashinfo->bhash[bhash];
+ head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num);
spin_lock(&head->lock);
tb = inet_csk(sk)->icsk_bind_hash;
@@ -207,19 +212,19 @@ EXPORT_SYMBOL(inet_put_port);
int __inet_inherit_port(const struct sock *sk, struct sock *child)
{
- struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
+ struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk);
unsigned short port = inet_sk(child)->inet_num;
- const int bhash = inet_bhashfn(sock_net(sk), port,
- table->bhash_size);
- struct inet_bind_hashbucket *head = &table->bhash[bhash];
- struct inet_bind_hashbucket *head2 =
- inet_bhashfn_portaddr(table, child, sock_net(sk), port);
+ struct inet_bind_hashbucket *head, *head2;
bool created_inet_bind_bucket = false;
- bool update_fastreuse = false;
struct net *net = sock_net(sk);
+ bool update_fastreuse = false;
struct inet_bind2_bucket *tb2;
struct inet_bind_bucket *tb;
- int l3mdev;
+ int bhash, l3mdev;
+
+ bhash = inet_bhashfn(net, port, table->bhash_size);
+ head = &table->bhash[bhash];
+ head2 = inet_bhashfn_portaddr(table, child, net, port);
spin_lock(&head->lock);
spin_lock(&head2->lock);
@@ -385,7 +390,7 @@ static inline struct sock *inet_lookup_run_bpf(struct net *net,
struct sock *sk, *reuse_sk;
bool no_reuseport;
- if (hashinfo != &tcp_hashinfo)
+ if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
return NULL; /* only TCP is supported */
no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport,
@@ -628,9 +633,9 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
*/
bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk)
{
- struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
- struct hlist_nulls_head *list;
+ struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
struct inet_ehash_bucket *head;
+ struct hlist_nulls_head *list;
spinlock_t *lock;
bool ret = true;
@@ -700,7 +705,7 @@ static int inet_reuseport_add_sock(struct sock *sk,
int __inet_hash(struct sock *sk, struct sock *osk)
{
- struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
+ struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
struct inet_listen_hashbucket *ilb2;
int err = 0;
@@ -746,7 +751,7 @@ EXPORT_SYMBOL_GPL(inet_hash);
void inet_unhash(struct sock *sk)
{
- struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
+ struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk);
if (sk_unhashed(sk))
return;
@@ -790,6 +795,9 @@ static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb,
int l3mdev, const struct sock *sk)
{
#if IS_ENABLED(CONFIG_IPV6)
+ if (sk->sk_family != tb->family)
+ return false;
+
if (sk->sk_family == AF_INET6)
return net_eq(ib2_net(tb), net) && tb->port == port &&
tb->l3mdev == l3mdev &&
@@ -806,6 +814,9 @@ bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const
#if IS_ENABLED(CONFIG_IPV6)
struct in6_addr addr_any = {};
+ if (sk->sk_family != tb->family)
+ return false;
+
if (sk->sk_family == AF_INET6)
return net_eq(ib2_net(tb), net) && tb->port == port &&
tb->l3mdev == l3mdev &&
@@ -833,7 +844,7 @@ inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net
struct inet_bind_hashbucket *
inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port)
{
- struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+ struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk);
u32 hash;
#if IS_ENABLED(CONFIG_IPV6)
struct in6_addr addr_any = {};
@@ -849,7 +860,7 @@ inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, in
int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
{
- struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+ struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk);
struct inet_bind2_bucket *tb2, *new_tb2;
int l3mdev = inet_sk_bound_l3mdev(sk);
struct inet_bind_hashbucket *head2;
@@ -947,8 +958,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
if (likely(remaining > 1))
remaining &= ~1U;
- net_get_random_once(table_perturb,
- INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb));
+ get_random_slow_once(table_perturb,
+ INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb));
index = port_offset & (INET_TABLE_PERTURB_SIZE - 1);
offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32);
@@ -1144,3 +1155,50 @@ int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo)
return 0;
}
EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc);
+
+struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo,
+ unsigned int ehash_entries)
+{
+ struct inet_hashinfo *new_hashinfo;
+ int i;
+
+ new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL);
+ if (!new_hashinfo)
+ goto err;
+
+ new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket),
+ GFP_KERNEL_ACCOUNT);
+ if (!new_hashinfo->ehash)
+ goto free_hashinfo;
+
+ new_hashinfo->ehash_mask = ehash_entries - 1;
+
+ if (inet_ehash_locks_alloc(new_hashinfo))
+ goto free_ehash;
+
+ for (i = 0; i < ehash_entries; i++)
+ INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i);
+
+ new_hashinfo->pernet = true;
+
+ return new_hashinfo;
+
+free_ehash:
+ vfree(new_hashinfo->ehash);
+free_hashinfo:
+ kfree(new_hashinfo);
+err:
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc);
+
+void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo)
+{
+ if (!hashinfo->pernet)
+ return;
+
+ inet_ehash_locks_free(hashinfo);
+ vfree(hashinfo->ehash);
+ kfree(hashinfo);
+}
+EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free);