Implement str_to_addr() and unit test.
[darkstat] / acct.c
diff --git a/acct.c b/acct.c
index ef6d514..65e2bc1 100644 (file)
--- a/acct.c
+++ b/acct.c
 #include <arpa/inet.h> /* for inet_aton() */
 #define __FAVOR_BSD
 #include <netinet/tcp.h>
+#include <sys/socket.h>
 #include <stdlib.h> /* for free */
 #include <string.h> /* for memcpy */
+#include <ctype.h>  /* isdigit() */
 
 uint64_t total_packets = 0, total_bytes = 0;
 
 static int using_localnet = 0;
+static int using_localnet6 = 0;
 static in_addr_t localnet, localmask;
+static struct in6_addr localnet6, localmask6;
 
 /* Parse the net/mask specification into two IPs or die trying. */
 void
 acct_init_localnet(const char *spec)
 {
-   char **tokens;
-   int num_tokens;
+   char **tokens, *p;
+   int num_tokens, isnum, j;
+   int build_ipv6;  /* Zero for IPv4, one for IPv6.  */
+   int pfxlen, octets, remainder;
    struct in_addr addr;
+   struct in6_addr addr6;
 
    tokens = split('/', spec, &num_tokens);
    if (num_tokens != 2)
       errx(1, "expecting network/netmask, got \"%s\"", spec);
 
-   if (inet_aton(tokens[0], &addr) != 1)
-      errx(1, "invalid network address \"%s\"", tokens[0]);
-   localnet = addr.s_addr;
+   /* Presence of a colon distinguishes address families.  */
+   if (strchr(tokens[0], ':')) {
+      build_ipv6 = 1;
+      if (inet_pton(AF_INET6, tokens[0], &addr6) != 1)
+         errx(1, "invalid IPv6 network address \"%s\"", tokens[0]);
+      memcpy(&localnet6, &addr6, sizeof(localnet6));
+   } else {
+      build_ipv6 = 0;
+      if (inet_pton(AF_INET, tokens[0], &addr) != 1)
+         errx(1, "invalid network address \"%s\"", tokens[0]);
+      localnet = addr.s_addr;
+   }
+
+   /* Detect a purely numeric argument.  */
+   isnum = 0;
+   p = tokens[1];
+   while (*p != '\0') {
+      if (isdigit(*p)) {
+         isnum = 1;
+         ++p;
+         continue;
+      } else {
+         isnum = 0;
+         break;
+      }
+   }
 
-   if (inet_aton(tokens[1], &addr) != 1)
-      errx(1, "invalid network mask \"%s\"", tokens[1]);
-   localmask = addr.s_addr;
-   /* FIXME: improve so we can accept masks like /24 for 255.255.255.0 */
+   if (!isnum) {
+      if (build_ipv6) {
+         if (inet_pton(AF_INET6, tokens[1], &addr6) != 1)
+            errx(1, "invalid IPv6 network mask \"%s\"", tokens[1]);
+         memcpy(&localmask6, &addr6, sizeof(localmask6));
+      } else {
+         if (inet_pton(AF_INET, tokens[1], &addr) != 1)
+            errx(1, "invalid network mask \"%s\"", tokens[1]);
+         localmask = addr.s_addr;
+      }
+   } else {
+      uint8_t frac, *p;
+
+      /* Compute the prefix length.  */
+      pfxlen = strtonum(tokens[1], 1, build_ipv6 ? 128 : 32, NULL);
+      if (pfxlen == 0)
+         errx(1, "invalid network prefix length \"%s\"", tokens[1]);
+
+      /* Construct the network mask.  */
+      octets = pfxlen / 8;
+      remainder = pfxlen % 8;
+      p = build_ipv6 ? (uint8_t *) localmask6.s6_addr : (uint8_t *) &localmask;
+
+      if (build_ipv6)
+         memset(&localmask6, 0, sizeof(localmask6));
+      else
+         memset(&localmask, 0, sizeof(localmask));
+
+      for (j = 0; j < octets; ++j)
+         p[j] = 0xff;
+
+      frac = 0xff << (8 - remainder);
+      if (frac)
+         p[j] = frac;   /* Have contribution for next position.  */
+   }
+
+   /* Register the correct netmask and calculate the correct net.  */
+   if (build_ipv6) {
+      using_localnet6 = 1;
+      for (j = 0; j < 16; ++j)
+         localnet6.s6_addr[j] &= localmask6.s6_addr[j];
+   } else {
+      using_localnet = 1;
+      localnet &= localmask;
+   }
 
-   using_localnet = 1;
    free(tokens[0]);
    free(tokens[1]);
    free(tokens);
 
-   verbosef("local network address: %s", ip_to_str(localnet));
-   verbosef("   local network mask: %s", ip_to_str(localmask));
+   if (build_ipv6) {
+      verbosef("local network address: %s", ip_to_str_af(&localnet6, AF_INET6));
+      verbosef("   local network mask: %s", ip_to_str_af(&localmask6, AF_INET6));
+   } else {
+      verbosef("local network address: %s", ip_to_str_af(&localnet, AF_INET));
+      verbosef("   local network mask: %s", ip_to_str_af(&localmask, AF_INET));
+   }
 
-   if ((localnet & localmask) != localnet)
-      errx(1, "this is an invalid combination of address and mask!\n"
-      "it cannot match any address!");
 }
 
 /* Account for the given packet summary. */
