#include <errno.h>
#include <fcntl.h>
#include <libudev.h>
#include <linux/input.h>
#include <linux/memfd.h>
#include <linux/vt.h>
#include <signal.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <systemd/sd-bus.h>
#include <systemd/sd-event.h>
#include <systemd/sd-login.h>
#include <unistd.h>
#include <xf86drm.h>
#include <xf86drmMode.h>

#include <shell.h>
#include <keymap.h>

#define STB_IMAGE_WRITE_IMPLEMENTATION
#include <stb_image_write.h>

#define array_size(a) (sizeof(a) / sizeof((a)[0]))

#define BYTES_PER_PIXEL 4

enum {
    DISPLAY_CONTROLLER_CHOSEN = 1,
    DISPLAY_HAS_FRAME_BUFFERS = 2,
    DISPLAY_SAVED = 4,
};

enum {
    NUM_WORKSPACES = 4,
    NUM_WINDOWS_PER_WORKSPACE = 3,
};

struct display_buffer {
    void *mem;
    uint64_t size;
    uint32_t width;
    uint32_t height;
    uint32_t stride;
    uint32_t handle;
    uint32_t fb;
};

struct display {
    int fd;
    struct display_buffer buffers[2];
    unsigned int back_buffer_index:1;
    unsigned int back_buffer_is_locked:1;
    unsigned int is_draw_needed:1;

    uint32_t state;

    drmModeModeInfo mode;
    uint32_t conn;
    uint32_t crtc;
    drmModeCrtc *saved_crtc;
};

struct window {
    uint32_t x;
    uint32_t y;
    uint32_t width;
    uint32_t height;
    unsigned int front_buffer:1;
    void *mem;
    int socket_fd;
    pid_t pid;
};

struct workspace {
    struct window windows[NUM_WINDOWS_PER_WORKSPACE];
    unsigned focus;
};

struct keyboard {
    int fd;
    uint32_t major;
    uint32_t minor;
    uint8_t key_bits[SHELL_KEY_BITS_SIZE];
};

struct root {
    bool is_active;
    struct workspace workspaces[NUM_WORKSPACES];
    unsigned focus;
    sd_bus *bus;
    sd_event *event;
    struct keyboard keyboard;
    struct display display;
};

_Noreturn static void
die(const char *m)
{
    fprintf(stderr, "%s\n", m);
    exit(1);
}

static void *
xmalloc(size_t size)
{
    if (size == 0)
        die("Memory allocation error.");
    void *x = malloc(size);
    if (x == NULL)
        die("Memory allocation failed.");
    return x;
}

static void
display_init(struct display *disp)
{
    disp->state = 0;
    disp->back_buffer_index = 1;
    disp->back_buffer_is_locked = 0;
    disp->is_draw_needed = 0;
}

static struct display_buffer *
display_front_buffer(struct display *disp)
{
    return &disp->buffers[1 - disp->back_buffer_index];
}

static struct display_buffer *
display_back_buffer(struct display *disp)
{
    return &disp->buffers[disp->back_buffer_index];
}

