Drivers: hv: vmbus: Fix race condition with new ring_buffer_info mutex
authorKimberly Brown <kimbrownkd@gmail.com>
Thu, 14 Mar 2019 20:05:15 +0000 (16:05 -0400)
committerSasha Levin <sashal@kernel.org>
Wed, 10 Apr 2019 22:58:56 +0000 (18:58 -0400)
Fix a race condition that can result in a ring buffer pointer being set
to null while a "_show" function is reading the ring buffer's data. This
problem was discussed here: https://lkml.org/lkml/2018/10/18/779

To fix the race condition, add a new mutex lock to the
"hv_ring_buffer_info" struct. Add a new function,
"hv_ringbuffer_pre_init()", where a channel's inbound and outbound
ring_buffer_info mutex locks are initialized.

Acquire/release the locks in the "hv_ringbuffer_cleanup()" function,
which is where the ring buffer pointers are set to null.

Acquire/release the locks in the four channel-level "_show" functions
that access ring buffer data. Remove the "const" qualifier from the
"vmbus_channel" parameter and the "rbi" variable of the channel-level
"_show" functions so that the locks can be acquired/released in these
functions.

Acquire/release the locks in hv_ringbuffer_get_debuginfo(). Remove the
"const" qualifier from the "hv_ring_buffer_info" parameter so that the
locks can be acquired/released in this function.

Signed-off-by: Kimberly Brown <kimbrownkd@gmail.com>
Reviewed-by: Michael Kelley <mikelley@microsoft.com>
Signed-off-by: Sasha Levin <sashal@kernel.org>
drivers/hv/channel_mgmt.c
drivers/hv/hyperv_vmbus.h
drivers/hv/ring_buffer.c
drivers/hv/vmbus_drv.c
include/linux/hyperv.h

index d32cac5..3fc0b24 100644 (file)
@@ -336,6 +336,8 @@ static struct vmbus_channel *alloc_channel(void)
        tasklet_init(&channel->callback_event,
                     vmbus_on_event, (unsigned long)channel);
 
+       hv_ringbuffer_pre_init(channel);
+
        return channel;
 }
 
index a94aab9..e5467b8 100644 (file)
@@ -193,6 +193,7 @@ extern void hv_synic_clockevents_cleanup(void);
 
 /* Interface */
 
+void hv_ringbuffer_pre_init(struct vmbus_channel *channel);
 
 int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
                       struct page *pages, u32 pagecnt);
index 0386ff4..121a01c 100644 (file)
@@ -166,14 +166,18 @@ hv_get_ringbuffer_availbytes(const struct hv_ring_buffer_info *rbi,
 }
 
 /* Get various debug metrics for the specified ring buffer. */
-int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
+int hv_ringbuffer_get_debuginfo(struct hv_ring_buffer_info *ring_info,
                                struct hv_ring_buffer_debug_info *debug_info)
 {
        u32 bytes_avail_towrite;
        u32 bytes_avail_toread;
 
-       if (!ring_info->ring_buffer)
+       mutex_lock(&ring_info->ring_buffer_mutex);
+
+       if (!ring_info->ring_buffer) {
+               mutex_unlock(&ring_info->ring_buffer_mutex);
                return -EINVAL;
+       }
 
        hv_get_ringbuffer_availbytes(ring_info,
                                     &bytes_avail_toread,
@@ -184,10 +188,19 @@ int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
        debug_info->current_write_index = ring_info->ring_buffer->write_index;
        debug_info->current_interrupt_mask
                = ring_info->ring_buffer->interrupt_mask;
+       mutex_unlock(&ring_info->ring_buffer_mutex);
+
        return 0;
 }
 EXPORT_SYMBOL_GPL(hv_ringbuffer_get_debuginfo);
 
+/* Initialize a channel's ring buffer info mutex locks */
+void hv_ringbuffer_pre_init(struct vmbus_channel *channel)
+{
+       mutex_init(&channel->inbound.ring_buffer_mutex);
+       mutex_init(&channel->outbound.ring_buffer_mutex);
+}
+
 /* Initialize the ring buffer. */
 int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
                       struct page *pages, u32 page_cnt)
@@ -240,8 +253,10 @@ int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
 /* Cleanup the ring buffer. */
 void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info)
 {
+       mutex_lock(&ring_info->ring_buffer_mutex);
        vunmap(ring_info->ring_buffer);
        ring_info->ring_buffer = NULL;
+       mutex_unlock(&ring_info->ring_buffer_mutex);
 }
 
 /* Write to the ring buffer. */
