tcp_diag 内核相关实现

前言

tcp_diag 是一个内核模块,本文的目的是梳理调用关系,如果从用户态的socket一路调用到tcp_diag模块dump出所有socket的。

大致分层关系 总结如下:netlink层->sock_diag层->inet_diag层->tcp_diag

用户态代码

类似 ss 功能的代码可以从 https://man7.org/linux/man-pages/man7/sock_diag.7.html 中获得,但是它只是打印 unix_socket  .sdiag_family = AF_UNIX,

       #include <errno.h>
       #include <stdio.h>
       #include <string.h>
       #include <unistd.h>
       #include <sys/socket.h>
       #include <sys/un.h>
       #include <linux/netlink.h>
       #include <linux/rtnetlink.h>
       #include <linux/sock_diag.h>
       #include <linux/unix_diag.h>

       static int
       send_query(int fd)
       {
           struct sockaddr_nl nladdr = {
               .nl_family = AF_NETLINK
           };
           struct
           {
               struct nlmsghdr nlh;
               struct unix_diag_req udr;
           } req = {
               .nlh = {
                   .nlmsg_len = sizeof(req),
                   .nlmsg_type = SOCK_DIAG_BY_FAMILY,
                   .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP
               },
               .udr = {
                   .sdiag_family = AF_UNIX,
                   .udiag_states = -1,
                   .udiag_show = UDIAG_SHOW_NAME | UDIAG_SHOW_PEER
               }
           };
           struct iovec iov = {
               .iov_base = &req,
               .iov_len = sizeof(req)
           };
           struct msghdr msg = {
               .msg_name = &nladdr,
               .msg_namelen = sizeof(nladdr),
               .msg_iov = &iov,
               .msg_iovlen = 1
           };

           for (;;) {
               if (sendmsg(fd, &msg, 0) < 0) {
                   if (errno == EINTR)
                       continue;

                   perror("sendmsg");
                   return -1;
               }

               return 0;
           }
       }

       static int
       print_diag(const struct unix_diag_msg *diag, unsigned int len)
       {
           if (len < NLMSG_LENGTH(sizeof(*diag))) {
               fputs("short response\n", stderr);
               return -1;
           }
           if (diag->udiag_family != AF_UNIX) {
               fprintf(stderr, "unexpected family %u\n", diag->udiag_family);
               return -1;
           }

           unsigned int rta_len = len - NLMSG_LENGTH(sizeof(*diag));
           unsigned int peer = 0;
           size_t path_len = 0;
           char path[sizeof(((struct sockaddr_un *) 0)->sun_path) + 1];

           for (struct rtattr *attr = (struct rtattr *) (diag + 1);
                    RTA_OK(attr, rta_len); attr = RTA_NEXT(attr, rta_len)) {
               switch (attr->rta_type) {
               case UNIX_DIAG_NAME:
                   if (!path_len) {
                       path_len = RTA_PAYLOAD(attr);
                       if (path_len > sizeof(path) - 1)
                           path_len = sizeof(path) - 1;
                       memcpy(path, RTA_DATA(attr), path_len);
                       path[path_len] = '\0';
                   }
                   break;

               case UNIX_DIAG_PEER:
                   if (RTA_PAYLOAD(attr) >= sizeof(peer))
                       peer = *(unsigned int *) RTA_DATA(attr);
                   break;
               }
           }

           printf("inode=%u", diag->udiag_ino);

           if (peer)
               printf(", peer=%u", peer);

           if (path_len)
               printf(", name=%s%s", *path ? "" : "@",
                       *path ? path : path + 1);

           putchar('\n');
           return 0;
       }

       static int
       receive_responses(int fd)
       {
           long buf[8192 / sizeof(long)];
           struct sockaddr_nl nladdr;
           struct iovec iov = {
               .iov_base = buf,
               .iov_len = sizeof(buf)
           };
           int flags = 0;

           for (;;) {
               struct msghdr msg = {
                   .msg_name = &nladdr,
                   .msg_namelen = sizeof(nladdr),
                   .msg_iov = &iov,
                   .msg_iovlen = 1
               };

               ssize_t ret = recvmsg(fd, &msg, flags);

               if (ret < 0) {
                   if (errno == EINTR)
                       continue;

                   perror("recvmsg");
                   return -1;
               }
               if (ret == 0)
                   return 0;

               if (nladdr.nl_family != AF_NETLINK) {
                   fputs("!AF_NETLINK\n", stderr);
                   return -1;
               }

               const struct nlmsghdr *h = (struct nlmsghdr *) buf;

               if (!NLMSG_OK(h, ret)) {
                   fputs("!NLMSG_OK\n", stderr);
                   return -1;
               }

               for (; NLMSG_OK(h, ret); h = NLMSG_NEXT(h, ret)) {
                   if (h->nlmsg_type == NLMSG_DONE)
                       return 0;

                   if (h->nlmsg_type == NLMSG_ERROR) {
                       const struct nlmsgerr *err = NLMSG_DATA(h);

                       if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) {
                           fputs("NLMSG_ERROR\n", stderr);
                       } else {
                           errno = -err->error;
                           perror("NLMSG_ERROR");
                       }

                       return -1;
                   }

                   if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
                       fprintf(stderr, "unexpected nlmsg_type %u\n",
                               (unsigned) h->nlmsg_type);
                       return -1;
                   }

                   if (print_diag(NLMSG_DATA(h), h->nlmsg_len))
                       return -1;
               }
           }
       }

       int
       main(void)
       {
           int fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);

           if (fd < 0) {
               perror("socket");
               return 1;
           }

           int ret = send_query(fd) || receive_responses(fd);

           close(fd);
           return ret;
       }