static void
display_connect(struct display *disp, int fd)
{
    disp->fd = fd;

    //  We need "dumb buffer" support so fail if the device doesn't have that.

    {
        uint64_t has_dumb;
        int r = drmGetCap(fd, DRM_CAP_DUMB_BUFFER, &has_dumb);
        if (r < 0 || !has_dumb)
            die("Failed to connect to the display.");
    }

    //  We choose an arbitrary connector that's currently active, a mode to use
    //  on it (determines resolution), and a controller (CRTC) that's
    //  compatible with one of the connector's encoders.

    {
        drmModeRes *res = drmModeGetResources(disp->fd);
        if (!res)
            die("Failed to connect to display.");
        bool did_choose_crtc = false;
        for (int i = 0; i < res->count_connectors; i++) {
            drmModeConnector *conn =
                drmModeGetConnector(disp->fd, res->connectors[i]);
            if (!conn)
                continue;
            if (conn->connection == DRM_MODE_CONNECTED && conn->count_modes > 0) {
                drmModeModeInfo *mode;
                {
                    int j = conn->count_modes - 1;
                    while (j > 0 && !(conn->modes[j].type & DRM_MODE_TYPE_PREFERRED))
                        j--;
                    mode = &conn->modes[j];
                }
                for (int j = 0; j < conn->count_encoders; j++) {
                    drmModeEncoder *enc =
                        drmModeGetEncoder(disp->fd, conn->encoders[j]);
                    if (!enc)
                        continue;
                    for (int k = 0; k < res->count_crtcs; k++) {
                        if (enc->possible_crtcs & (1 << k)) {
                            disp->crtc = res->crtcs[k];
                            disp->conn = conn->connector_id;
                            disp->mode = *mode;
                            did_choose_crtc = true;
                            break;
                        }
                    }
                    drmModeFreeEncoder(enc);
                    if (did_choose_crtc)
                        break;
                }
            }
            drmModeFreeConnector(conn);
            if (did_choose_crtc)
                break;
        }
        drmModeFreeResources(res);
        if (!did_choose_crtc)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_CONTROLLER_CHOSEN;

    //  Allocate and mmap the two full-screen pixel buffers.

    for (int i = 0; i < 2; i++) {
        struct display_buffer *buffer = &disp->buffers[i];
        buffer->width = disp->mode.hdisplay;
        buffer->height = disp->mode.vdisplay;
        struct drm_mode_create_dumb ioctl_create = {
            .width = buffer->width,
            .height = buffer->height,
            .bpp = 32,
        };
        int r = drmIoctl(fd, DRM_IOCTL_MODE_CREATE_DUMB, &ioctl_create);
        if (r < 0)
            die("Failed to connect to display.");
        buffer->stride = ioctl_create.pitch;
        buffer->size = ioctl_create.size;
        buffer->handle = ioctl_create.handle;
        r = drmModeAddFB(fd, buffer->width, buffer->height, 24, 32,
            buffer->stride, buffer->handle, &buffer->fb);
        if (r)
            die("Failed to connect to display.");
        struct drm_mode_map_dumb ioctl_map = {
            .handle = buffer->handle,
        };
        r = drmIoctl(fd, DRM_IOCTL_MODE_MAP_DUMB, &ioctl_map);
        if (r)
            die("Failed to connect to display.");
        buffer->mem = mmap(0, buffer->size, PROT_READ | PROT_WRITE,
            MAP_SHARED, fd, ioctl_map.offset);
        if (buffer->mem == MAP_FAILED)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_HAS_FRAME_BUFFERS;

    //  Save the current state for the CRTC we are about to take over. We will
    //  restore this state when we disconnect from the display.
    //
    //  Take over the connector and CRTC that we selected above.

    {
        struct display_buffer *front_buffer = display_front_buffer(disp);

        disp->saved_crtc = drmModeGetCrtc(disp->fd, disp->crtc);

        int r = drmModeSetCrtc(disp->fd, disp->crtc, front_buffer->fb, 0, 0,
            &disp->conn, 1, &disp->mode);
        if (r)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_SAVED;
}

static void
display_disconnect(struct display *disp)
{
    if (disp->state & DISPLAY_SAVED) {
        drmModeSetCrtc(
            disp->fd,
            disp->saved_crtc->crtc_id,
            disp->saved_crtc->buffer_id,
            disp->saved_crtc->x,
            disp->saved_crtc->y,
            &disp->conn,
            1,
            &disp->saved_crtc->mode);
        drmModeFreeCrtc(disp->saved_crtc);
        disp->state &= ~DISPLAY_SAVED;
    }
    if (disp->state & DISPLAY_HAS_FRAME_BUFFERS) {
        for (int i = 0; i < 2; i++) {
            struct display_buffer *buffer = &disp->buffers[i];
            munmap(buffer->mem, buffer->size);
            drmModeRmFB(disp->fd, buffer->fb);
            drmIoctl(disp->fd, DRM_IOCTL_MODE_DESTROY_DUMB,
                &(struct drm_mode_destroy_dumb){
                    .handle = buffer->handle,
                });
        }
        disp->state &= ~DISPLAY_HAS_FRAME_BUFFERS;
    }
}

static void
set_bit(uint8_t *bits, unsigned i)
{
    bits[i / 8] |= (1 << (i % 8));
}

static void
clear_bit(uint8_t *bits, unsigned i)
{
    bits[i / 8] &= ~(1 << (i % 8));
}

static int
fetch_bit(const uint8_t *bits, unsigned i)
{
    return (bits[i / 8] & (1 << (i % 8))) ? 1 : 0;
}

struct keyboard *
keyboard_new(void)
{
    struct keyboard *keyboard = xmalloc(sizeof(*keyboard));
    keyboard->fd = -1;
    return keyboard;
}

int
keyboard_connect(struct keyboard *keyboard, int fd,
        unsigned int major, unsigned int minor)
{
    keyboard->fd = fd;
    keyboard->major = major;
    keyboard->minor = minor;

    int r = ioctl(fd, EVIOCGKEY(SHELL_KEY_BITS_SIZE), keyboard->key_bits);
    if (r == -1)
        die("Failed to read keyboard state.");

    return 0;
}

static void
blit(void *dest, uint32_t dest_width, uint32_t x, uint32_t y,
        void *source, uint32_t source_width, uint32_t source_height)
{
    for (uint32_t i = 0; i < source_height; i++) {
        memmove(dest + BYTES_PER_PIXEL * (dest_width * (y + i) + x),
                source + BYTES_PER_PIXEL * source_width * i,
                BYTES_PER_PIXEL * source_width);
    }
}

static uint32_t
encode_color_xrgb(uint8_t r, uint8_t g, uint8_t b)
{
    union {
        uint32_t u;
        uint8_t bytes[4];
    } x = {
        .bytes = { b, g, r, 0 },
    };
    return x.u;
}

static void
clear_buffer(void *buffer, uint32_t width, uint32_t height)
{
    uint32_t color = encode_color_xrgb(0xff, 0xff, 0xff);
    uint32_t *pixel = buffer;
    for (unsigned int j = 0; j < width; j++)
        for (unsigned int k = 0; k < height; k++)
            *pixel++ = color;
}

static void
draw_frame(void *mem, uint32_t width, struct window *window, bool has_focus)
{
    int x0 = window->x - 5;
    int y0 = window->y - 5;
    int x1 = window->x + window->width + 4;
    int y1 = window->y + window->height + 4;

    uint32_t color = has_focus
        ? encode_color_xrgb(0x48, 0x48, 0x48)
        : encode_color_xrgb(0xcc, 0xcc, 0xcc);

    uint32_t *pixels = mem;

    for (int x = x0, i = width * y0 + x0; x <= x1; x++, i++)
        pixels[i] = color;

    for (int x = x0, i = width * y1 + x0; x <= x1; x++, i++)
        pixels[i] = color;

    for (int y = y0, i = width * y0 + x0; y <= y1; y++, i += width)
        pixels[i] = color;

    for (int y = y0, i = width * y0 + x1; y <= y1; y++, i += width)
        pixels[i] = color;
}

static void
draw_window(void *back_buffer, uint32_t disp_width, struct window *window)
{
    if (window->pid == -1)
        return;
    void *buffer = window->mem + (window->front_buffer * BYTES_PER_PIXEL *
            window->width * window->height);
    blit(back_buffer, disp_width, window->x, window->y,
            buffer, window->width, window->height);
}

static void
draw(struct root *root)
{
    struct display *disp = &root->display;

    if (disp->back_buffer_is_locked || !root->is_active) {
        disp->is_draw_needed = 1;
        return;
    }

    struct display_buffer *back_buffer = display_back_buffer(disp);
    uint32_t disp_width = back_buffer->width;
    uint32_t disp_height = back_buffer->height;

    clear_buffer(back_buffer->mem, disp_width, disp_height);

    struct workspace *workspace = &root->workspaces[root->focus];
    for (int i = 0; i < NUM_WINDOWS_PER_WORKSPACE; i++) {
        struct window *window = &workspace->windows[i];
        draw_frame(back_buffer->mem, disp_width, window, i == workspace->focus);
        draw_window(back_buffer->mem, disp_width, window);
    }

    int r = drmModePageFlip(disp->fd, disp->crtc, back_buffer->fb,
        DRM_MODE_PAGE_FLIP_EVENT, root);
    if (r)
        fprintf(stderr, "Failed to flip back buffer onto display.");

    disp->back_buffer_index = !disp->back_buffer_index;
    disp->back_buffer_is_locked = 1;
    disp->is_draw_needed = 0;
}

static void
root_init(struct root *root, sd_bus *bus, sd_event *event,
        uint32_t disp_width, uint32_t disp_height)
{
    root->is_active = true;
    root->bus = bus;
    root->event = event;
    root->focus = 0;

    for (int i = 0; i < NUM_WORKSPACES; i++) {
        struct workspace *workspace = &root->workspaces[i];

        workspace->focus = 0;

        for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
            struct window *window = &workspace->windows[j];
            switch (j) {
            case 0:
                window->x = 20;
                window->y = 20;
                window->width = (disp_width / 2) - 30;
                window->height = disp_height - 40;
                break;
            case 1:
                window->x = (disp_width / 2) + 10;
                window->y = 20;
                window->width = (disp_width / 2) - 30;
                window->height = (disp_height / 2) - 30;
                break;
            case 2:
                window->x = (disp_width / 2) + 10;
                window->y = (disp_height / 2) + 10;
                window->width = (disp_width / 2) - 30;
                window->height = (disp_height / 2) - 30;
                break;
            default:
                die("Failed to initialize workspaces.");
                break;
            }
            window->front_buffer = 0;
            window->mem = NULL;
            window->socket_fd = -1;
            window->pid = -1;
        }
    }
}

static struct window *
root_lookup_window_by_fd(struct root *root, int fd)
{
    for (int i = 0; i < NUM_WORKSPACES; i++) {
        struct workspace *workspace = &root->workspaces[i];
        for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
            struct window *window = &workspace->windows[j];
            if (window->socket_fd == fd)
                return window;
        }
    }
    return NULL;
}

static void send_keyboard_detached(struct window *window);
static void send_keyboard_attached(struct window *window, uint8_t *key_bits);

static void
root_set_workspace_focus(struct root *root, int focus)
{
    if (root->focus == focus)
        return;

    struct workspace *workspace;
    struct window *window;

    workspace = &root->workspaces[root->focus];
    window = &workspace->windows[workspace->focus];
    send_keyboard_detached(window);

    root->focus = focus;

    workspace = &root->workspaces[root->focus];
    window = &workspace->windows[workspace->focus];
    send_keyboard_attached(window, root->keyboard.key_bits);

    draw(root);
}

static void
choose_keyboard(uint32_t *major, uint32_t *minor)
{
    struct udev *udev = udev_new();
    if (udev == NULL)
        die("Failed to choose keyboard.");
    struct udev_enumerate *enumerate = udev_enumerate_new(udev);
    if (enumerate == NULL)
        die("Failed to choose keyboard.");
    int r;
    r = udev_enumerate_add_match_property(enumerate, "ID_INPUT_KEYBOARD", "1");
    if (r < 0)
        die("Failed to choose keyboard.");
    r = udev_enumerate_scan_devices(enumerate);
    if (r < 0)
        die("Failed to choose keyboard.");
    struct udev_list_entry *first_entry = udev_enumerate_get_list_entry(enumerate);
    if (first_entry == NULL)
        die("Failed to choose keyboard.");
    struct udev_list_entry *entry;
    *major = 0;
    *minor = 0;
    udev_list_entry_foreach(entry, first_entry) {
        const char *name = udev_list_entry_get_name(entry);
        if (name == NULL)
            continue;
        struct udev_device *device = udev_device_new_from_syspath(udev, name);
        if (device == NULL)
            continue;
        dev_t dev = udev_device_get_devnum(device);
        if (major(dev) == 0 && minor(dev) == 0) {
            udev_device_unref(device);
            continue;
        }
        *major = major(dev);
        *minor = minor(dev);
        udev_device_unref(device);
        break;
    }
    udev_enumerate_unref(enumerate);
    udev_unref(udev);
    if (*major == 0 && *minor == 0)
        die("Failed to choose keyboard.");
}

static void
switch_to_vt(int n)
{
    int fd = open("/dev/tty", O_RDWR);
    if (fd == -1)
        die("Failed to switch VT.");

    int r;

    r = ioctl(fd, VT_ACTIVATE, n);
    if (r == -1)
        die("Failed to switch VT.");

    r = ioctl(fd, VT_WAITACTIVE, n);
    if (r == -1)
        die("Failed to switch VT.");

    close(fd);
}

static int
sigchld_handler(sd_event_source *source,
        const struct signalfd_siginfo *siginfo, void *data)
{
    struct root *root = data;

    (void)source;
    (void)siginfo;

    for (;;) {
        int status;
        pid_t pid;

        pid = waitpid(-1, &status, WUNTRACED|WCONTINUED|WNOHANG);
        if (pid == 0)
            break;
        if (pid == -1) {
            if (errno == ECHILD)
                break;
            die("Failed to handle SIGCHLD.");
        }

        if (WIFEXITED(status)) {
            fprintf(stderr, "Child exited with status %d.\n",
                    (int)WEXITSTATUS(status));
        } else if (WIFSIGNALED(status)) {
            fprintf(stderr, "Child was terminated by signal %d.\n",
                    (int)WTERMSIG(status));
        } else if (WIFSTOPPED(status)) {
            fprintf(stderr, "Child was stopped.\n");
            continue;
        } else if (WIFCONTINUED(status)) {
            fprintf(stderr, "Child was continued.\n");
            continue;
        }

        for (int i = 0; i < NUM_WORKSPACES; i++) {
            struct workspace *workspace = &root->workspaces[i];
            for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
                struct window *window = &workspace->windows[j];
                if (window->pid == pid) {
                    window->front_buffer = 0;
                    size_t mem_size =
                        2 * BYTES_PER_PIXEL * window->width * window->height;
                    int r = munmap(window->mem, mem_size);
                    if (r == -1)
                        die("Failed to release draw buffers.");
                    window->mem = NULL;
                    close(window->socket_fd);
                    window->socket_fd = -1;
                    window->pid = -1;
                }
            }
        }
    }

    draw(root);

    return 0;
}