@@ -76,11 +148,13 @@ acct_for(const pktsummary *sm)
 {
    struct bucket *hs = NULL, *hd = NULL;
    struct bucket *ps, *pd;
-   int dir_in, dir_out;
+   struct addr46 ipaddr;
+   struct in6_addr scribble;
+   int dir_in, dir_out, j;
 
 #if 0 /* WANT_CHATTY? */
-   printf("%15s > ", ip_to_str(sm->src_ip));
-   printf("%15s ", ip_to_str(sm->dest_ip));
+   printf("%15s > ", ip_to_str_af(&sm->src_ip, AF_INET));
+   printf("%15s ", ip_to_str_af(&sm->dest_ip, AF_INET));
    printf("len %4d proto %2d", sm->len, sm->proto);
 
    if (sm->proto == IPPROTO_TCP || sm->proto == IPPROTO_UDP)
@@ -106,25 +180,37 @@ acct_for(const pktsummary *sm)
 
    if (sm->af == AF_INET) {
       if (using_localnet) {
-         if ((sm->src_ip & localmask) == localnet)
+         if ((sm->src_ip.s_addr & localmask) == localnet)
             dir_out = 1;
-         if ((sm->dest_ip & localmask) == localnet)
+         if ((sm->dest_ip.s_addr & localmask) == localnet)
             dir_in = 1;
          if (dir_in == 1 && dir_out == 1)
             /* Traffic staying within the network isn't counted. */
             dir_in = dir_out = 0;
       } else {
-         if (sm->src_ip == localip)
+         if (memcmp(&sm->src_ip, &localip, sizeof(localip)) == 0)
             dir_out = 1;
-         if (sm->dest_ip == localip)
+         if (memcmp(&sm->dest_ip, &localip, sizeof(localip)) == 0)
             dir_in = 1;
       }
    } else if (sm->af == AF_INET6) {
-      /* Only exact address has been implemented. */
-      if (memcmp(&sm->src_ip6, &localip6, sizeof(localip6)) == 0)
-         dir_out = 1;
-      if (memcmp(&sm->dest_ip6, &localip6, sizeof(localip6)) == 0)
-         dir_in = 1;
+      if (using_localnet6) {
+         for (j = 0; j < 16; ++j)
+            scribble.s6_addr[j] = sm->src_ip6.s6_addr[j] & localmask6.s6_addr[j];
+         if (memcmp(&scribble, &localnet6, sizeof(scribble)) == 0)
+            dir_out = 1;
+         else {
+            for (j = 0; j < 16; ++j)
+               scribble.s6_addr[j] = sm->dest_ip6.s6_addr[j] & localmask6.s6_addr[j];
+            if (memcmp(&scribble, &localnet6, sizeof(scribble)) == 0)
+               dir_in = 1;
+         }
+      } else {
+         if (memcmp(&sm->src_ip6, &localip6, sizeof(localip6)) == 0)
+            dir_out = 1;
+         if (memcmp(&sm->dest_ip6, &localip6, sizeof(localip6)) == 0)
+            dir_in = 1;
+      }
    }
 
    if (dir_out) {
@@ -136,25 +222,51 @@ acct_for(const pktsummary *sm)
       graph_acct((uint64_t)sm->len, GRAPH_IN);
    }
 
-   if (sm->af == AF_INET6) return; /* Still no continuation for IPv6! */
-
    if (hosts_max == 0) return; /* skip per-host accounting */
 
    /* Hosts. */
-   hs = host_get(sm->src_ip);
+   ipaddr.af = sm->af;
+   switch (ipaddr.af) {
+      case AF_INET6:
+         memcpy(&ipaddr.addr.ip6, &sm->src_ip6, sizeof(ipaddr.addr.ip6));
+         break;
+      case AF_INET:
+      default:
+         memcpy(&ipaddr.addr.ip, &sm->src_ip, sizeof(ipaddr.addr.ip));
+         break;
+   }
+   hs = host_get(&ipaddr);
    hs->out   += sm->len;
    hs->total += sm->len;
    memcpy(hs->u.host.mac_addr, sm->src_mac, sizeof(sm->src_mac));
    hs->u.host.last_seen = now;
 
-   hd = host_get(sm->dest_ip); /* this can invalidate hs! */
+   switch (ipaddr.af) {
+      case AF_INET6:
+         memcpy(&ipaddr.addr.ip6, &sm->dest_ip6, sizeof(ipaddr.addr.ip6));
+         break;
+      case AF_INET:
+      default:
+         memcpy(&ipaddr.addr.ip, &sm->dest_ip, sizeof(ipaddr.addr.ip));
+         break;
+   }
+   hd = host_get(&ipaddr); /* this can invalidate hs! */
    hd->in    += sm->len;
    hd->total += sm->len;
    memcpy(hd->u.host.mac_addr, sm->dst_mac, sizeof(sm->dst_mac));
    hd->u.host.last_seen = now;
 
    /* Protocols. */
-   hs = host_find(sm->src_ip);
+   switch (ipaddr.af) {
+      case AF_INET6:
+         memcpy(&ipaddr.addr.ip6, &sm->src_ip6, sizeof(ipaddr.addr.ip6));
+         break;
+      case AF_INET:
+      default:
+         memcpy(&ipaddr.addr.ip, &sm->src_ip, sizeof(ipaddr.addr.ip));
+         break;
+   }
+   hs = host_find(&ipaddr);
    if (hs != NULL) {
       ps = host_get_ip_proto(hs, sm->proto);
       ps->out   += sm->len;
@@ -206,6 +318,9 @@ acct_for(const pktsummary *sm)
 
    case IPPROTO_ICMP:
    case IPPROTO_ICMPV6:
+   case IPPROTO_AH:
+   case IPPROTO_ESP:
+   case IPPROTO_OSPF:
       /* known protocol, don't complain about it */
       break;