index 6aa79b6..aa25f3b 100644 (file)
@@ -1410,7 +1410,7 @@ static void vmbus_chan_release(struct kobject *kobj)
 
 struct vmbus_chan_attribute {
        struct attribute attr;
-       ssize_t (*show)(const struct vmbus_channel *chan, char *buf);
+       ssize_t (*show)(struct vmbus_channel *chan, char *buf);
        ssize_t (*store)(struct vmbus_channel *chan,
                         const char *buf, size_t count);
 };
@@ -1429,7 +1429,7 @@ static ssize_t vmbus_chan_attr_show(struct kobject *kobj,
 {
        const struct vmbus_chan_attribute *attribute
                = container_of(attr, struct vmbus_chan_attribute, attr);
-       const struct vmbus_channel *chan
+       struct vmbus_channel *chan
                = container_of(kobj, struct vmbus_channel, kobj);
 
        if (!attribute->show)
@@ -1442,57 +1442,81 @@ static const struct sysfs_ops vmbus_chan_sysfs_ops = {
        .show = vmbus_chan_attr_show,
 };
 
-static ssize_t out_mask_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t out_mask_show(struct vmbus_channel *channel, char *buf)
 {
-       const struct hv_ring_buffer_info *rbi = &channel->outbound;
+       struct hv_ring_buffer_info *rbi = &channel->outbound;
+       ssize_t ret;
 
-       if (!rbi->ring_buffer)
+       mutex_lock(&rbi->ring_buffer_mutex);
+       if (!rbi->ring_buffer) {
+               mutex_unlock(&rbi->ring_buffer_mutex);
                return -EINVAL;
+       }
 
-       return sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+       ret = sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+       mutex_unlock(&rbi->ring_buffer_mutex);
+       return ret;
 }
 static VMBUS_CHAN_ATTR_RO(out_mask);
 
-static ssize_t in_mask_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t in_mask_show(struct vmbus_channel *channel, char *buf)
 {
-       const struct hv_ring_buffer_info *rbi = &channel->inbound;
+       struct hv_ring_buffer_info *rbi = &channel->inbound;
+       ssize_t ret;
 
-       if (!rbi->ring_buffer)
+       mutex_lock(&rbi->ring_buffer_mutex);
+       if (!rbi->ring_buffer) {
+               mutex_unlock(&rbi->ring_buffer_mutex);
                return -EINVAL;
+       }
 
-       return sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+       ret = sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+       mutex_unlock(&rbi->ring_buffer_mutex);
+       return ret;
 }
 static VMBUS_CHAN_ATTR_RO(in_mask);
 
-static ssize_t read_avail_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t read_avail_show(struct vmbus_channel *channel, char *buf)
 {
-       const struct hv_ring_buffer_info *rbi = &channel->inbound;
+       struct hv_ring_buffer_info *rbi = &channel->inbound;
+       ssize_t ret;
 
-       if (!rbi->ring_buffer)
+       mutex_lock(&rbi->ring_buffer_mutex);
+       if (!rbi->ring_buffer) {
+               mutex_unlock(&rbi->ring_buffer_mutex);
                return -EINVAL;
+       }
 
-       return sprintf(buf, "%u\n", hv_get_bytes_to_read(rbi));
+       ret = sprintf(buf, "%u\n", hv_get_bytes_to_read(rbi));
+       mutex_unlock(&rbi->ring_buffer_mutex);
+       return ret;
 }
 static VMBUS_CHAN_ATTR_RO(read_avail);
 
-static ssize_t write_avail_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t write_avail_show(struct vmbus_channel *channel, char *buf)
 {
-       const struct hv_ring_buffer_info *rbi = &channel->outbound;
+       struct hv_ring_buffer_info *rbi = &channel->outbound;
+       ssize_t ret;
 
-       if (!rbi->ring_buffer)
+       mutex_lock(&rbi->ring_buffer_mutex);
+       if (!rbi->ring_buffer) {
+               mutex_unlock(&rbi->ring_buffer_mutex);
                return -EINVAL;
+       }
 
-       return sprintf(buf, "%u\n", hv_get_bytes_to_write(rbi));
+       ret = sprintf(buf, "%u\n", hv_get_bytes_to_write(rbi));
+       mutex_unlock(&rbi->ring_buffer_mutex);
+       return ret;
 }
 static VMBUS_CHAN_ATTR_RO(write_avail);
 
-static ssize_t show_target_cpu(const struct vmbus_channel *channel, char *buf)
+static ssize_t show_target_cpu(struct vmbus_channel *channel, char *buf)
 {
        return sprintf(buf, "%u\n", channel->target_cpu);
 }
 static VMBUS_CHAN_ATTR(cpu, S_IRUGO, show_target_cpu, NULL);
 
-static ssize_t channel_pending_show(const struct vmbus_channel *channel,
+static ssize_t channel_pending_show(struct vmbus_channel *channel,
                                    char *buf)
 {
        return sprintf(buf, "%d\n",
@@ -1501,7 +1525,7 @@ static ssize_t channel_pending_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(pending, S_IRUGO, channel_pending_show, NULL);
 
-static ssize_t channel_latency_show(const struct vmbus_channel *channel,
+static ssize_t channel_latency_show(struct vmbus_channel *channel,
                                    char *buf)
 {
        return sprintf(buf, "%d\n",
@@ -1510,19 +1534,19 @@ static ssize_t channel_latency_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(latency, S_IRUGO, channel_latency_show, NULL);
 
-static ssize_t channel_interrupts_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t channel_interrupts_show(struct vmbus_channel *channel, char *buf)
 {
        return sprintf(buf, "%llu\n", channel->interrupts);
 }
 static VMBUS_CHAN_ATTR(interrupts, S_IRUGO, channel_interrupts_show, NULL);
 
-static ssize_t channel_events_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t channel_events_show(struct vmbus_channel *channel, char *buf)
 {
        return sprintf(buf, "%llu\n", channel->sig_events);
 }
 static VMBUS_CHAN_ATTR(events, S_IRUGO, channel_events_show, NULL);
 
-static ssize_t channel_intr_in_full_show(const struct vmbus_channel *channel,
+static ssize_t channel_intr_in_full_show(struct vmbus_channel *channel,
                                         char *buf)
 {
        return sprintf(buf, "%llu\n",
@@ -1530,7 +1554,7 @@ static ssize_t channel_intr_in_full_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(intr_in_full, 0444, channel_intr_in_full_show, NULL);
 
-static ssize_t channel_intr_out_empty_show(const struct vmbus_channel *channel,
+static ssize_t channel_intr_out_empty_show(struct vmbus_channel *channel,
                                           char *buf)
 {
        return sprintf(buf, "%llu\n",
@@ -1538,7 +1562,7 @@ static ssize_t channel_intr_out_empty_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(intr_out_empty, 0444, channel_intr_out_empty_show, NULL);
 
-static ssize_t channel_out_full_first_show(const struct vmbus_channel *channel,
+static ssize_t channel_out_full_first_show(struct vmbus_channel *channel,
                                           char *buf)
 {
        return sprintf(buf, "%llu\n",
@@ -1546,7 +1570,7 @@ static ssize_t channel_out_full_first_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(out_full_first, 0444, channel_out_full_first_show, NULL);
 
-static ssize_t channel_out_full_total_show(const struct vmbus_channel *channel,
+static ssize_t channel_out_full_total_show(struct vmbus_channel *channel,
                                           char *buf)
 {
        return sprintf(buf, "%llu\n",
@@ -1554,14 +1578,14 @@ static ssize_t channel_out_full_total_show(const struct vmbus_channel *channel,
 }
 static VMBUS_CHAN_ATTR(out_full_total, 0444, channel_out_full_total_show, NULL);
 
-static ssize_t subchannel_monitor_id_show(const struct vmbus_channel *channel,
+static ssize_t subchannel_monitor_id_show(struct vmbus_channel *channel,
                                          char *buf)
 {
        return sprintf(buf, "%u\n", channel->offermsg.monitorid);
 }
 static VMBUS_CHAN_ATTR(monitor_id, S_IRUGO, subchannel_monitor_id_show, NULL);
 
-static ssize_t subchannel_id_show(const struct vmbus_channel *channel,
+static ssize_t subchannel_id_show(struct vmbus_channel *channel,
                                  char *buf)
 {
        return sprintf(buf, "%u\n",
index 64698ec..8b9a93c 100644 (file)
@@ -141,6 +141,11 @@ struct hv_ring_buffer_info {
 
        u32 ring_datasize;              /* < ring_size */
        u32 priv_read_index;
+       /*
+        * The ring buffer mutex lock. This lock prevents the ring buffer from
+        * being freed while the ring buffer is being accessed.
+        */
+       struct mutex ring_buffer_mutex;
 };
 
 
@@ -1206,7 +1211,7 @@ struct hv_ring_buffer_debug_info {
 };
 
 
-int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
+int hv_ringbuffer_get_debuginfo(struct hv_ring_buffer_info *ring_info,
                                struct hv_ring_buffer_debug_info *debug_info);
 
 /* Vmbus interface */