static void
send_welcome(struct window *window)
{
    struct shell_message_welcome message = {
        .class = SHELL_MESSAGE_WELCOME,
        .width = window->width,
        .height = window->height,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    /*
    if (r == -1 && errno == EPIPE) {
        kill(window->pid, SIGKILL);
        return;
    }
    */
    if (r != sizeof(message)) {
        //die("Failed to communicate with client process.");
        perror("Failed to send welcome message to child process");
        return;
    }
}

static void
send_keyboard_event(struct window *window, struct input_event *event)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_event message = {
        .class = SHELL_MESSAGE_KEYBOARD_EVENT,
        .code = event->code,
        .value = event->value,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    /*
    if (r == -1 && (errno == EPIPE || errno == EBADF)) {
        kill(window->pid, SIGKILL);
        return;
    }
    */
    if (r != sizeof(message)) {
        //die("Failed to communicate with client process.");
        perror("Failed to send keyboard event message to child process");
        //exit(1);
        return;
    }
}

static void
send_reset_completed(struct window *window)
{
    struct shell_message_reset_completed message = {
        .class = SHELL_MESSAGE_RESET_COMPLETED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    /*
    if (r == -1 && (errno == EPIPE || errno == EBADF)) {
        kill(window->pid, SIGKILL);
        return;
    }
    */
    if (r != sizeof(message)) {
        //die("Failed to communicate with client process.");
        perror("Failed to send buffer unlocked message to child process");
        //exit(1);
        return;
    }
}

static void
send_buffer_unlocked(struct window *window)
{
    if (window->pid == -1)
        return;

    struct shell_message_buffer_unlocked message = {
        .class = SHELL_MESSAGE_BUFFER_UNLOCKED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    /*
    if (r == -1 && (errno == EPIPE || errno == EBADF)) {
        kill(window->pid, SIGKILL);
        return;
    }
    */
    if (r != sizeof(message)) {
        //die("Failed to communicate with client process.");
        perror("Failed to send buffer unlocked message to child process");
        //exit(1);
        return;
    }
}

static void
send_keyboard_detached(struct window *window)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_detached message = {
        .class = SHELL_MESSAGE_KEYBOARD_DETACHED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    /*
    if (r == -1 && (errno == EPIPE || errno == EBADF)) {
        kill(window->pid, SIGKILL);
        return;
    }
    */
    if (r != sizeof(message)) {
        perror("Failed to send keyboard detached message to child process");
        return;
    }
}

static void
send_keyboard_attached(struct window *window, uint8_t *key_bits)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_attached message = {
        .class = SHELL_MESSAGE_KEYBOARD_ATTACHED,
    };
    memmove(message.key_bits, key_bits, SHELL_KEY_BITS_SIZE);

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    // TODO Rethink this.
    if (r != sizeof(message)) {
        perror("Failed to send keyboard attached message to child process");
        return;
    }
}

static int
child_message_event_handler(sd_event_source *source, int fd, uint32_t events,
        void *data)
{
    int drop(void) {
        sd_event_source_unref(source);
        return 0;
    }

    //  TODO What kinds of errors (EPOLLERR) can we get? How should they be
    //  handled?

    if (events & EPOLLERR)
        return drop();
    if (events & EPOLLHUP)
        return drop();
    if (events != EPOLLIN)
        die("Unexpected condition on socket.");

    static _Alignas(max_align_t) char packet[SHELL_PACKET_SIZE_MAX];
    size_t packet_size;
    for (;;) {
        ssize_t r = read(fd, packet, SHELL_PACKET_SIZE_MAX);
        if (r == -1 && errno == EINTR)
            continue;
        if (r == -1)
            return drop();
        packet_size = r;
        break;
    }

    struct root *root = data;
    struct window *window = root_lookup_window_by_fd(root, fd);
    if (window == NULL)
        return drop();

    //  TODO The following error should not be fatal. We need a general way to
    //  deal with protocol errors.

    if (packet_size < sizeof(uint16_t))
        die("Protocol error on socket.");

    uint16_t class = *(uint16_t *)packet;
    switch (class) {
    case SHELL_MESSAGE_RESET:
        if (window->front_buffer != 0) {
            size_t buffer_size = BYTES_PER_PIXEL * window->width * window->height;
            memmove(window->mem, window->mem + buffer_size, buffer_size);
            window->front_buffer = 0;
        }
        send_reset_completed(window);
        send_welcome(window);
        break;
    case SHELL_MESSAGE_SHOW:
        window->front_buffer = !window->front_buffer;
        send_buffer_unlocked(window);
        draw(root);
        break;
    default:
        fprintf(stderr, "Unexpected message class received from client.");
        break;
    }
    return 0;
}

static void
enable_cloexec(int fd)
{
    const char *error_message = "Failed to enable CLOEXEC on file.";
    int r = fcntl(fd, F_GETFD);
    if (r == -1)
        die(error_message);
    int flags = r;
    flags |= FD_CLOEXEC;
    r = fcntl(fd, F_SETFD, flags);
    if (r == -1)
        die(error_message);
}

static void
disable_nonblock(int fd)
{
    const char *error_message = "Failed to disable NONBLOCK on file.";
    int r = fcntl(fd, F_GETFL);
    if (r == -1)
        die(error_message);
    int flags = r;
    flags &= ~O_NONBLOCK;
    r = fcntl(fd, F_SETFL, flags);
    if (r == -1)
        die(error_message);
}

static void
launch_client(struct root *root, const char *path, const char *arg0)
{
    struct workspace *workspace = &root->workspaces[root->focus];
    struct window *window = &workspace->windows[workspace->focus];

    if (window->pid != -1)
        return;

    int r;

    size_t mem_size = 2 * BYTES_PER_PIXEL * window->width * window->height;
    char filename[64];
    r = snprintf(filename, sizeof(filename), "buffer_%u_%u",
            root->focus, workspace->focus);
    if (r >= sizeof(filename))
        die("Failed to create draw buffers for child process.");
    int mem_fd = syscall(__NR_memfd_create, filename, 0);
    if (mem_fd == -1)
        die("Failed to create draw buffers for child process.");
    do { r = ftruncate(mem_fd, mem_size); } while (r == -1 && errno == EINTR);
    if (r == -1)
        die("Failed to create draw buffers for child process.");
    void *mem = mmap(NULL, mem_size, PROT_READ|PROT_WRITE, MAP_SHARED,
            mem_fd, 0);
    if (mem == MAP_FAILED)
        die("Failed to create draw buffers for child process.");

    window->mem = mem;

    int socket_fds[2];
    r = socketpair(AF_UNIX, SOCK_SEQPACKET|SOCK_NONBLOCK, 0, socket_fds);
    if (r == -1)
        die("Failed to create socket for child process.");

    window->socket_fd = socket_fds[0];
    send_welcome(window);

    //  TODO This is weird. We pass the event_source parameter only to trigger
    //  a certain behavior; otherwise, the library allocates an event_source
    //  object that it deallocates when the event loop is deallocated. We have
    //  no use for the actual object here but we want to be able to deallocate
    //  it when the child terminates.

    sd_event_source *event_source;
    r = sd_event_add_io(root->event, &event_source, window->socket_fd,
            EPOLLIN, child_message_event_handler, root);
    if (r < 0)
        die("Failed to add child message event handler.");

    clear_buffer(window->mem, window->width, window->height);

    pid_t pid = fork();
    switch (pid) {
    case -1:
        die("Failed to fork child process.");
        break;
    case 0:
        {
            sigset_t sigset;
            int r;
            const char error_message[] =
                "Failed to initialize signal handling for child process.";
            r = sigemptyset(&sigset);
            if (r == -1)
                die(error_message);
            r = sigprocmask(SIG_SETMASK, &sigset, NULL);
            if (r == -1)
                die(error_message);
            signal(SIGPIPE, SIG_DFL);
            signal(SIGINT, SIG_DFL);
            signal(SIGTERM, SIG_DFL);
            signal(SIGCHLD, SIG_DFL);
        }
        close(socket_fds[0]);
        if (socket_fds[1] != 3) {

            //  TODO We close fd 3 here without any guarantee that mem_fd != 3.
            //  How to fix this problem?

            do {
                r = dup2(socket_fds[1], 3);
            } while (r == -1 && errno == EINTR);
            if (r == -1)
                die("Failed while launching child process.");
            close(socket_fds[1]);
        }
        if (mem_fd != 4) {
            do {
                r = dup2(mem_fd, 4);
            } while (r == -1 && errno == EINTR);
            if (r == -1)
                die("Failed while launching child process.");
            close(mem_fd);
        }
        execlp(path, arg0, (char *)NULL);
        die("Failed to exec child process.");
        break;
    default:
        window->pid = pid;
        close(mem_fd);
        close(socket_fds[1]);
        enable_cloexec(socket_fds[0]);
        break;
    }
}

static void
on_key_release(sd_event_source *source, int code)
{
    struct root *root = sd_event_source_get_userdata(source);

    struct keyboard *keyboard = &root->keyboard;
    if (code >= 0 && code <= KEY_MAX)
        clear_bit(keyboard->key_bits, (unsigned)code);
}

static void
on_key_press(sd_event_source *source, int code)
{
    int r;

    struct root *root = sd_event_source_get_userdata(source);

    sd_event *event = sd_event_source_get_event(source);

    struct keyboard *keyboard = &root->keyboard;
    if (code >= 0 && code <= KEY_MAX)
        set_bit(keyboard->key_bits, (unsigned)code);

    if (!fetch_bit(keyboard->key_bits, KEY_LEFTMETA) &&
            !fetch_bit(keyboard->key_bits, KEY_RIGHTMETA))
        return;

    switch (code) {
    case KEY_ESC:
        r = sd_event_exit(event, 0);
        if (r < 0)
            die("Failed to handle keyboard event.");
        return;
    case KEY_TAB:
        {
            struct workspace *workspace = &root->workspaces[root->focus];
            struct window *window = &workspace->windows[workspace->focus];
            uint8_t key_bits[SHELL_KEY_BITS_SIZE];
            memmove(key_bits, keyboard->key_bits, SHELL_KEY_BITS_SIZE);
            clear_bit(key_bits, KEY_LEFTMETA);
            clear_bit(key_bits, KEY_RIGHTMETA);
            clear_bit(key_bits, KEY_TAB);
            send_keyboard_detached(window);
            workspace->focus =
                (workspace->focus + 1) % NUM_WINDOWS_PER_WORKSPACE;
            window = &workspace->windows[workspace->focus];
            send_keyboard_attached(window, key_bits);
            draw(root);
        }
        return;
    case KEY_1:
        root_set_workspace_focus(root, 0);
        return;
    case KEY_2:
        root_set_workspace_focus(root, 1);
        return;
    case KEY_3:
        root_set_workspace_focus(root, 2);
        return;
    case KEY_4:
        root_set_workspace_focus(root, 3);
        return;
    case KEY_F1:
        switch_to_vt(1);
        return;
    case KEY_F2:
        switch_to_vt(2);
        return;
    case KEY_F3:
        switch_to_vt(3);
        return;
    case KEY_F4:
        switch_to_vt(4);
        return;
    }
    switch (keymap_lookup(code)) {
    case 'g':
        launch_client(root, "./germ", "germ");
        return;
    case 'l':
        launch_client(root, "./logo", "logo");
        return;
    case 't':
        launch_client(root, "./term", "term");
        return;
    case 's':
        {
            struct display_buffer *buf = display_front_buffer(&root->display);
            int w = buf->width;
            int h = buf->height;
            int stride = w;
            uint8_t *color = buf->mem;
            uint8_t *gray = malloc(w * h);
            if (gray == NULL)
                return;
            for (int i = 0; i < w * h; i++)
                gray[i] = color[4 * i];
            stbi_write_png("screenshot.png", w, h, 1, gray, stride);
            free(gray);
        }
        return;
    }
}

static int
bus_event_handler(sd_event_source *source, int fd, uint32_t revents,
        void *data)
{
    struct root *root = data;
    int r;

    for (;;) {
        r = sd_bus_process(root->bus, NULL);
        if (r < 0)
            die("Failed to process input from bus.");
        if (r == 0)
            break;
    }

    return 0;
}

static void
page_flip_handler(int device, unsigned int frame, unsigned int sec,
        unsigned int usec, void *data)
{
    struct root *root = data;
    struct display *disp = &root->display;
    disp->back_buffer_is_locked = 0;
    if (disp->is_draw_needed)
        draw(root);
}

static int
display_event_handler(sd_event_source *source, int fd, uint32_t revents,
        void *data)
{
    drmEventContext event_context = {
        .version = DRM_EVENT_CONTEXT_VERSION,
        .page_flip_handler = page_flip_handler,
    };
    int r;
    do {
        r = drmHandleEvent(fd, &event_context);
    } while (r == -1 && errno == EINTR);
    if (r)
        fprintf(stderr, "Failed to flip back buffer onto display.");

    return 0;
}

static int
keyboard_event_handler(sd_event_source *source, int fd, uint32_t revents,
        void *data)
{
    struct root *root = data;
    struct keyboard *keyboard = &root->keyboard;

    // TODO Review the handling of the various revents cases.

    if (revents & EPOLLERR)
        return -1;
    if (revents & EPOLLHUP)
        return -1;
    if (revents != EPOLLIN)
        die("Unexpected condition on keyboard.");

    struct input_event input_events[16];

    int ret;
    do {
        ret = read(fd, input_events, sizeof(input_events));
    } while (ret == -1 && errno == EINTR);

    if (ret == 0)
        die("Failed while reading keyboard.");
    if (ret % sizeof(struct input_event) != 0)
        die("Failed while reading keyboard.");

    int num_events = ret / sizeof(struct input_event);

    for (int i = 0; i < num_events; i++) {
        struct input_event *ev = &input_events[i];
        if (ev->type != EV_KEY)
            continue;
        switch (ev->value) {
        case 0:
            on_key_release(source, ev->code);
            break;
        case 1:
        case 2:
            on_key_press(source, ev->code);
            break;
        }
        struct workspace *workspace = &root->workspaces[root->focus];
        struct window *window = &workspace->windows[workspace->focus];
        if (window->pid == -1)
            continue;
        if (ev->code == KEY_LEFTMETA || ev->code == KEY_RIGHTMETA)
            continue;
        if (fetch_bit(keyboard->key_bits, KEY_LEFTMETA))
            continue;
        if (fetch_bit(keyboard->key_bits, KEY_RIGHTMETA))
            continue;
        send_keyboard_event(window, ev);
    }

    return 0;
}

static int
properties_changed_handler(sd_bus_message *m, void *data,
        sd_bus_error *ret_error)
{
    struct root *root = data;

    root->is_active = !root->is_active;

    if (root->is_active)
        fprintf(stderr, "Session is active.\n");
    else
        fprintf(stderr, "Session is not active.\n");

    if (root->is_active)
        draw(root);

    return 0;
}

static int
pause_device_handler(sd_bus_message *message, void *data,
        sd_bus_error *ret_error)
{
    int r;

    uint32_t major, minor;
    char *reason;

    r = sd_bus_message_read(message, "uus", &major, &minor, &reason);
    if (r < 0)
        die("Failed to read message from DBus.");

    fprintf(stderr, "Device paused: %d %d %s.\n", (int)major, (int)minor, reason);

    return 0;
}

static int
resume_device_handler(sd_bus_message *message, void *data,
        sd_bus_error *ret_error)
{
    int r;

    struct root *root = data;
    struct keyboard *keyboard = &root->keyboard;

    uint32_t major, minor;
    int fd;

    r = sd_bus_message_read(message, "uuh", &major, &minor, &fd);
    if (r < 0)
        die("Failed to communicate with DBus.");

    fprintf(stderr, "Device resumed: %d %d.\n", (int)major, (int)minor);

    if (major == keyboard->major && minor == keyboard->minor) {
        fd = dup(fd);
        enable_cloexec(fd);
        r = keyboard_connect(keyboard, fd, keyboard->major, keyboard->minor);
        if (r)
            die("Failed to connect to the keyboard.");

        r = sd_event_add_io(root->event, NULL, fd, EPOLLIN,
                keyboard_event_handler, root);
        if (r < 0)
            die("Failed to add keyboard event handler.");
    } else if (major == 226 && minor == 0) {
        struct display *disp = &root->display;
        if (disp->state & DISPLAY_CONTROLLER_CHOSEN) {
            struct display_buffer *front_buffer = display_front_buffer(disp);
            r = drmModeSetCrtc(disp->fd, disp->crtc, front_buffer->fb, 0, 0,
                    &disp->conn, 1, &disp->mode);
            if (r)
                die("Failed to reconnect to display.");

            draw(root);
        }
    }

    return 0;
}

/*
 *  easy_sprintf is like sprintf but it allocates a buffer for you.
 */
static char *
easy_sprintf(const char *template, ...) __attribute__((format(printf, 1, 2)));

static char *
easy_sprintf(const char *template, ...)
{
    /*
     *  We use vsnprintf twice: once to find out how much space is needed and
     *  once to store the result.
     */

    va_list args;

    int ret;
    char scratch[1];

    va_start(args, template);
    ret = vsnprintf(scratch, 0, template, args);
    va_end(args);
    if (ret <= 0)
        goto fail;

    size_t size = ret + 1;
    char *s = xmalloc(size);

    va_start(args, template);
    ret = vsnprintf(s, size, template, args);
    va_end(args);
    if (ret + 1 != size)
        goto fail;

    return s;

fail:
    die("Failed while formatting a string.");
}

int
main(void)
{
    int r;
    struct root root;
    sd_bus *bus = NULL;
    sd_bus_error error = SD_BUS_ERROR_NULL;
    sd_bus_message *message = NULL;
    sd_event *event = NULL;
    char *session_path = NULL;

    r = sd_bus_default_system(&bus);
    if (r < 0)
        die("Failed to connect to the system bus.");

    // Determine the logind object path for our current session.

    {
        char *session_short_name = NULL;
        r = sd_pid_get_session(getpid(), &session_short_name);
        if (r < 0)
            die("Failed to determine user session name.");
        session_path = easy_sprintf("/org/freedesktop/login1/session/%s",
                session_short_name);
        free(session_short_name);
    }

    // Take control of the session.

    r = sd_bus_call_method(bus, "org.freedesktop.login1", session_path,
            "org.freedesktop.login1.Session", "TakeControl", &error, &message,
            "b", (int32_t)0);
    if (r < 0)
        die("Failed to take control of session.");

    sd_bus_message_unref(message);

    // Take control of the display.

    struct display *disp = &root.display;
    display_init(disp);
    uint32_t disp_width;
    uint32_t disp_height;
    {
        uint32_t major = 226;
        uint32_t minor = 0;
        int32_t paused;
        int fd;

        r = sd_bus_call_method(bus, "org.freedesktop.login1", session_path,
                "org.freedesktop.login1.Session", "TakeDevice",
                &error, &message,
                "uu", major, minor);
        if (r < 0)
            die("Failed while communicating with logind.");

        r = sd_bus_message_read(message, "hb", &fd, &paused);
        if (r < 0)
            die("Failed while communicating with logind.");

        if (paused)
            die("Failed to take control of the display.");

        fd = dup(fd);
        sd_bus_message_unref(message);

        enable_cloexec(fd);
        disable_nonblock(fd);

        display_connect(disp, fd);

        struct display_buffer *back_buffer = display_back_buffer(disp);
        disp_width = back_buffer->width;
        disp_height = back_buffer->height;
    }

    // Take control of the keyboard.

    struct keyboard *keyboard = &root.keyboard;
    {
        int fd;
        uint32_t major;
        uint32_t minor;
        int32_t paused;

        choose_keyboard(&major, &minor);

        r = sd_bus_call_method(bus, "org.freedesktop.login1", session_path,
                "org.freedesktop.login1.Session", "TakeDevice",
                &error, &message,
                "uu", major, minor);
        if (r < 0)
            die("Failed while communicating with logind.");

        r = sd_bus_message_read(message, "hb", &fd, &paused);
        if (r < 0)
            die("Failed while communicating with logind.");

        if (paused)
            die("Failed to take control of the keyboard.");

        fd = dup(fd);
        enable_cloexec(fd);

        sd_bus_message_unref(message);

        r = keyboard_connect(keyboard, fd, major, minor);
        if (r)
            die("Failed to connect to the keyboard.");
    }

    // Create the event loop object.

    r = sd_event_default(&event);
    if (r < 0)
        die("Failed to initialize the event loop.");

    //  Initialize the root.
    //
    //  TODO This whole initialization process needs attention.

    root_init(&root, bus, event, disp_width, disp_height);

    // Subscribe to DBus signals.

    {
        char *pattern = NULL;

        pattern = easy_sprintf("type='signal',path='%s',"
                "interface='org.freedesktop.login1.Session',"
                "member='PauseDevice'",
                session_path);
        r = sd_bus_add_match(bus, NULL, pattern, pause_device_handler, &root);
        if (r < 0)
            die("Failed to subscribe to PauseDevice signal.");
        free(pattern);

        pattern = easy_sprintf("type='signal',path='%s',"
                "interface='org.freedesktop.login1.Session',"
                "member='ResumeDevice'",
                session_path);
        r = sd_bus_add_match(bus, NULL, pattern, resume_device_handler, &root);
        if (r < 0)
            die("Failed to subscribe to ResumeDevice signal.");
        free(pattern);

        pattern = easy_sprintf("type='signal',path='%s',"
                "interface='org.freedesktop.DBus.Properties',"
                "member='PropertiesChanged'",
                session_path);
        r = sd_bus_add_match(bus, NULL, pattern, properties_changed_handler,
                &root);
        if (r < 0)
            die("Failed to subscribe to PropertiesChanged signal.");
        free(pattern);
    }

    // Add signals to the event loop.

    {
        sigset_t sigset;
        const char error_message[] = "Failed to initialize signal handling.";

        signal(SIGPIPE, SIG_IGN);

        r = sigemptyset(&sigset);
        if (r == -1)
            die(error_message);

        r = sigaddset(&sigset, SIGTERM);
        if (r == -1)
            die(error_message);

        r = sigaddset(&sigset, SIGINT);
        if (r == -1)
            die(error_message);

        r = sigaddset(&sigset, SIGCHLD);
        if (r == -1)
            die(error_message);

        r = sigprocmask(SIG_BLOCK, &sigset, NULL);
        if (r == -1)
            die(error_message);

        r = sd_event_add_signal(event, NULL, SIGTERM, NULL, NULL);
        if (r < 0)
            die(error_message);

        r = sd_event_add_signal(event, NULL, SIGINT, NULL, NULL);
        if (r < 0)
            die(error_message);

        r = sd_event_add_signal(event, NULL, SIGCHLD, sigchld_handler, &root);
        if (r < 0)
            die(error_message);
    }

    // Add the display to the event loop.

    r = sd_event_add_io(event, NULL, disp->fd, EPOLLIN,
            display_event_handler, &root);
    if (r < 0)
        die("Failed to add display event handler.");

    // Add the keyboard to the event loop.

    r = sd_event_add_io(event, NULL, keyboard->fd, EPOLLIN,
            keyboard_event_handler, &root);
    if (r < 0)
        die("Failed to add keyboard event handler.");

    // Add the bus to the event loop.

    {
        int fd = sd_bus_get_fd(bus);
        if (fd < 0)
            die("Failed to initialize DBus monitoring.");

        r = sd_event_add_io(event, NULL, fd, EPOLLIN, bus_event_handler, &root);
        if (r < 0)
            die("Failed to initialize DBus monitoring.");
    }

    // Draw the first frame and enter the event loop.

    draw(&root);

    r = sd_event_loop(event);
    if (r < 0)
        die("Failed while processing events.");

    // Disconnect from display.
    
    display_disconnect(disp);

    // Release control of the session.

    r = sd_bus_call_method(bus, "org.freedesktop.login1", session_path,
            "org.freedesktop.login1.Session", "ReleaseControl",
            &error, &message,
            "");
    if (r < 0)
        die("Failed to release control of session.");

    sd_bus_message_unref(message);

    // Cleanup and exit.

    sd_event_unref(event);
    free(session_path);
    sd_bus_unref(bus);

    return 0;
}