aboutsummaryrefslogtreecommitdiff
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c430
1 files changed, 302 insertions, 128 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 86b9527c4826..1cc830582fa8 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -120,6 +120,34 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
return __skb_nsg(skb, offset, len, 0);
}
+static int padding_length(struct tls_sw_context_rx *ctx,
+ struct tls_context *tls_ctx, struct sk_buff *skb)
+{
+ struct strp_msg *rxm = strp_msg(skb);
+ int sub = 0;
+
+ /* Determine zero-padding length */
+ if (tls_ctx->prot_info.version == TLS_1_3_VERSION) {
+ char content_type = 0;
+ int err;
+ int back = 17;
+
+ while (content_type == 0) {
+ if (back > rxm->full_len)
+ return -EBADMSG;
+ err = skb_copy_bits(skb,
+ rxm->offset + rxm->full_len - back,
+ &content_type, 1);
+ if (content_type)
+ break;
+ sub++;
+ back++;
+ }
+ ctx->control = content_type;
+ }
+ return sub;
+}
+
static void tls_decrypt_done(struct crypto_async_request *req, int err)
{
struct aead_request *aead_req = (struct aead_request *)req;
@@ -127,6 +155,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx;
+ struct tls_prot_info *prot;
struct scatterlist *sg;
struct sk_buff *skb;
unsigned int pages;
@@ -135,6 +164,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk);
ctx = tls_sw_ctx_rx(tls_ctx);
+ prot = &tls_ctx->prot_info;
/* Propagate if there was an err */
if (err) {
@@ -142,9 +172,9 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
tls_err_abort(skb->sk, err);
} else {
struct strp_msg *rxm = strp_msg(skb);
-
- rxm->offset += tls_ctx->rx.prepend_size;
- rxm->full_len -= tls_ctx->rx.overhead_size;
+ rxm->full_len -= padding_length(ctx, tls_ctx, skb);
+ rxm->offset += prot->prepend_size;
+ rxm->full_len -= prot->overhead_size;
}
/* After using skb->sk to propagate sk through crypto async callback
@@ -181,13 +211,14 @@ static int tls_do_decryption(struct sock *sk,
bool async)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int ret;
aead_request_set_tfm(aead_req, ctx->aead_recv);
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+ aead_request_set_ad(aead_req, prot->aad_size);
aead_request_set_crypt(aead_req, sgin, sgout,
- data_len + tls_ctx->rx.tag_size,
+ data_len + prot->tag_size,
(u8 *)iv_recv);
if (async) {
@@ -225,12 +256,13 @@ static int tls_do_decryption(struct sock *sk,
static void tls_trim_both_msgs(struct sock *sk, int target_size)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec;
sk_msg_trim(sk, &rec->msg_plaintext, target_size);
if (target_size > 0)
- target_size += tls_ctx->tx.overhead_size;
+ target_size += prot->overhead_size;
sk_msg_trim(sk, &rec->msg_encrypted, target_size);
}
@@ -247,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len)
static int tls_clone_plaintext_msg(struct sock *sk, int required)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_pl = &rec->msg_plaintext;
@@ -262,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
/* Skip initial bytes in msg_en's data to be able to use
* same offset of both plain and encrypted data.
*/
- skip = tls_ctx->tx.prepend_size + msg_pl->sg.size;
+ skip = prot->prepend_size + msg_pl->sg.size;
return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
}
@@ -270,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
static struct tls_rec *tls_get_rec(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct sk_msg *msg_pl, *msg_en;
struct tls_rec *rec;
@@ -288,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk)
sk_msg_init(msg_en);
sg_init_table(rec->sg_aead_in, 2);
- sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
- sizeof(rec->aad_space));
+ sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
sg_unmark_end(&rec->sg_aead_in[1]);
sg_init_table(rec->sg_aead_out, 2);
- sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
- sizeof(rec->aad_space));
+ sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
sg_unmark_end(&rec->sg_aead_out[1]);
return rec;
@@ -383,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
struct aead_request *aead_req = (struct aead_request *)req;
struct sock *sk = req->data;
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct scatterlist *sge;
struct sk_msg *msg_en;
@@ -394,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
msg_en = &rec->msg_encrypted;
sge = sk_msg_elem(msg_en, msg_en->sg.curr);
- sge->offset -= tls_ctx->tx.prepend_size;
- sge->length += tls_ctx->tx.prepend_size;
+ sge->offset -= prot->prepend_size;
+ sge->length += prot->prepend_size;
/* Check if error is previously set on socket */
if (err || sk->sk_err) {
@@ -442,21 +475,26 @@ static int tls_do_encryption(struct sock *sk,
struct aead_request *aead_req,
size_t data_len, u32 start)
{
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_en = &rec->msg_encrypted;
struct scatterlist *sge = sk_msg_elem(msg_en, start);
int rc;
- sge->offset += tls_ctx->tx.prepend_size;
- sge->length -= tls_ctx->tx.prepend_size;
+ memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data));
+ xor_iv_with_seq(prot->version, rec->iv_data,
+ tls_ctx->tx.rec_seq);
+
+ sge->offset += prot->prepend_size;
+ sge->length -= prot->prepend_size;
msg_en->sg.curr = start;
aead_request_set_tfm(aead_req, ctx->aead_send);
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+ aead_request_set_ad(aead_req, prot->aad_size);
aead_request_set_crypt(aead_req, rec->sg_aead_in,
rec->sg_aead_out,
- data_len, tls_ctx->tx.iv);
+ data_len, rec->iv_data);
aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_encrypt_done, sk);
@@ -468,8 +506,8 @@ static int tls_do_encryption(struct sock *sk,
rc = crypto_aead_encrypt(aead_req);
if (!rc || rc != -EINPROGRESS) {
atomic_dec(&ctx->encrypt_pending);
- sge->offset -= tls_ctx->tx.prepend_size;
- sge->length += tls_ctx->tx.prepend_size;
+ sge->offset -= prot->prepend_size;
+ sge->length += prot->prepend_size;
}
if (!rc) {
@@ -481,7 +519,7 @@ static int tls_do_encryption(struct sock *sk,
/* Unhook the record from context if encryption is not failure */
ctx->open_rec = NULL;
- tls_advance_record_sn(sk, &tls_ctx->tx);
+ tls_advance_record_sn(sk, &tls_ctx->tx, prot->version);
return rc;
}
@@ -607,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags,
unsigned char record_type)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
u32 i, split_point, uninitialized_var(orig_end);
@@ -625,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags,
split = split_point && split_point < msg_pl->sg.size;
if (split) {
rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
- split_point, tls_ctx->tx.overhead_size,
+ split_point, prot->overhead_size,
&orig_end);
if (rc < 0)
return rc;
sk_msg_trim(sk, msg_en, msg_pl->sg.size +
- tls_ctx->tx.overhead_size);
+ prot->overhead_size);
}
rec->tx_flags = flags;
@@ -638,7 +677,17 @@ static int tls_push_record(struct sock *sk, int flags,
i = msg_pl->sg.end;
sk_msg_iter_var_prev(i);
- sg_mark_end(sk_msg_elem(msg_pl, i));
+
+ rec->content_type = record_type;
+ if (prot->version == TLS_1_3_VERSION) {
+ /* Add content type to end of message. No padding added */
+ sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
+ sg_mark_end(&rec->sg_content_type);
+ sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
+ &rec->sg_content_type);
+ } else {
+ sg_mark_end(sk_msg_elem(msg_pl, i));
+ }
i = msg_pl->sg.start;
sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ?
@@ -651,18 +700,20 @@ static int tls_push_record(struct sock *sk, int flags,
i = msg_en->sg.start;
sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
- tls_make_aad(rec->aad_space, msg_pl->sg.size,
- tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
- record_type);
+ tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
+ tls_ctx->tx.rec_seq, prot->rec_seq_size,
+ record_type, prot->version);
tls_fill_prepend(tls_ctx,
page_address(sg_page(&msg_en->sg.data[i])) +
- msg_en->sg.data[i].offset, msg_pl->sg.size,
- record_type);
+ msg_en->sg.data[i].offset,
+ msg_pl->sg.size + prot->tail_size,
+ record_type, prot->version);
tls_ctx->pending_open_record_frags = false;
- rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
+ rc = tls_do_encryption(sk, tls_ctx, ctx, req,
+ msg_pl->sg.size + prot->tail_size, i);
if (rc < 0) {
if (rc != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG);
@@ -671,12 +722,12 @@ static int tls_push_record(struct sock *sk, int flags,
tls_merge_open_record(sk, rec, tmp, orig_end);
}
}
+ ctx->async_capable = 1;
return rc;
} else if (split) {
msg_pl = &tmp->msg_plaintext;
msg_en = &tmp->msg_encrypted;
- sk_msg_trim(sk, msg_en, msg_pl->sg.size +
- tls_ctx->tx.overhead_size);
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
tls_ctx->pending_open_record_frags = true;
ctx->open_rec = tmp;
}
@@ -811,9 +862,9 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send);
- bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+ bool async_capable = ctx->async_capable;
unsigned char record_type = TLS_RECORD_TYPE_DATA;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool eor = !(msg->msg_flags & MSG_MORE);
@@ -878,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
}
required_size = msg_pl->sg.size + try_to_copy +
- tls_ctx->tx.overhead_size;
+ prot->overhead_size;
if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf;
@@ -947,8 +998,8 @@ fallback_to_reg_send:
*/
try_to_copy -= required_size - msg_pl->sg.size;
full_record = true;
- sk_msg_trim(sk, msg_en, msg_pl->sg.size +
- tls_ctx->tx.overhead_size);
+ sk_msg_trim(sk, msg_en,
+ msg_pl->sg.size + prot->overhead_size);
}
if (try_to_copy) {
@@ -1034,6 +1085,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
unsigned char record_type = TLS_RECORD_TYPE_DATA;
struct sk_msg *msg_pl;
struct tls_rec *rec;
@@ -1083,8 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
full_record = true;
}
- required_size = msg_pl->sg.size + copy +
- tls_ctx->tx.overhead_size;
+ required_size = msg_pl->sg.size + copy + prot->overhead_size;
if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf;
@@ -1283,6 +1334,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct strp_msg *rxm = strp_msg(skb);
int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
struct aead_request *aead_req;
@@ -1290,15 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
u8 *aad, *iv, *mem = NULL;
struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL;
- const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+ const int data_len = rxm->full_len - prot->overhead_size +
+ prot->tail_size;
if (*zc && (out_iov || out_sg)) {
if (out_iov)
n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
else
n_sgout = sg_nents(out_sg);
- n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
- rxm->full_len - tls_ctx->rx.prepend_size);
+ n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+ rxm->full_len - prot->prepend_size);
} else {
n_sgout = 0;
*zc = false;
@@ -1315,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem_size = aead_size + (nsg * sizeof(struct scatterlist));
- mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+ mem_size = mem_size + prot->aad_size;
mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
/* Allocate a single block of memory which contains
@@ -1331,29 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
sgin = (struct scatterlist *)(mem + aead_size);
sgout = sgin + n_sgin;
aad = (u8 *)(sgout + n_sgout);
- iv = aad + TLS_AAD_SPACE_SIZE;
+ iv = aad + prot->aad_size;
/* Prepare IV */
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- tls_ctx->rx.iv_size);
+ prot->iv_size);
if (err < 0) {
kfree(mem);
return err;
}
- memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ if (prot->version == TLS_1_3_VERSION)
+ memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv));
+ else
+ memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+
+ xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
/* Prepare AAD */
- tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
- tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
- ctx->control);
+ tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+ prot->tail_size,
+ tls_ctx->rx.rec_seq, prot->rec_seq_size,
+ ctx->control, prot->version);
/* Prepare sgin */
sg_init_table(sgin, n_sgin);
- sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+ sg_set_buf(&sgin[0], aad, prot->aad_size);
err = skb_to_sgvec(skb, &sgin[1],
- rxm->offset + tls_ctx->rx.prepend_size,
- rxm->full_len - tls_ctx->rx.prepend_size);
+ rxm->offset + prot->prepend_size,
+ rxm->full_len - prot->prepend_size);
if (err < 0) {
kfree(mem);
return err;
@@ -1362,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
if (n_sgout) {
if (out_iov) {
sg_init_table(sgout, n_sgout);
- sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
+ sg_set_buf(&sgout[0], aad, prot->aad_size);
*chunk = 0;
err = tls_setup_from_iter(sk, out_iov, data_len,
@@ -1403,6 +1462,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
+ int version = prot->version;
struct strp_msg *rxm = strp_msg(skb);
int err = 0;
@@ -1415,20 +1476,23 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
if (err < 0) {
if (err == -EINPROGRESS)
- tls_advance_record_sn(sk, &tls_ctx->rx);
+ tls_advance_record_sn(sk, &tls_ctx->rx,
+ version);
return err;
}
+
+ rxm->full_len -= padding_length(ctx, tls_ctx, skb);
+
+ rxm->offset += prot->prepend_size;
+ rxm->full_len -= prot->overhead_size;
+ tls_advance_record_sn(sk, &tls_ctx->rx, version);
+ ctx->decrypted = true;
+ ctx->saved_data_ready(sk);
} else {
*zc = false;
}
- rxm->offset += tls_ctx->rx.prepend_size;
- rxm->full_len -= tls_ctx->rx.overhead_size;
- tls_advance_record_sn(sk, &tls_ctx->rx);
- ctx->decrypted = true;
- ctx->saved_data_ready(sk);
-
return err;
}
@@ -1466,22 +1530,38 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
}
/* This function traverses the rx_list in tls receive context to copies the
- * decrypted data records into the buffer provided by caller zero copy is not
+ * decrypted records into the buffer provided by caller zero copy is not
* true. Further, the records are removed from the rx_list if it is not a peek
* case and the record has been consumed completely.
*/
static int process_rx_list(struct tls_sw_context_rx *ctx,
struct msghdr *msg,
+ u8 *control,
+ bool *cmsg,
size_t skip,
size_t len,
bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
+ u8 ctrl = *control;
+ u8 msgc = *cmsg;
+ struct tls_msg *tlm;
ssize_t copied = 0;
+ /* Set the record type in 'control' if caller didn't pass it */
+ if (!ctrl && skb) {
+ tlm = tls_msg(skb);
+ ctrl = tlm->control;
+ }
+
while (skip && skb) {
struct strp_msg *rxm = strp_msg(skb);
+ tlm = tls_msg(skb);
+
+ /* Cannot process a record of different type */
+ if (ctrl != tlm->control)
+ return 0;
if (skip < rxm->full_len)
break;
@@ -1495,6 +1575,27 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
struct strp_msg *rxm = strp_msg(skb);
int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+ tlm = tls_msg(skb);
+
+ /* Cannot process a record of different type */
+ if (ctrl != tlm->control)
+ return 0;
+
+ /* Set record type if not already done. For a non-data record,
+ * do not proceed if record type could not be copied.
+ */
+ if (!msgc) {
+ int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+ sizeof(ctrl), &ctrl);
+ msgc = true;
+ if (ctrl != TLS_RECORD_TYPE_DATA) {
+ if (cerr || msg->msg_flags & MSG_CTRUNC)
+ return -EIO;
+
+ *cmsg = msgc;
+ }
+ }
+
if (!zc || (rxm->full_len - skip) > len) {
int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
@@ -1533,6 +1634,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
skb = next_skb;
}
+ *control = ctrl;
return copied;
}
@@ -1545,10 +1647,12 @@ int tls_sw_recvmsg(struct sock *sk,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct sk_psock *psock;
unsigned char control = 0;
ssize_t decrypted = 0;
struct strp_msg *rxm;
+ struct tls_msg *tlm;
struct sk_buff *skb;
ssize_t copied = 0;
bool cmsg = false;
@@ -1567,7 +1671,8 @@ int tls_sw_recvmsg(struct sock *sk,
lock_sock(sk);
/* Process pending decrypted records. It must be non-zero-copy */
- err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
+ is_peek);
if (err < 0) {
tls_err_abort(sk, err);
goto end;
@@ -1585,10 +1690,10 @@ int tls_sw_recvmsg(struct sock *sk,
do {
bool retain_skb = false;
- bool async = false;
bool zc = false;
int to_decrypt;
int chunk = 0;
+ bool async;
skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) {
@@ -1603,61 +1708,87 @@ int tls_sw_recvmsg(struct sock *sk,
}
}
goto recv_end;
+ } else {
+ tlm = tls_msg(skb);
+ if (prot->version == TLS_1_3_VERSION)
+ tlm->control = 0;
+ else
+ tlm->control = ctx->control;
}
rxm = strp_msg(skb);
+ to_decrypt = rxm->full_len - prot->overhead_size;
+
+ if (to_decrypt <= len && !is_kvec && !is_peek &&
+ ctx->control == TLS_RECORD_TYPE_DATA &&
+ prot->version != TLS_1_3_VERSION)
+ zc = true;
+
+ /* Do not use async mode if record is non-data */
+ if (ctx->control == TLS_RECORD_TYPE_DATA)
+ async = ctx->async_capable;
+ else
+ async = false;
+
+ err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+ &chunk, &zc, async);
+ if (err < 0 && err != -EINPROGRESS) {
+ tls_err_abort(sk, EBADMSG);
+ goto recv_end;
+ }
+
+ if (err == -EINPROGRESS)
+ num_async++;
+ else if (prot->version == TLS_1_3_VERSION)
+ tlm->control = ctx->control;
+
+ /* If the type of records being processed is not known yet,
+ * set it to record type just dequeued. If it is already known,
+ * but does not match the record type just dequeued, go to end.
+ * We always get record type here since for tls1.2, record type
+ * is known just after record is dequeued from stream parser.
+ * For tls1.3, we disable async.
+ */
+
+ if (!control)
+ control = tlm->control;
+ else if (control != tlm->control)
+ goto recv_end;
+
if (!cmsg) {
int cerr;
cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
- sizeof(ctx->control), &ctx->control);
+ sizeof(control), &control);
cmsg = true;
- control = ctx->control;
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ if (control != TLS_RECORD_TYPE_DATA) {
if (cerr || msg->msg_flags & MSG_CTRUNC) {
err = -EIO;
goto recv_end;
}
}
- } else if (control != ctx->control) {
- goto recv_end;
- }
-
- to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
-
- if (to_decrypt <= len && !is_kvec && !is_peek)
- zc = true;
-
- err = decrypt_skb_update(sk, skb, &msg->msg_iter,
- &chunk, &zc, ctx->async_capable);
- if (err < 0 && err != -EINPROGRESS) {
- tls_err_abort(sk, EBADMSG);
- goto recv_end;
}
- if (err == -EINPROGRESS) {
- async = true;
- num_async++;
+ if (async)
goto pick_next_record;
- } else {
- if (!zc) {
- if (rxm->full_len > len) {
- retain_skb = true;
- chunk = len;
- } else {
- chunk = rxm->full_len;
- }
- err = skb_copy_datagram_msg(skb, rxm->offset,
- msg, chunk);
- if (err < 0)
- goto recv_end;
+ if (!zc) {
+ if (rxm->full_len > len) {
+ retain_skb = true;
+ chunk = len;
+ } else {
+ chunk = rxm->full_len;
+ }
- if (!is_peek) {
- rxm->offset = rxm->offset + chunk;
- rxm->full_len = rxm->full_len - chunk;
- }
+ err = skb_copy_datagram_msg(skb, rxm->offset,
+ msg, chunk);
+ if (err < 0)
+ goto recv_end;
+
+ if (!is_peek) {
+ rxm->offset = rxm->offset + chunk;
+ rxm->full_len = rxm->full_len - chunk;
}
}
@@ -1711,18 +1842,16 @@ recv_end:
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
- err = process_rx_list(ctx, msg, copied,
+ err = process_rx_list(ctx, msg, &control, &cmsg, copied,
decrypted, false, is_peek);
else
- err = process_rx_list(ctx, msg, 0,
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0,
decrypted, true, is_peek);
if (err < 0) {
tls_err_abort(sk, err);
copied = 0;
goto end;
}
-
- WARN_ON(decrypted != err);
}
copied += decrypted;
@@ -1757,15 +1886,15 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (!skb)
goto splice_read_end;
- /* splice does not support reading control messages */
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
- err = -ENOTSUPP;
- goto splice_read_end;
- }
-
if (!ctx->decrypted) {
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
+ /* splice does not support reading control messages */
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ err = -ENOTSUPP;
+ goto splice_read_end;
+ }
+
if (err < 0) {
tls_err_abort(sk, EBADMSG);
goto splice_read_end;
@@ -1807,6 +1936,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
struct strp_msg *rxm = strp_msg(skb);
size_t cipher_overhead;
@@ -1814,17 +1944,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
int ret;
/* Verify that we have a full TLS header, or wait for more data */
- if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+ if (rxm->offset + prot->prepend_size > skb->len)
return 0;
/* Sanity-check size of on-stack buffer. */
- if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
+ if (WARN_ON(prot->prepend_size > sizeof(header))) {
ret = -EINVAL;
goto read_failure;
}
/* Linearize header to local buffer */
- ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+ ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
if (ret < 0)
goto read_failure;
@@ -1833,9 +1963,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
data_len = ((header[4] & 0xFF) | (header[3] << 8));
- cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+ cipher_overhead = prot->tag_size;
+ if (prot->version != TLS_1_3_VERSION)
+ cipher_overhead += prot->iv_size;
- if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+ if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
+ prot->tail_size) {
ret = -EMSGSIZE;
goto read_failure;
}
@@ -1844,12 +1977,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
goto read_failure;
}
- if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
- header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
+ /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
+ if (header[1] != TLS_1_2_VERSION_MINOR ||
+ header[2] != TLS_1_2_VERSION_MAJOR) {
ret = -EINVAL;
goto read_failure;
}
-
#ifdef CONFIG_TLS_DEVICE
handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
*(u64*)tls_ctx->rx.rec_seq);
@@ -1901,7 +2034,9 @@ void tls_sw_free_resources_tx(struct sock *sk)
if (atomic_read(&ctx->encrypt_pending))
crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ release_sock(sk);
cancel_delayed_work_sync(&ctx->tx_work.work);
+ lock_sock(sk);
/* Tx whatever records we can transmit and abandon the rest */
tls_tx_records(sk, -1);
@@ -1993,8 +2128,11 @@ static void tx_work_handler(struct work_struct *work)
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_crypto_info *crypto_info;
struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
+ struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
struct cipher_context *cctx;
@@ -2002,7 +2140,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size;
struct crypto_tfm *tfm;
- char *iv, *rec_seq;
+ char *iv, *rec_seq, *key, *salt;
+ size_t keysize;
int rc = 0;
if (!ctx) {
@@ -2063,6 +2202,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
gcm_128_info =
(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
+ key = gcm_128_info->key;
+ salt = gcm_128_info->salt;
+ break;
+ }
+ case TLS_CIPHER_AES_GCM_256: {
+ nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+ tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
+ iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+ rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+ gcm_256_info =
+ (struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+ keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
+ key = gcm_256_info->key;
+ salt = gcm_256_info->salt;
break;
}
default:
@@ -2076,19 +2233,32 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_priv;
}
- cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
- cctx->tag_size = tag_size;
- cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
- cctx->iv_size = iv_size;
+ if (crypto_info->version == TLS_1_3_VERSION) {
+ nonce_size = 0;
+ prot->aad_size = TLS_HEADER_SIZE;
+ prot->tail_size = 1;
+ } else {
+ prot->aad_size = TLS_AAD_SPACE_SIZE;
+ prot->tail_size = 0;
+ }
+
+ prot->version = crypto_info->version;
+ prot->cipher_type = crypto_info->cipher_type;
+ prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+ prot->tag_size = tag_size;
+ prot->overhead_size = prot->prepend_size +
+ prot->tag_size + prot->tail_size;
+ prot->iv_size = iv_size;
cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
GFP_KERNEL);
if (!cctx->iv) {
rc = -ENOMEM;
goto free_priv;
}
- memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ /* Note: 128 & 256 bit salt are the same size */
+ memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
- cctx->rec_seq_size = rec_seq_size;
+ prot->rec_seq_size = rec_seq_size;
cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!cctx->rec_seq) {
rc = -ENOMEM;
@@ -2106,19 +2276,23 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
ctx->push_pending_record = tls_sw_push_pending_record;
- rc = crypto_aead_setkey(*aead, gcm_128_info->key,
- TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+ rc = crypto_aead_setkey(*aead, key, keysize);
+
if (rc)
goto free_aead;
- rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
if (rc)
goto free_aead;
if (sw_ctx_rx) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
- sw_ctx_rx->async_capable =
- tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
+ if (crypto_info->version == TLS_1_3_VERSION)
+ sw_ctx_rx->async_capable = false;
+ else
+ sw_ctx_rx->async_capable =
+ tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
/* Set up strparser */
memset(&cb, 0, sizeof(cb));