diff options
Diffstat (limited to 'net/xfrm/xfrm_state.c')
-rw-r--r-- | net/xfrm/xfrm_state.c | 171 |
1 files changed, 160 insertions, 11 deletions
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 37478d36a8df..67ca7ac955a3 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -665,6 +665,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) refcount_set(&x->refcnt, 1); atomic_set(&x->tunnel_users, 0); INIT_LIST_HEAD(&x->km.all); + INIT_HLIST_NODE(&x->state_cache); INIT_HLIST_NODE(&x->bydst); INIT_HLIST_NODE(&x->bysrc); INIT_HLIST_NODE(&x->byspi); @@ -679,6 +680,7 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) x->lft.hard_packet_limit = XFRM_INF; x->replay_maxage = 0; x->replay_maxdiff = 0; + x->pcpu_num = UINT_MAX; spin_lock_init(&x->lock); } return x; @@ -743,12 +745,18 @@ int __xfrm_state_delete(struct xfrm_state *x) if (x->km.state != XFRM_STATE_DEAD) { x->km.state = XFRM_STATE_DEAD; + spin_lock(&net->xfrm.xfrm_state_lock); list_del(&x->km.all); hlist_del_rcu(&x->bydst); hlist_del_rcu(&x->bysrc); if (x->km.seq) hlist_del_rcu(&x->byseq); + if (!hlist_unhashed(&x->state_cache)) + hlist_del_rcu(&x->state_cache); + if (!hlist_unhashed(&x->state_cache_input)) + hlist_del_rcu(&x->state_cache_input); + if (x->id.spi) hlist_del_rcu(&x->byspi); net->xfrm.state_num--; @@ -1101,6 +1109,52 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, return NULL; } +struct xfrm_state *xfrm_input_state_lookup(struct net *net, u32 mark, + const xfrm_address_t *daddr, + __be32 spi, u8 proto, + unsigned short family) +{ + struct hlist_head *state_cache_input; + struct xfrm_state *x = NULL; + int cpu = get_cpu(); + + state_cache_input = per_cpu_ptr(net->xfrm.state_cache_input, cpu); + + rcu_read_lock(); + hlist_for_each_entry_rcu(x, state_cache_input, state_cache_input) { + if (x->props.family != family || + x->id.spi != spi || + x->id.proto != proto || + !xfrm_addr_equal(&x->id.daddr, daddr, family)) + continue; + + if ((mark & x->mark.m) != x->mark.v) + continue; + if (!xfrm_state_hold_rcu(x)) + continue; + goto out; + } + + x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family); + + if (x && x->km.state == XFRM_STATE_VALID) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + if (hlist_unhashed(&x->state_cache_input)) { + hlist_add_head_rcu(&x->state_cache_input, state_cache_input); + } else { + hlist_del_rcu(&x->state_cache_input); + hlist_add_head_rcu(&x->state_cache_input, state_cache_input); + } + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } + +out: + rcu_read_unlock(); + put_cpu(); + return x; +} +EXPORT_SYMBOL(xfrm_input_state_lookup); + static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, const xfrm_address_t *daddr, const xfrm_address_t *saddr, @@ -1155,6 +1209,12 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x, struct xfrm_state **best, int *acq_in_progress, int *error) { + /* We need the cpu id just as a lookup key, + * we don't require it to be stable. + */ + unsigned int pcpu_id = get_cpu(); + put_cpu(); + /* Resolution logic: * 1. There is a valid state with matching selector. Done. * 2. Valid state with inappropriate selector. Skip. @@ -1174,13 +1234,18 @@ static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x, &fl->u.__fl_common)) return; + if (x->pcpu_num != UINT_MAX && x->pcpu_num != pcpu_id) + return; + if (!*best || + ((*best)->pcpu_num == UINT_MAX && x->pcpu_num == pcpu_id) || (*best)->km.dying > x->km.dying || ((*best)->km.dying == x->km.dying && (*best)->curlft.add_time < x->curlft.add_time)) *best = x; } else if (x->km.state == XFRM_STATE_ACQ) { - *acq_in_progress = 1; + if (!*best || x->pcpu_num == pcpu_id) + *acq_in_progress = 1; } else if (x->km.state == XFRM_STATE_ERROR || x->km.state == XFRM_STATE_EXPIRED) { if ((!x->sel.family || @@ -1209,12 +1274,60 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, unsigned short encap_family = tmpl->encap_family; unsigned int sequence; struct km_event c; + unsigned int pcpu_id; + bool cached = false; + + /* We need the cpu id just as a lookup key, + * we don't require it to be stable. + */ + pcpu_id = get_cpu(); + put_cpu(); to_put = NULL; sequence = read_seqcount_begin(&net->xfrm.xfrm_state_hash_generation); rcu_read_lock(); + hlist_for_each_entry_rcu(x, &pol->state_cache_list, state_cache) { + if (x->props.family == encap_family && + x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && + x->if_id == if_id && + !(x->props.flags & XFRM_STATE_WILDRECV) && + xfrm_state_addr_check(x, daddr, saddr, encap_family) && + tmpl->mode == x->props.mode && + tmpl->id.proto == x->id.proto && + (tmpl->id.spi == x->id.spi || !tmpl->id.spi)) + xfrm_state_look_at(pol, x, fl, encap_family, + &best, &acquire_in_progress, &error); + } + + if (best) + goto cached; + + hlist_for_each_entry_rcu(x, &pol->state_cache_list, state_cache) { + if (x->props.family == encap_family && + x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && + x->if_id == if_id && + !(x->props.flags & XFRM_STATE_WILDRECV) && + xfrm_addr_equal(&x->id.daddr, daddr, encap_family) && + tmpl->mode == x->props.mode && + tmpl->id.proto == x->id.proto && + (tmpl->id.spi == x->id.spi || !tmpl->id.spi)) + xfrm_state_look_at(pol, x, fl, family, + &best, &acquire_in_progress, &error); + } + +cached: + cached = true; + if (best) + goto found; + else if (error) + best = NULL; + else if (acquire_in_progress) /* XXX: acquire_in_progress should not happen */ + WARN_ON(1); + h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) { #ifdef CONFIG_XFRM_OFFLOAD @@ -1282,7 +1395,10 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, } found: - x = best; + if (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) || + (best && (best->pcpu_num == pcpu_id))) + x = best; + if (!x && !error && !acquire_in_progress) { if (tmpl->id.spi && (x0 = __xfrm_state_lookup_all(net, mark, daddr, @@ -1314,6 +1430,8 @@ found: xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family); memcpy(&x->mark, &pol->mark, sizeof(x->mark)); x->if_id = if_id; + if ((pol->flags & XFRM_POLICY_CPU_ACQUIRE) && best) + x->pcpu_num = pcpu_id; error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid); if (error) { @@ -1352,6 +1470,7 @@ found: x->km.state = XFRM_STATE_ACQ; x->dir = XFRM_SA_DIR_OUT; list_add(&x->km.all, &net->xfrm.state_all); + h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); XFRM_STATE_INSERT(bydst, &x->bydst, net->xfrm.state_bydst + h, x->xso.type); @@ -1359,6 +1478,7 @@ found: XFRM_STATE_INSERT(bysrc, &x->bysrc, net->xfrm.state_bysrc + h, x->xso.type); + INIT_HLIST_NODE(&x->state_cache); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family); XFRM_STATE_INSERT(byspi, &x->byspi, @@ -1392,6 +1512,11 @@ found: x = NULL; error = -ESRCH; } + + /* Use the already installed 'fallback' while the CPU-specific + * SA acquire is handled*/ + if (best) + x = best; } out: if (x) { @@ -1402,6 +1527,15 @@ out: } else { *err = acquire_in_progress ? -EAGAIN : error; } + + if (x && x->km.state == XFRM_STATE_VALID && !cached && + (!(pol->flags & XFRM_POLICY_CPU_ACQUIRE) || x->pcpu_num == pcpu_id)) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + if (hlist_unhashed(&x->state_cache)) + hlist_add_head_rcu(&x->state_cache, &pol->state_cache_list); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } + rcu_read_unlock(); if (to_put) xfrm_state_put(to_put); @@ -1524,12 +1658,14 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew) unsigned int h; u32 mark = xnew->mark.v & xnew->mark.m; u32 if_id = xnew->if_id; + u32 cpu_id = xnew->pcpu_num; h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family); hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { if (x->props.family == family && x->props.reqid == reqid && x->if_id == if_id && + x->pcpu_num == cpu_id && (mark & x->mark.m) == x->mark.v && xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) && xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family)) @@ -1552,7 +1688,7 @@ EXPORT_SYMBOL(xfrm_state_insert); static struct xfrm_state *__find_acq_core(struct net *net, const struct xfrm_mark *m, unsigned short family, u8 mode, - u32 reqid, u32 if_id, u8 proto, + u32 reqid, u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr, const xfrm_address_t *saddr, int create) @@ -1569,6 +1705,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, x->id.spi != 0 || x->id.proto != proto || (mark & x->mark.m) != x->mark.v || + x->pcpu_num != pcpu_num || !xfrm_addr_equal(&x->id.daddr, daddr, family) || !xfrm_addr_equal(&x->props.saddr, saddr, family)) continue; @@ -1602,6 +1739,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, break; } + x->pcpu_num = pcpu_num; x->km.state = XFRM_STATE_ACQ; x->id.proto = proto; x->props.family = family; @@ -1630,7 +1768,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, return x; } -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq); +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num); int xfrm_state_add(struct xfrm_state *x) { @@ -1656,7 +1794,7 @@ int xfrm_state_add(struct xfrm_state *x) } if (use_spi && x->km.seq) { - x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq); + x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq, x->pcpu_num); if (x1 && ((x1->id.proto != x->id.proto) || !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) { to_put = x1; @@ -1666,7 +1804,7 @@ int xfrm_state_add(struct xfrm_state *x) if (use_spi && !x1) x1 = __find_acq_core(net, &x->mark, family, x->props.mode, - x->props.reqid, x->if_id, x->id.proto, + x->props.reqid, x->if_id, x->pcpu_num, x->id.proto, &x->id.daddr, &x->props.saddr, 0); __xfrm_state_bump_genids(x); @@ -1791,6 +1929,7 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, x->props.flags = orig->props.flags; x->props.extra_flags = orig->props.extra_flags; + x->pcpu_num = orig->pcpu_num; x->if_id = orig->if_id; x->tfcpad = orig->tfcpad; x->replay_maxdiff = orig->replay_maxdiff; @@ -2066,13 +2205,14 @@ EXPORT_SYMBOL(xfrm_state_lookup_byaddr); struct xfrm_state * xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid, - u32 if_id, u8 proto, const xfrm_address_t *daddr, + u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr, const xfrm_address_t *saddr, int create, unsigned short family) { struct xfrm_state *x; spin_lock_bh(&net->xfrm.xfrm_state_lock); - x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create); + x = __find_acq_core(net, mark, family, mode, reqid, if_id, pcpu_num, + proto, daddr, saddr, create); spin_unlock_bh(&net->xfrm.xfrm_state_lock); return x; @@ -2207,7 +2347,7 @@ xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n, /* Silly enough, but I'm lazy to build resolution list */ -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num) { unsigned int h = xfrm_seq_hash(net, seq); struct xfrm_state *x; @@ -2215,6 +2355,7 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) { if (x->km.seq == seq && (mark & x->mark.m) == x->mark.v && + x->pcpu_num == pcpu_num && x->km.state == XFRM_STATE_ACQ) { xfrm_state_hold(x); return x; @@ -2224,12 +2365,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s return NULL; } -struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq, u32 pcpu_num) { struct xfrm_state *x; spin_lock_bh(&net->xfrm.xfrm_state_lock); - x = __xfrm_find_acq_byseq(net, mark, seq); + x = __xfrm_find_acq_byseq(net, mark, seq, pcpu_num); spin_unlock_bh(&net->xfrm.xfrm_state_lock); return x; } @@ -2988,6 +3129,11 @@ int __net_init xfrm_state_init(struct net *net) net->xfrm.state_byseq = xfrm_hash_alloc(sz); if (!net->xfrm.state_byseq) goto out_byseq; + + net->xfrm.state_cache_input = alloc_percpu(struct hlist_head); + if (!net->xfrm.state_cache_input) + goto out_state_cache_input; + net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1); net->xfrm.state_num = 0; @@ -2997,6 +3143,8 @@ int __net_init xfrm_state_init(struct net *net) &net->xfrm.xfrm_state_lock); return 0; +out_state_cache_input: + xfrm_hash_free(net->xfrm.state_byseq, sz); out_byseq: xfrm_hash_free(net->xfrm.state_byspi, sz); out_byspi: @@ -3026,6 +3174,7 @@ void xfrm_state_fini(struct net *net) xfrm_hash_free(net->xfrm.state_bysrc, sz); WARN_ON(!hlist_empty(net->xfrm.state_bydst)); xfrm_hash_free(net->xfrm.state_bydst, sz); + free_percpu(net->xfrm.state_cache_input); } #ifdef CONFIG_AUDITSYSCALL |