diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f61f852d6cfd..e3d7ea1288c6 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include @@ -85,6 +87,13 @@ struct vhost_net_ubuf_ref { struct vhost_virtqueue *vq; }; +#define VHOST_RX_BATCH 64 +struct vhost_net_buf { + struct sk_buff **queue; + int tail; + int head; +}; + struct vhost_net_virtqueue { struct vhost_virtqueue vq; size_t vhost_hlen; @@ -99,6 +108,8 @@ struct vhost_net_virtqueue { /* Reference counting for outstanding ubufs. * Protected by vq mutex. Writers must also take device mutex. */ struct vhost_net_ubuf_ref *ubufs; + struct skb_array *rx_array; + struct vhost_net_buf rxq; }; struct vhost_net { @@ -117,6 +128,71 @@ struct vhost_net { static unsigned vhost_net_zcopy_mask __read_mostly; +static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq) +{ + if (rxq->tail != rxq->head) + return rxq->queue[rxq->head]; + else + return NULL; +} + +static int vhost_net_buf_get_size(struct vhost_net_buf *rxq) +{ + return rxq->tail - rxq->head; +} + +static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq) +{ + return rxq->tail == rxq->head; +} + +static void *vhost_net_buf_consume(struct vhost_net_buf *rxq) +{ + void *ret = vhost_net_buf_get_ptr(rxq); + ++rxq->head; + return ret; +} + +static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq) +{ + struct vhost_net_buf *rxq = &nvq->rxq; + + rxq->head = 0; + rxq->tail = skb_array_consume_batched(nvq->rx_array, rxq->queue, + VHOST_RX_BATCH); + return rxq->tail; +} + +static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq) +{ + struct vhost_net_buf *rxq = &nvq->rxq; + + if (nvq->rx_array && !vhost_net_buf_is_empty(rxq)) { + skb_array_unconsume(nvq->rx_array, rxq->queue + rxq->head, + vhost_net_buf_get_size(rxq)); + rxq->head = rxq->tail = 0; + } +} + +static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq) +{ + struct vhost_net_buf *rxq = &nvq->rxq; + + if (!vhost_net_buf_is_empty(rxq)) + goto out; + + if (!vhost_net_buf_produce(nvq)) + return 0; + +out: + return __skb_array_len_with_tag(vhost_net_buf_get_ptr(rxq)); +} + +static void vhost_net_buf_init(struct vhost_net_buf *rxq) +{ + rxq->head = rxq->tail = 0; +} + static void vhost_net_enable_zcopy(int vq) { vhost_net_zcopy_mask |= 0x1 << vq; @@ -201,6 +277,7 @@ static void vhost_net_vq_reset(struct vhost_net *n) n->vqs[i].ubufs = NULL; n->vqs[i].vhost_hlen = 0; n->vqs[i].sock_hlen = 0; + vhost_net_buf_init(&n->vqs[i].rxq); } } @@ -503,15 +580,14 @@ out: mutex_unlock(&vq->mutex); } -static int peek_head_len(struct sock *sk) +static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk) { - struct socket *sock = sk->sk_socket; struct sk_buff *head; int len = 0; unsigned long flags; - if (sock->ops->peek_len) - return sock->ops->peek_len(sock); + if (rvq->rx_array) + return vhost_net_buf_peek(rvq); spin_lock_irqsave(&sk->sk_receive_queue.lock, flags); head = skb_peek(&sk->sk_receive_queue); @@ -537,10 +613,11 @@ static int sk_has_rx_data(struct sock *sk) static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk) { + struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX]; struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX]; struct vhost_virtqueue *vq = &nvq->vq; unsigned long uninitialized_var(endtime); - int len = peek_head_len(sk); + int len = peek_head_len(rvq, sk); if (!len && vq->busyloop_timeout) { /* Both tx vq and rx socket were polled here */ @@ -561,7 +638,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk) vhost_poll_queue(&vq->poll); mutex_unlock(&vq->mutex); - len = peek_head_len(sk); + len = peek_head_len(rvq, sk); } return len; @@ -699,6 +776,8 @@ static void handle_rx(struct vhost_net *net) /* On error, stop handling until the next kick. */ if (unlikely(headcount < 0)) goto out; + if (nvq->rx_array) + msg.msg_control = vhost_net_buf_consume(&nvq->rxq); /* On overrun, truncate and discard */ if (unlikely(headcount > UIO_MAXIOV)) { iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1); @@ -815,6 +894,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) struct vhost_net *n; struct vhost_dev *dev; struct vhost_virtqueue **vqs; + struct sk_buff **queue; int i; n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_REPEAT); @@ -826,6 +906,15 @@ static int vhost_net_open(struct inode *inode, struct file *f) return -ENOMEM; } + queue = kmalloc_array(VHOST_RX_BATCH, sizeof(struct sk_buff *), + GFP_KERNEL); + if (!queue) { + kfree(vqs); + kvfree(n); + return -ENOMEM; + } + n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue; + dev = &n->dev; vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq; vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq; @@ -838,6 +927,7 @@ static int vhost_net_open(struct inode *inode, struct file *f) n->vqs[i].done_idx = 0; n->vqs[i].vhost_hlen = 0; n->vqs[i].sock_hlen = 0; + vhost_net_buf_init(&n->vqs[i].rxq); } vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX); @@ -853,11 +943,14 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n, struct vhost_virtqueue *vq) { struct socket *sock; + struct vhost_net_virtqueue *nvq = + container_of(vq, struct vhost_net_virtqueue, vq); mutex_lock(&vq->mutex); sock = vq->private_data; vhost_net_disable_vq(n, vq); vq->private_data = NULL; + vhost_net_buf_unproduce(nvq); mutex_unlock(&vq->mutex); return sock; } @@ -912,6 +1005,7 @@ static int vhost_net_release(struct inode *inode, struct file *f) /* We do an extra flush before freeing memory, * since jobs can re-queue themselves. */ vhost_net_flush(n); + kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue); kfree(n->dev.vqs); kvfree(n); return 0; @@ -950,6 +1044,25 @@ err: return ERR_PTR(r); } +static struct skb_array *get_tap_skb_array(int fd) +{ + struct skb_array *array; + struct file *file = fget(fd); + + if (!file) + return NULL; + array = tun_get_skb_array(file); + if (!IS_ERR(array)) + goto out; + array = tap_get_skb_array(file); + if (!IS_ERR(array)) + goto out; + array = NULL; +out: + fput(file); + return array; +} + static struct socket *get_tap_socket(int fd) { struct file *file = fget(fd); @@ -1026,6 +1139,9 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) vhost_net_disable_vq(n, vq); vq->private_data = sock; + vhost_net_buf_unproduce(nvq); + if (index == VHOST_NET_VQ_RX) + nvq->rx_array = get_tap_skb_array(fd); r = vhost_vq_init_access(vq); if (r) goto err_used;