用户态代码 - 创建socket

nt fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);

对应 内核 调用路径 sys_socket -> netlink_create -> __netlink_create
因为用户态的入参是AF_NETLINK,netlink模块注册了对应的net_proto_family

static const struct net_proto_family netlink_family_ops = {
   .family = PF_NETLINK,
   .create = netlink_create,
   .owner   = THIS_MODULE, /* for consistency 8) */
};
sock_register(&netlink_family_ops);

__netlink_create只有一个我们需要关心的,就是操作函数 sock->ops,他决定了当你对netlink的fd调用send/recv等情况下,内核实际运行的函数

static const struct proto_ops netlink_ops = {
   .family =   PF_NETLINK,
   .owner = THIS_MODULE,
   .release =  netlink_release,
   .bind =     netlink_bind,
   .connect =  netlink_connect,
   .socketpair =  sock_no_socketpair,
   .accept =   sock_no_accept,
   .getname =  netlink_getname,
   .poll =     datagram_poll,
   .ioctl = netlink_ioctl,
   .listen =   sock_no_listen,
   .shutdown = sock_no_shutdown,
   .setsockopt =  netlink_setsockopt,
   .getsockopt =  netlink_getsockopt,
   .sendmsg =  netlink_sendmsg,
   .recvmsg =  netlink_recvmsg,
   .mmap =     sock_no_mmap,
   .sendpage = sock_no_sendpage,
};

static int __netlink_create(struct net *net, struct socket *sock,
             struct mutex *cb_mutex, int protocol,
             int kern)
{
   struct sock *sk;
   struct netlink_sock *nlk;

   sock->ops = &netlink_ops;
   ...

}

用户态代码 - 发送dump请求

这里,我们使用 AF_INET 替换上面例子中的AF_UNIX

struct
{
   struct nlmsghdr nlh;
   struct inet_diag_req_v2 r;
} req = {
   .nlh = {
       .nlmsg_len = sizeof(req),
       .nlmsg_type = SOCK_DIAG_BY_FAMILY,
       .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP
   },
   .r = {
       .sdiag_family = AF_INET,
       .idiag_states = ((1 << TCP_CLOSING + 1) - 1); //states to dump
   }
};
sendmsg(fd, &msg, 0)

sendmsg对应的内核调用路径 sys_sendmsg -> netlink_sendmsg -> 找到对应的内核socket -> netlink_unicast_kernel -> nlk_sk(sk)->netlink_rcv
netlink_sendmsg 会根据 fd的类型,使用 netlink_getsockbyportid 函数,通过NETLINK_SOCK_DIAG , 找到对应的内核socket,这个内核socket负责处理用户态程序send的数据

static int __net_init diag_net_init(struct net *net)
{
   struct netlink_kernel_cfg cfg = {
      .groups  = SKNLGRP_MAX,
      .input   = sock_diag_rcv,
      .bind = sock_diag_bind,
      .flags   = NL_CFG_F_NONROOT_RECV,
   };

   // nlk_sk(sk)->netlink_rcv = cfg.input
   net->diag_nlsk = netlink_kernel_create(net, NETLINK_SOCK_DIAG, &cfg);
   return net->diag_nlsk == NULL ? -ENOMEM : 0;
}

所以用户态的数据,因为通过NETLINK_SOCK_DIAG创建的socket,所以 首先会被 sock_diag_rcv处理
调用路径  ->sock_diag_rcv->sock_diag_rcv_msg

