smb3: Avoid Mid pending list corruption
authorRohith Surabattula <rohiths@microsoft.com>
Thu, 29 Oct 2020 05:03:10 +0000 (05:03 +0000)
committerSteve French <stfrench@microsoft.com>
Mon, 16 Nov 2020 05:05:33 +0000 (23:05 -0600)
When reconnect happens Mid queue can be corrupted when both
demultiplex and offload thread try to dequeue the MID from the
pending list.

These patches address a problem found during decryption offload:
         CIFS: VFS: trying to dequeue a deleted mid
that could cause a refcount use after free:
         Workqueue: smb3decryptd smb2_decrypt_offload [cifs]

Signed-off-by: Rohith Surabattula <rohiths@microsoft.com>
Reviewed-by: Pavel Shilovsky <pshilov@microsoft.com>
CC: Stable <stable@vger.kernel.org> #5.4+
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/cifs/smb2ops.c

index efedec2..b3b2abb 100644 (file)
@@ -264,7 +264,7 @@ smb2_revert_current_mid(struct TCP_Server_Info *server, const unsigned int val)
 }
 
 static struct mid_q_entry *
-smb2_find_mid(struct TCP_Server_Info *server, char *buf)
+__smb2_find_mid(struct TCP_Server_Info *server, char *buf, bool dequeue)
 {
        struct mid_q_entry *mid;
        struct smb2_sync_hdr *shdr = (struct smb2_sync_hdr *)buf;
@@ -281,6 +281,10 @@ smb2_find_mid(struct TCP_Server_Info *server, char *buf)
                    (mid->mid_state == MID_REQUEST_SUBMITTED) &&
                    (mid->command == shdr->Command)) {
                        kref_get(&mid->refcount);
+                       if (dequeue) {
+                               list_del_init(&mid->qhead);
+                               mid->mid_flags |= MID_DELETED;
+                       }
                        spin_unlock(&GlobalMid_Lock);
                        return mid;
                }
@@ -289,6 +293,18 @@ smb2_find_mid(struct TCP_Server_Info *server, char *buf)
        return NULL;
 }
 
+static struct mid_q_entry *
+smb2_find_mid(struct TCP_Server_Info *server, char *buf)
+{
+       return __smb2_find_mid(server, buf, false);
+}
+
+static struct mid_q_entry *
+smb2_find_dequeue_mid(struct TCP_Server_Info *server, char *buf)
+{
+       return __smb2_find_mid(server, buf, true);
+}
+
 static void
 smb2_dump_detail(void *buf, struct TCP_Server_Info *server)
 {
@@ -4404,7 +4420,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
                cifs_dbg(FYI, "%s: server returned error %d\n",
                         __func__, rdata->result);
                /* normal error on read response */
-               dequeue_mid(mid, false);
+               if (is_offloaded)
+                       mid->mid_state = MID_RESPONSE_RECEIVED;
+               else
+                       dequeue_mid(mid, false);
                return 0;
        }
 
@@ -4428,7 +4447,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
                cifs_dbg(FYI, "%s: data offset (%u) beyond end of smallbuf\n",
                         __func__, data_offset);
                rdata->result = -EIO;
-               dequeue_mid(mid, rdata->result);
+               if (is_offloaded)
+                       mid->mid_state = MID_RESPONSE_MALFORMED;
+               else
+                       dequeue_mid(mid, rdata->result);
                return 0;
        }
 
@@ -4444,21 +4466,30 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
                        cifs_dbg(FYI, "%s: data offset (%u) beyond 1st page of response\n",
                                 __func__, data_offset);
                        rdata->result = -EIO;
-                       dequeue_mid(mid, rdata->result);
+                       if (is_offloaded)
+                               mid->mid_state = MID_RESPONSE_MALFORMED;
+                       else
+                               dequeue_mid(mid, rdata->result);
                        return 0;
                }
 
                if (data_len > page_data_size - pad_len) {
                        /* data_len is corrupt -- discard frame */
                        rdata->result = -EIO;
-                       dequeue_mid(mid, rdata->result);
+                       if (is_offloaded)
+                               mid->mid_state = MID_RESPONSE_MALFORMED;
+                       else
+                               dequeue_mid(mid, rdata->result);
                        return 0;
                }
 
                rdata->result = init_read_bvec(pages, npages, page_data_size,
                                               cur_off, &bvec);
                if (rdata->result != 0) {
-                       dequeue_mid(mid, rdata->result);
+                       if (is_offloaded)
+                               mid->mid_state = MID_RESPONSE_MALFORMED;
+                       else
+                               dequeue_mid(mid, rdata->result);
                        return 0;
                }
 
@@ -4473,7 +4504,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
                /* read response payload cannot be in both buf and pages */
                WARN_ONCE(1, "buf can not contain only a part of read data");
                rdata->result = -EIO;
-               dequeue_mid(mid, rdata->result);
+               if (is_offloaded)
+                       mid->mid_state = MID_RESPONSE_MALFORMED;
+               else
+                       dequeue_mid(mid, rdata->result);
                return 0;
        }
 
@@ -4484,7 +4518,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
        if (length < 0)
                return length;
 
-       dequeue_mid(mid, false);
+       if (is_offloaded)
+               mid->mid_state = MID_RESPONSE_RECEIVED;
+       else
+               dequeue_mid(mid, false);
        return length;
 }
 
@@ -4513,7 +4550,7 @@ static void smb2_decrypt_offload(struct work_struct *work)
        }
 
        dw->server->lstrp = jiffies;
-       mid = smb2_find_mid(dw->server, dw->buf);
+       mid = smb2_find_dequeue_mid(dw->server, dw->buf);
        if (mid == NULL)
                cifs_dbg(FYI, "mid not found\n");
        else {