Add basic-auth support.
authorRyan Jacobs <ryan@rmj.us>
Sat, 4 May 2019 00:43:42 +0000 (17:43 -0700)
committerEmil Mikulic <emikulic@gmail.com>
Wed, 1 Jul 2020 10:48:02 +0000 (20:48 +1000)
darkhttpd.c

index 76859d8..7a4b411 100644 (file)
@@ -235,7 +235,7 @@ struct connection {
     size_t request_length;
 
     /* request fields */
-    char *method, *url, *referer, *user_agent;
+    char *method, *url, *referer, *user_agent, *authorization;
     off_t range_begin, range_end;
     off_t range_begin_given, range_end_given;
 
@@ -301,6 +301,7 @@ static char *pidfile_name = NULL;   /* NULL = no pidfile */
 static int want_chroot = 0, want_daemon = 0, want_accf = 0,
            want_keepalive = 1, want_server_id = 1;
 static char *server_hdr = NULL;
+static char *auth_key = NULL;
 static uint64_t num_requests = 0, total_in = 0, total_out = 0;
 static int accepting = 1;           /* set to 0 to stop accept()ing */
 
@@ -933,6 +934,8 @@ static void usage(const char *argv0) {
     "\t\tIf a connection is idle for more than this many seconds,\n"
     "\t\tit will be closed. Set to zero to disable timeouts.\n\n",
     timeout_secs);
+    printf("\t--auth username:password\n"
+    "\t\tEnable basic authentication.\n\n");
 #ifdef HAVE_INET6
     printf("\t--ipv6\n"
     "\t\tListen on IPv6 address.\n\n");
@@ -941,6 +944,45 @@ static void usage(const char *argv0) {
 #endif
 }
 
+static char *base64_encode(char *str) {
+    const char base64_table[] = {
+        'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
+        'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
+        'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
+        'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
+        'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
+        'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
+        'w', 'x', 'y', 'z', '0', '1', '2', '3',
+        '4', '5', '6', '7', '8', '9', '+', '/'};
+
+    int input_length = strlen(str);
+    int output_length = 4 * ((input_length + 2) / 3);
+
+    char *encoded_data = malloc(output_length+1);
+    if (encoded_data == NULL) return NULL;
+
+    for (int i = 0, j = 0; i < input_length;) {
+
+        uint32_t octet_a = i < input_length ? (unsigned char)str[i++] : 0;
+        uint32_t octet_b = i < input_length ? (unsigned char)str[i++] : 0;
+        uint32_t octet_c = i < input_length ? (unsigned char)str[i++] : 0;
+
+        uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;
+
+        encoded_data[j++] = base64_table[(triple >> 3 * 6) & 0x3F];
+        encoded_data[j++] = base64_table[(triple >> 2 * 6) & 0x3F];
+        encoded_data[j++] = base64_table[(triple >> 1 * 6) & 0x3F];
+        encoded_data[j++] = base64_table[(triple >> 0 * 6) & 0x3F];
+    }
+
+    const int mod_table[] = {0, 2, 1};
+    for (int i = 0; i < mod_table[input_length % 3]; i++)
+        encoded_data[output_length - 1 - i] = '=';
+    encoded_data[output_length] = '\0';
+
+    return encoded_data;
+}
+
 /* Returns 1 if string is a number, 0 otherwise.  Set num to NULL if
  * disinterested in value.
  */
@@ -1095,6 +1137,14 @@ static void parse_commandline(const int argc, char *argv[]) {
                 errx(1, "missing number after --timeout");
             timeout_secs = (int)xstr_to_num(argv[i]);
         }
+        else if (strcmp(argv[i], "--auth") == 0) {
+            if (++i >= argc || strchr(argv[i], ':') == NULL)
+                errx(1, "missing 'user:pass' after --auth");
+
+            char *key = base64_encode(argv[i]);
+            xasprintf(&auth_key, "Basic %s", key);
+            free(key);
+        }
 #ifdef HAVE_INET6
         else if (strcmp(argv[i], "--ipv6") == 0) {
             inet6 = 1;
@@ -1118,6 +1168,7 @@ static struct connection *new_connection(void) {
     conn->url = NULL;
     conn->referer = NULL;
     conn->user_agent = NULL;
+    conn->authorization = NULL;
     conn->range_begin = 0;
     conn->range_end = 0;
     conn->range_begin_given = 0;
@@ -1288,6 +1339,7 @@ static void free_connection(struct connection *conn) {
     if (conn->url != NULL) free(conn->url);
     if (conn->referer != NULL) free(conn->referer);
     if (conn->user_agent != NULL) free(conn->user_agent);
+    if (conn->authorization != NULL) free(conn->authorization);
     if (conn->header != NULL && !conn->header_dont_free) free(conn->header);
     if (conn->reply != NULL && !conn->reply_dont_free) free(conn->reply);
     if (conn->reply_fd != -1) xclose(conn->reply_fd);
@@ -1311,6 +1363,7 @@ static void recycle_connection(struct connection *conn) {
     conn->url = NULL;
     conn->referer = NULL;
     conn->user_agent = NULL;
+    conn->authorization = NULL;
     conn->range_begin = 0;
     conn->range_end = 0;
     conn->range_begin_given = 0;
@@ -1442,6 +1495,9 @@ static void default_reply(struct connection *conn,
      errcode, errname, errname, reason, generated_on(date));
     free(reason);
 
+    const char *auth_header =
+        "WWW-Authenticate: Basic realm=\"User Visible Realm\"";
+
     conn->header_length = xasprintf(&(conn->header),
      "HTTP/1.1 %d %s\r\n"
      "Date: %s\r\n"
@@ -1450,9 +1506,11 @@ static void default_reply(struct connection *conn,
      "%s" /* keep-alive */
      "Content-Length: %llu\r\n"
      "Content-Type: text/html; charset=UTF-8\r\n"
+     "%s\r\n"
      "\r\n",
      errcode, errname, date, server_hdr, keep_alive(conn),
-     llu(conn->reply_length));
+     llu(conn->reply_length),
+     (auth_key != NULL ? auth_header : ""));
 
     conn->reply_type = REPLY_GENERATED;
     conn->http_code = errcode;
@@ -1653,6 +1711,7 @@ static int parse_request(struct connection *conn) {
     /* parse important fields */
     conn->referer = parse_field(conn, "Referer: ");
     conn->user_agent = parse_field(conn, "User-Agent: ");
+    conn->authorization = parse_field(conn, "Authorization: ");
     parse_range_field(conn);
     return 1;
 }
@@ -2085,10 +2144,19 @@ static void process_get(struct connection *conn) {
 /* Process a request: build the header and reply, advance state. */
 static void process_request(struct connection *conn) {
     num_requests++;
+
     if (!parse_request(conn)) {
         default_reply(conn, 400, "Bad Request",
             "You sent a request that the server couldn't understand.");
     }
+    // fail if: (auth_enabled) AND (client supplied invalid credentials)
+    if (auth_key != NULL &&
+            (conn->authorization == NULL ||
+             strcmp(conn->authorization, auth_key)))
+    {
+        default_reply(conn, 401, "Unauthorized",
+            "Access denied due to invalid credentials.");
+    }
     else if (strcmp(conn->method, "GET") == 0) {
         process_get(conn);
     }
@@ -2701,6 +2769,7 @@ int main(int argc, char **argv) {
         free(keep_alive_field);
         free(wwwroot);
         free(server_hdr);
+        free(auth_key);
     }
 
     /* usage stats */