static void sock_diag_rcv(struct sk_buff *skb)
{
   mutex_lock(&sock_diag_mutex);
   netlink_rcv_skb(skb, &sock_diag_rcv_msg);
   mutex_unlock(&sock_diag_mutex);
}

static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
              struct netlink_ext_ack *extack)
{
   int ret;

   switch (nlh->nlmsg_type) {
   case TCPDIAG_GETSOCK:
   case DCCPDIAG_GETSOCK:
      if (inet_rcv_compat == NULL)
         sock_load_diag_module(AF_INET, 0);

      mutex_lock(&sock_diag_table_mutex);
      if (inet_rcv_compat != NULL)
         ret = inet_rcv_compat(skb, nlh);
      else
         ret = -EOPNOTSUPP;
      mutex_unlock(&sock_diag_table_mutex);

      return ret;
   case SOCK_DIAG_BY_FAMILY:
   case SOCK_DESTROY:
      return __sock_diag_cmd(skb, nlh);
   default:
      return -EINVAL;
   }
}

因为我们 sendmsg的msg入参类型是 SOCK_DIAG_BY_FAMILY,从而走到 __sock_diag_cmd 分支

static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh)
{
   int err;
   struct sock_diag_req *req = nlmsg_data(nlh);
   const struct sock_diag_handler *hndl;

   ...
   mutex_lock(&sock_diag_table_mutex);
   hndl = sock_diag_handlers[req->sdiag_family];
   if (hndl == NULL)
      err = -ENOENT;
   else if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY)
      err = hndl->dump(skb, nlh);
   else if (nlh->nlmsg_type == SOCK_DESTROY && hndl->destroy)
      err = hndl->destroy(skb, nlh);
   else
      err = -EOPNOTSUPP;
   mutex_unlock(&sock_diag_table_mutex);

   return err;
}

显然还有一层,sock_diag_handlers[req->sdiag_family]对于我们来说就是,因为msg的类型是 .sdiag_family = AF_INET,

static int __init inet_diag_init(void)
{
   const int inet_diag_table_size = (IPPROTO_MAX *
                 sizeof(struct inet_diag_handler *));
   int err = -ENOMEM;

   inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
   if (!inet_diag_table)
      goto out;

   err = sock_diag_register(&inet_diag_handler);
   if (err)
      goto out_free_nl;

   err = sock_diag_register(&inet6_diag_handler);
   if (err)
      goto out_free_inet;

static const struct sock_diag_handler inet_diag_handler = {
   .family = AF_INET,
   .dump = inet_diag_handler_cmd,
   .get_info = inet_diag_handler_get_info,
   .destroy = inet_diag_handler_cmd,
};

所以  hndl->dump 对于的其实是inet_diag_handler_cmd

static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
{
   int hdrlen = sizeof(struct inet_diag_req_v2);
   struct net *net = sock_net(skb->sk);


   if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
       h->nlmsg_flags & NLM_F_DUMP) {
        ....
      {
         struct netlink_dump_control c = {
            .dump = inet_diag_dump,
         };
         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
      }
   }

   return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
}

可以想象的是 netlink_dump_start 中肯定是调用了入参c->dump,即 inet_diag_dump,貌似至此,还是只是靠函数指针一路的调用。

inet_diag_dump->__inet_diag_dump->handler->dump

static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
             const struct inet_diag_req_v2 *r,
             struct nlattr *bc)
{
   const struct inet_diag_handler *handler;
   int err = 0;

   handler = inet_diag_lock_handler(r->sdiag_protocol);
   if (!IS_ERR(handler))
      handler->dump(skb, cb, r, bc);
   else
      err = PTR_ERR(handler);
   inet_diag_unlock_handler(handler);

   return err ? : skb->len;
}

可以看到,接着还是依靠注册机制,根据 sdiag_protocol 找到具体的dump函数指针

static const struct inet_diag_handler tcp_diag_handler = {
   .dump       = tcp_diag_dump,
   .dump_one      = tcp_diag_dump_one,
   .idiag_get_info      = tcp_diag_get_info,
   .idiag_get_aux    = tcp_diag_get_aux,
   .idiag_get_aux_size  = tcp_diag_get_aux_size,
   .idiag_type    = IPPROTO_TCP,
   .idiag_info_size  = sizeof(struct tcp_info),
#ifdef CONFIG_INET_DIAG_DESTROY
   .destroy    = tcp_diag_destroy,
#endif
};

static int __init tcp_diag_init(void)
{
   return inet_diag_register(&tcp_diag_handler);
}

至此,终于找到 tcp_diag 模块的dump函数了

Logo

更多推荐