#include <errno.h>
#include <fcntl.h>
#include <libudev.h>
#include <linux/input.h>
#include <linux/memfd.h>
#include <linux/vt.h>
#include <signal.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <systemd/sd-bus.h>
#include <systemd/sd-event.h>
#include <systemd/sd-login.h>
#include <unistd.h>
#include <xf86drm.h>
#include <xf86drmMode.h>

#include <shell.h>
#include <keymap.h>

#define STB_IMAGE_WRITE_IMPLEMENTATION
#include <stb_image_write.h>

#define array_size(a) (sizeof(a) / sizeof((a)[0]))

#define BYTES_PER_PIXEL 4

enum {
    DISPLAY_CONTROLLER_CHOSEN = 1,
    DISPLAY_HAS_FRAME_BUFFERS = 2,
    DISPLAY_SAVED = 4,
    DISPLAY_PAUSED = 8,
};

enum {
    NUM_WORKSPACES = 4,
    NUM_WINDOWS_PER_WORKSPACE = 3,
};

struct display_buffer {
    void *mem;
    uint64_t size;
    uint32_t width;
    uint32_t height;
    uint32_t stride;
    uint32_t handle;
    uint32_t fb;
};

struct display {
    int fd;
    struct display_buffer buffers[2];
    unsigned int back_buffer_index:1;
    unsigned int back_buffer_is_locked:1;
    unsigned int is_draw_needed:1;

    uint32_t state;

    drmModeModeInfo mode;
    uint32_t conn;
    uint32_t crtc;
    drmModeCrtc *saved_crtc;

    sd_event_source *event_source;
};

struct window {
    uint32_t x;
    uint32_t y;
    uint32_t width;
    uint32_t height;
    unsigned int front_buffer:1;
    void *mem;
    int socket_fd;
    pid_t pid;
    sd_event_source *event_source;
};

struct workspace {
    struct window windows[NUM_WINDOWS_PER_WORKSPACE];
    unsigned focus;
};

struct keyboard {
    int fd;
    uint32_t major;
    uint32_t minor;
    uint8_t key_bits[SHELL_KEY_BITS_SIZE];
    sd_event_source *event_source;
};

struct root {
    char *session_name;
    char *session_path;
    bool is_active;
    struct workspace workspaces[NUM_WORKSPACES];
    unsigned focus;
    sd_bus *bus;
    sd_event *event;
    struct keyboard keyboard;
    struct display display;
};

#define ROOT_INIT &(struct root){};

static bool want_state_reporting = true;
static bool want_event_reporting = true;

static void
report_state(const char *fmt, ...)
{
    if (!want_state_reporting)
        return;
    va_list args;
    va_start(args, fmt);
    fprintf(stderr, "STATE: ");
    vfprintf(stderr, fmt, args);
    fputc('\n', stderr);
    fflush(stderr);
    va_end(args);
}

static void
report_event(const char *fmt, ...)
{
    if (!want_event_reporting)
        return;
    va_list args;
    va_start(args, fmt);
    fprintf(stderr, "EVENT: ");
    vfprintf(stderr, fmt, args);
    fputc('\n', stderr);
    fflush(stderr);
    va_end(args);
}

_Noreturn static void
die(const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    fprintf(stderr, "ERROR: ");
    vfprintf(stderr, fmt, args);
    fputc('\n', stderr);
    fflush(stderr);
    va_end(args);
    exit(1);
}

static void *
must_malloc(size_t size)
{
    if (size == 0)
        die("Memory allocation error.");
    void *x = malloc(size);
    if (x == NULL)
        die("Memory allocation failed.");
    return x;
}

static int
handle_device_event_for_display(sd_event_source *event_source, int fd,
        uint32_t revents, void *data);

static int
handle_device_event_for_keyboard(sd_event_source *event_source, int fd, uint32_t revents,
        void *data);

static void
root_switch_to_vt(int n);

static void
root_launch_client(struct root *root, const char *path, const char *arg0);

static struct window *
root_lookup_window_by_fd(struct root *root, int fd);

static void
root_set_workspace_focus(struct root *root, int focus);

static void
display_init(struct display *disp)
{
    disp->state = 0;
    disp->back_buffer_index = 1;
    disp->back_buffer_is_locked = 0;
    disp->is_draw_needed = 0;
}

static struct display_buffer *
display_front_buffer(struct display *disp)
{
    return &disp->buffers[1 - disp->back_buffer_index];
}

static struct display_buffer *
display_back_buffer(struct display *disp)
{
    return &disp->buffers[disp->back_buffer_index];
}

static void
display_connect(struct root *root, struct display *disp, int fd)
{
    disp->fd = fd;

    //  We need "dumb buffer" support so fail if the device doesn't have that.

    {
        uint64_t has_dumb;
        int r = drmGetCap(fd, DRM_CAP_DUMB_BUFFER, &has_dumb);
        if (r < 0 || !has_dumb)
            die("Failed to connect to the display.");
    }

    //  We choose an arbitrary connector that's currently active, a mode to use
    //  on it (determines resolution), and a controller (CRTC) that's
    //  compatible with one of the connector's encoders.

    {
        drmModeRes *res = drmModeGetResources(disp->fd);
        if (!res)
            die("Failed to connect to display.");
        bool did_choose_crtc = false;
        for (int i = 0; i < res->count_connectors; i++) {
            drmModeConnector *conn =
                drmModeGetConnector(disp->fd, res->connectors[i]);
            if (!conn)
                continue;
            if (conn->connection == DRM_MODE_CONNECTED && conn->count_modes > 0) {
                drmModeModeInfo *mode;
                {
                    int j = conn->count_modes - 1;
                    while (j > 0 && !(conn->modes[j].type & DRM_MODE_TYPE_PREFERRED))
                        j--;
                    mode = &conn->modes[j];
                }
                for (int j = 0; j < conn->count_encoders; j++) {
                    drmModeEncoder *enc =
                        drmModeGetEncoder(disp->fd, conn->encoders[j]);
                    if (!enc)
                        continue;
                    for (int k = 0; k < res->count_crtcs; k++) {
                        if (enc->possible_crtcs & (1 << k)) {
                            disp->crtc = res->crtcs[k];
                            disp->conn = conn->connector_id;
                            disp->mode = *mode;
                            did_choose_crtc = true;
                            break;
                        }
                    }
                    drmModeFreeEncoder(enc);
                    if (did_choose_crtc)
                        break;
                }
            }
            drmModeFreeConnector(conn);
            if (did_choose_crtc)
                break;
        }
        drmModeFreeResources(res);
        if (!did_choose_crtc)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_CONTROLLER_CHOSEN;

    //  Allocate and mmap the two full-screen pixel buffers.

    for (int i = 0; i < 2; i++) {
        struct display_buffer *buffer = &disp->buffers[i];
        buffer->width = disp->mode.hdisplay;
        buffer->height = disp->mode.vdisplay;
        struct drm_mode_create_dumb ioctl_create = {
            .width = buffer->width,
            .height = buffer->height,
            .bpp = 32,
        };
        int r = drmIoctl(fd, DRM_IOCTL_MODE_CREATE_DUMB, &ioctl_create);
        if (r < 0)
            die("Failed to connect to display.");
        buffer->stride = ioctl_create.pitch;
        buffer->size = ioctl_create.size;
        buffer->handle = ioctl_create.handle;
        r = drmModeAddFB(fd, buffer->width, buffer->height, 24, 32,
            buffer->stride, buffer->handle, &buffer->fb);
        if (r)
            die("Failed to connect to display.");
        struct drm_mode_map_dumb ioctl_map = {
            .handle = buffer->handle,
        };
        r = drmIoctl(fd, DRM_IOCTL_MODE_MAP_DUMB, &ioctl_map);
        if (r)
            die("Failed to connect to display.");
        buffer->mem = mmap(0, buffer->size, PROT_READ | PROT_WRITE,
            MAP_SHARED, fd, ioctl_map.offset);
        if (buffer->mem == MAP_FAILED)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_HAS_FRAME_BUFFERS;

    //  Save the current state for the CRTC we are about to take over. We will
    //  restore this state when we disconnect from the display.
    //
    //  Take over the connector and CRTC that we selected above.

    {
        struct display_buffer *front_buffer = display_front_buffer(disp);

        disp->saved_crtc = drmModeGetCrtc(disp->fd, disp->crtc);

        int r = drmModeSetCrtc(disp->fd, disp->crtc, front_buffer->fb, 0, 0,
            &disp->conn, 1, &disp->mode);
        if (r)
            die("Failed to connect to display.");
    }

    disp->state |= DISPLAY_SAVED;

    {
        disp->event_source = NULL;
        int r = sd_event_add_io(root->event, &disp->event_source, disp->fd, EPOLLIN,
                handle_device_event_for_display, root);
        if (r < 0)
            die("Failed to add display event handler.");
    }
}

static void
display_disconnect(struct display *disp)
{
    if (disp->state & DISPLAY_SAVED) {
        drmModeSetCrtc(
            disp->fd,
            disp->saved_crtc->crtc_id,
            disp->saved_crtc->buffer_id,
            disp->saved_crtc->x,
            disp->saved_crtc->y,
            &disp->conn,
            1,
            &disp->saved_crtc->mode);
        drmModeFreeCrtc(disp->saved_crtc);
        disp->state &= ~DISPLAY_SAVED;
    }
    if (disp->state & DISPLAY_HAS_FRAME_BUFFERS) {
        for (int i = 0; i < 2; i++) {
            struct display_buffer *buffer = &disp->buffers[i];
            munmap(buffer->mem, buffer->size);
            drmModeRmFB(disp->fd, buffer->fb);
            drmIoctl(disp->fd, DRM_IOCTL_MODE_DESTROY_DUMB,
                &(struct drm_mode_destroy_dumb){
                    .handle = buffer->handle,
                });
        }
        disp->state &= ~DISPLAY_HAS_FRAME_BUFFERS;
    }
}

static void
display_reconnect(struct display *disp, int fd)
{
    disp->fd = fd;

    {
        int r = sd_event_source_set_io_fd(disp->event_source, fd);
        if (r < 0)
            die("Failed to reconnect the display.");
        r = sd_event_source_set_enabled(disp->event_source, SD_EVENT_ON);
        if (r < 0)
            die("Failed to reconnect the display.");
    }

    if (disp->state & DISPLAY_CONTROLLER_CHOSEN) {
        struct display_buffer *front_buffer = display_front_buffer(disp);
        int r = drmModeSetCrtc(disp->fd, disp->crtc, front_buffer->fb, 0, 0,
                &disp->conn, 1, &disp->mode);
        if (r)
            die("Failed to reconnect to display.");
    }

    report_event("Connected to display: fd=%d.", fd);
}

static void
set_bit(uint8_t *bits, unsigned i)
{
    bits[i / 8] |= (1 << (i % 8));
}

static void
clear_bit(uint8_t *bits, unsigned i)
{
    bits[i / 8] &= ~(1 << (i % 8));
}

static int
fetch_bit(const uint8_t *bits, unsigned i)
{
    return (bits[i / 8] & (1 << (i % 8))) ? 1 : 0;
}

static void
keyboard_load_key_bits(struct keyboard *keyboard)
{
    int r = ioctl(keyboard->fd, EVIOCGKEY(SHELL_KEY_BITS_SIZE), keyboard->key_bits);
    if (r == -1)
        die("Failed to read keyboard state.");
}

static void
keyboard_connect(struct root *root, struct keyboard *keyboard, int fd,
        unsigned int major, unsigned int minor)
{
    keyboard->fd = fd;
    keyboard->major = major;
    keyboard->minor = minor;

    keyboard_load_key_bits(keyboard);

    {
        keyboard->event_source = NULL;
        int r = sd_event_add_io(root->event, &keyboard->event_source,
                keyboard->fd, EPOLLIN,
                handle_device_event_for_keyboard, root);
        if (r < 0)
            die("Failed to add keyboard event handler.");
    }

    report_event("Connected to keyboard: fd=%d.", fd);
}

static void
keyboard_reconnect(struct keyboard *keyboard, int fd)
{
    keyboard->fd = fd;

    keyboard_load_key_bits(keyboard);

    {
        int r = sd_event_source_set_io_fd(keyboard->event_source, fd);
        if (r < 0)
            die("Failed to reconnect the keyboard.");
        r = sd_event_source_set_enabled(keyboard->event_source, SD_EVENT_ON);
        if (r < 0)
            die("Failed to reconnect the keyboard.");
    }

    report_event("Connected to keyboard: fd=%d.", fd);
}

static void
blit(void *dest, uint32_t dest_width, uint32_t x, uint32_t y,
        void *source, uint32_t source_width, uint32_t source_height)
{
    for (uint32_t i = 0; i < source_height; i++) {
        memmove(dest + BYTES_PER_PIXEL * (dest_width * (y + i) + x),
                source + BYTES_PER_PIXEL * source_width * i,
                BYTES_PER_PIXEL * source_width);
    }
}

static uint32_t
encode_color_xrgb(uint8_t r, uint8_t g, uint8_t b)
{
    union {
        uint32_t u;
        uint8_t bytes[4];
    } x = {
        .bytes = { b, g, r, 0 },
    };
    return x.u;
}

static void
clear_buffer(void *buffer, uint32_t width, uint32_t height)
{
    uint32_t color = encode_color_xrgb(0xff, 0xff, 0xff);
    uint32_t *pixel = buffer;
    for (unsigned int j = 0; j < width; j++)
        for (unsigned int k = 0; k < height; k++)
            *pixel++ = color;
}

static void
draw_frame(void *mem, uint32_t width, struct window *window, bool has_focus)
{
    int x0 = window->x - 5;
    int y0 = window->y - 5;
    int x1 = window->x + window->width + 4;
    int y1 = window->y + window->height + 4;

    uint32_t color = has_focus
        ? encode_color_xrgb(0x48, 0x48, 0x48)
        : encode_color_xrgb(0xcc, 0xcc, 0xcc);

    uint32_t *pixels = mem;

    for (int x = x0, i = width * y0 + x0; x <= x1; x++, i++)
        pixels[i] = color;

    for (int x = x0, i = width * y1 + x0; x <= x1; x++, i++)
        pixels[i] = color;

    for (int y = y0, i = width * y0 + x0; y <= y1; y++, i += width)
        pixels[i] = color;

    for (int y = y0, i = width * y0 + x1; y <= y1; y++, i += width)
        pixels[i] = color;
}

static void
draw_window(void *back_buffer, uint32_t disp_width, struct window *window)
{
    if (window->pid == -1)
        return;
    void *buffer = window->mem + (window->front_buffer * BYTES_PER_PIXEL *
            window->width * window->height);
    blit(back_buffer, disp_width, window->x, window->y,
            buffer, window->width, window->height);
}

static void
draw(struct root *root)
{
    struct display *disp = &root->display;

    bool must_wait =
        disp->back_buffer_is_locked
        || !root->is_active
        || (disp->state & DISPLAY_PAUSED);

    if (must_wait) {
        disp->is_draw_needed = 1;
        return;
    }

    struct display_buffer *back_buffer = display_back_buffer(disp);
    uint32_t disp_width = back_buffer->width;
    uint32_t disp_height = back_buffer->height;

    clear_buffer(back_buffer->mem, disp_width, disp_height);

    struct workspace *workspace = &root->workspaces[root->focus];
    for (int i = 0; i < NUM_WINDOWS_PER_WORKSPACE; i++) {
        struct window *window = &workspace->windows[i];
        draw_frame(back_buffer->mem, disp_width, window, i == workspace->focus);
        draw_window(back_buffer->mem, disp_width, window);
    }

    disp->back_buffer_index = !disp->back_buffer_index;

    int r = drmModePageFlip(disp->fd, disp->crtc, back_buffer->fb,
        DRM_MODE_PAGE_FLIP_EVENT, root);
    if (r) {
        report_event("Failed to flip back buffer onto display.");
        return;
    }

    disp->back_buffer_is_locked = 1;
    disp->is_draw_needed = 0;
}

// easy_sprintf is like sprintf but it allocates a buffer for you.
static char *
easy_sprintf(const char *template, ...) __attribute__((format(printf, 1, 2)));
static char *
easy_sprintf(const char *template, ...)
{
    // We use vsnprintf twice: once to find out how much space is needed and
    // once to store the result.

    va_list args;

    int ret;
    char scratch[1];

    va_start(args, template);
    ret = vsnprintf(scratch, 0, template, args);
    va_end(args);
    if (ret <= 0)
        goto fail;

    size_t size = ret + 1;
    char *s = must_malloc(size);

    va_start(args, template);
    ret = vsnprintf(s, size, template, args);
    va_end(args);
    if (ret + 1 != size)
        goto fail;

    return s;

fail:
    die("Failed while formatting a string.");
}

static void
enable_cloexec(int fd)
{
    const char *error_message = "Failed to enable CLOEXEC on file.";
    int r = fcntl(fd, F_GETFD);
    if (r == -1)
        die(error_message);
    int flags = r;
    flags |= FD_CLOEXEC;
    r = fcntl(fd, F_SETFD, flags);
    if (r == -1)
        die(error_message);
}

static void
disable_nonblock(int fd)
{
    const char *error_message = "Failed to disable NONBLOCK on file.";
    int r = fcntl(fd, F_GETFL);
    if (r == -1)
        die(error_message);
    int flags = r;
    flags &= ~O_NONBLOCK;
    r = fcntl(fd, F_SETFL, flags);
    if (r == -1)
        die(error_message);
}

static void
send_welcome(struct window *window)
{
    struct shell_message_welcome message = {
        .class = SHELL_MESSAGE_WELCOME,
        .width = window->width,
        .height = window->height,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send welcome message to child process.");
}

static void
send_keyboard_event(struct window *window, struct input_event *event)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_event message = {
        .class = SHELL_MESSAGE_KEYBOARD_EVENT,
        .code = event->code,
        .value = event->value,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send keyboard event message to child process.");
}

static void
send_reset_completed(struct window *window)
{
    struct shell_message_reset_completed message = {
        .class = SHELL_MESSAGE_RESET_COMPLETED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send buffer unlocked message to child process.");
}

static void
send_buffer_unlocked(struct window *window)
{
    if (window->pid == -1)
        return;

    struct shell_message_buffer_unlocked message = {
        .class = SHELL_MESSAGE_BUFFER_UNLOCKED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send buffer unlocked message to child process.");
}

static void
send_keyboard_detached(struct window *window)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_detached message = {
        .class = SHELL_MESSAGE_KEYBOARD_DETACHED,
    };

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send keyboard detached message to child process.");
}

static void
send_keyboard_attached(struct window *window, uint8_t *key_bits)
{
    if (window->pid == -1)
        return;

    struct shell_message_keyboard_attached message = {
        .class = SHELL_MESSAGE_KEYBOARD_ATTACHED,
    };
    memmove(message.key_bits, key_bits, SHELL_KEY_BITS_SIZE);

    ssize_t r = write(window->socket_fd, &message, sizeof(message));
    if (r != sizeof(message))
        report_event("Failed to send keyboard attached message to child process.");
}

static void
handle_key_release(struct root *root, int code)
{
    struct keyboard *keyboard = &root->keyboard;
    if (code >= 0 && code <= KEY_MAX)
        clear_bit(keyboard->key_bits, (unsigned)code);
}

static void
handle_key_press(struct root *root, int code)
{
    int r;

    struct keyboard *keyboard = &root->keyboard;
    if (code >= 0 && code <= KEY_MAX)
        set_bit(keyboard->key_bits, (unsigned)code);

    if (!fetch_bit(keyboard->key_bits, KEY_LEFTMETA) &&
            !fetch_bit(keyboard->key_bits, KEY_RIGHTMETA))
        return;

    switch (code) {
    case KEY_ESC:
        r = sd_event_exit(root->event, 0);
        if (r < 0)
            die("Failed to handle keyboard event.");
        return;
    case KEY_TAB:
        {
            struct workspace *workspace = &root->workspaces[root->focus];
            struct window *window = &workspace->windows[workspace->focus];
            uint8_t key_bits[SHELL_KEY_BITS_SIZE];
            memmove(key_bits, keyboard->key_bits, SHELL_KEY_BITS_SIZE);
            clear_bit(key_bits, KEY_LEFTMETA);
            clear_bit(key_bits, KEY_RIGHTMETA);
            clear_bit(key_bits, KEY_TAB);
            send_keyboard_detached(window);
            workspace->focus =
                (workspace->focus + 1) % NUM_WINDOWS_PER_WORKSPACE;
            window = &workspace->windows[workspace->focus];
            send_keyboard_attached(window, key_bits);
            draw(root);
        }
        return;
    case KEY_1:
        root_set_workspace_focus(root, 0);
        return;
    case KEY_2:
        root_set_workspace_focus(root, 1);
        return;
    case KEY_3:
        root_set_workspace_focus(root, 2);
        return;
    case KEY_4:
        root_set_workspace_focus(root, 3);
        return;
    case KEY_F1:
        root_switch_to_vt(1);
        return;
    case KEY_F2:
        root_switch_to_vt(2);
        return;
    case KEY_F3:
        root_switch_to_vt(3);
        return;
    case KEY_F4:
        root_switch_to_vt(4);
        return;
    }
    switch (keymap_lookup(code)) {
    case 'g':
        root_launch_client(root, "./germ", "germ");
        return;
    case 'l':
        root_launch_client(root, "./logo", "logo");
        return;
    case 't':
        root_launch_client(root, "./term", "term");
        return;
    case 's':
        {
            struct display_buffer *buf = display_front_buffer(&root->display);
            int w = buf->width;
            int h = buf->height;
            int stride = w;
            uint8_t *color = buf->mem;
            uint8_t *gray = malloc(w * h);
            if (gray == NULL)
                return;
            for (int i = 0; i < w * h; i++)
                gray[i] = color[4 * i];
            stbi_write_png("screenshot.png", w, h, 1, gray, stride);
            free(gray);
        }
        return;
    }
}

static int
handle_os_signal_sigchld(sd_event_source *event_source,
        const struct signalfd_siginfo *siginfo, void *data)
{
    struct root *root = data;

    (void)event_source;
    (void)siginfo;

    for (;;) {
        int status;
        pid_t pid;

        pid = waitpid(-1, &status, WUNTRACED|WCONTINUED|WNOHANG);
        if (pid == 0)
            break;
        if (pid == -1) {
            if (errno == ECHILD)
                break;
            die("Failed to handle SIGCHLD.");
        }

        if (WIFEXITED(status)) {
            report_event("Child exited with status %d.", (int)WEXITSTATUS(status));
        } else if (WIFSIGNALED(status)) {
            report_event("Child was terminated by signal %d.", (int)WTERMSIG(status));
        } else if (WIFSTOPPED(status)) {
            report_event("Child was stopped.");
            continue;
        } else if (WIFCONTINUED(status)) {
            report_event("Child was continued.");
            continue;
        }

        for (int i = 0; i < NUM_WORKSPACES; i++) {
            struct workspace *workspace = &root->workspaces[i];
            for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
                struct window *window = &workspace->windows[j];
                if (window->pid == pid) {
                    window->front_buffer = 0;
                    size_t mem_size =
                        2 * BYTES_PER_PIXEL * window->width * window->height;
                    int r = munmap(window->mem, mem_size);
                    if (r == -1)
                        die("Failed to release draw buffers.");
                    window->mem = NULL;
                    close(window->socket_fd);
                    window->socket_fd = -1;
                    window->pid = -1;
                }
            }
        }
    }

    draw(root);

    return 0;
}

static int
handle_client_message(sd_event_source *event_source, int fd, uint32_t events,
        void *data)
{
    struct root *root = data;
    struct window *window = root_lookup_window_by_fd(root, fd);
    if (window == NULL || window->event_source != event_source)
        die("Event received from unknown source.");

    int drop(void) {
        sd_event_source_set_enabled(event_source, SD_EVENT_OFF);
        sd_event_source_unref(event_source);
        window->event_source = NULL;
        return 0;
    }

    //  TODO What kinds of errors (EPOLLERR) can we get? How should they be
    //  handled?

    if (events & EPOLLERR)
        return drop();
    if (events & EPOLLHUP)
        return drop();
    if (events != EPOLLIN)
        die("Unexpected condition on socket.");

    static _Alignas(max_align_t) char packet[SHELL_PACKET_SIZE_MAX];
    size_t packet_size;
    for (;;) {
        ssize_t r = read(fd, packet, SHELL_PACKET_SIZE_MAX);
        if (r == -1 && errno == EINTR)
            continue;
        if (r == -1)
            return drop();
        packet_size = r;
        break;
    }

    //  TODO The following error should not be fatal. We need a general way to
    //  deal with protocol errors.

    if (packet_size < sizeof(uint16_t))
        die("Protocol error on socket.");

    uint16_t class = *(uint16_t *)packet;
    switch (class) {
    case SHELL_MESSAGE_RESET:
        if (window->front_buffer != 0) {
            size_t buffer_size = BYTES_PER_PIXEL * window->width * window->height;
            memmove(window->mem, window->mem + buffer_size, buffer_size);
            window->front_buffer = 0;
        }
        send_reset_completed(window);
        send_welcome(window);
        break;
    case SHELL_MESSAGE_SHOW:
        window->front_buffer = !window->front_buffer;
        send_buffer_unlocked(window);
        draw(root);
        break;
    default:
        report_event("Unexpected message class received from client.");
        break;
    }
    return 0;
}

static void
handle_display_page_flip_event(int device, unsigned int frame,
        unsigned int sec, unsigned int usec, void *data)
{
    struct root *root = data;
    struct display *disp = &root->display;
    disp->back_buffer_is_locked = 0;
    if (disp->is_draw_needed)
        draw(root);
}

static int
handle_device_event_for_display(sd_event_source *event_source, int fd,
        uint32_t revents, void *data)
{
    struct root *root = data;
    struct display *disp = &root->display;
    if (disp->event_source != event_source)
        die("Event received from unknown source.");

    if ((revents & EPOLLERR) || (revents & EPOLLHUP))
        return -1;

    drmEventContext event_context = {
        .version = DRM_EVENT_CONTEXT_VERSION,
        .page_flip_handler = handle_display_page_flip_event,
    };
    int r;
    do {
        r = drmHandleEvent(fd, &event_context);
    } while (r == -1 && errno == EINTR);
    if (r)
        report_event("Failed to flip back buffer onto display.");

    return 0;
}

static int
handle_device_event_for_keyboard(sd_event_source *event_source, int fd,
        uint32_t revents, void *data)
{
    struct root *root = data;
    struct keyboard *keyboard = &root->keyboard;
    if (keyboard->event_source != event_source)
        die("Event from unknown source.");

    if ((revents & EPOLLERR) || (revents & EPOLLHUP)) {
        report_event("Keyboard disconnected.");
        sd_event_source_set_enabled(event_source, SD_EVENT_OFF);
        close(fd);
        return 0;
    }
    if (revents != EPOLLIN)
        die("Unexpected condition on keyboard.");

    struct input_event input_events[16];

    int ret;
    do {
        ret = read(fd, input_events, sizeof(input_events));
    } while (ret == -1 && errno == EINTR);

    if (ret == -1) {
        report_event("Failed while reading keyboard: %s.", strerror(errno));
        return -1;
    }
    if (ret == 0)
        die("Failed while reading keyboard: no input.");
    if (ret % sizeof(struct input_event) != 0)
        die("Failed while reading keyboard: invalid input.");

    int num_events = ret / sizeof(struct input_event);

    for (int i = 0; i < num_events; i++) {
        struct input_event *ev = &input_events[i];
        if (ev->type != EV_KEY)
            continue;
        switch (ev->value) {
        case 0:
            handle_key_release(root, ev->code);
            break;
        case 1:
        case 2:
            handle_key_press(root, ev->code);
            break;
        }
        struct workspace *workspace = &root->workspaces[root->focus];
        struct window *window = &workspace->windows[workspace->focus];
        if (window->pid == -1)
            continue;
        if (ev->code == KEY_LEFTMETA || ev->code == KEY_RIGHTMETA)
            continue;
        if (fetch_bit(keyboard->key_bits, KEY_LEFTMETA))
            continue;
        if (fetch_bit(keyboard->key_bits, KEY_RIGHTMETA))
            continue;
        send_keyboard_event(window, ev);
    }

    return 0;
}

static int
handle_bus_event(sd_event_source *event_source, int fd, uint32_t revents,
        void *data)
{
    struct root *root = data;
    int r;

    for (;;) {
        r = sd_bus_process(root->bus, NULL);
        if (r < 0)
            die("Failed to process input from bus.");
        if (r == 0)
            break;
        report_event("Processed a DBus message.");
    }

    return 0;
}

static int
handle_bus_signal_properties_changed(sd_bus_message *m, void *data,
        sd_bus_error *ret_error)
{
    report_event("Received a PropertiesChanged message.");

    struct root *root = data;

    void check(int r)
    {
        if (r < 0)
            die("Failed to read DBus message.");
    }

    void update(bool is_active)
    {
        if (root->is_active == is_active)
            return;

        root->is_active = is_active;

        report_event("%s focus.", is_active ? "Gained" : "Lost");
        report_state("Running in the %sground.", is_active ? "fore" : "back");

        if (is_active)
            draw(root);
    }

    int r;
    const char *interface;
    const char *s;

    r = sd_bus_message_read_basic(m, 's', &interface);
    check(r);

    if (strcmp(interface, "org.freedesktop.login1.Session") != 0)
        return 0;

    r = sd_bus_message_enter_container(m, 'a', "{sv}");
    check(r);

    for (;;) {
        r = sd_bus_message_enter_container(m, 'e', "sv");
        check(r);
        if (r == 0)
            break;

        r = sd_bus_message_read_basic(m, 's', &s);
        check(r);

        if (strcmp(s, "Active") == 0) {
            r = sd_bus_message_enter_container(m, 'v', "b");
            check(r);

            bool is_active;
            r = sd_bus_message_read_basic(m, 'b', &is_active);
            check(r);

            update(is_active);

            return 0;
        } else {
            r = sd_bus_message_skip(m, "v");
            check(r);
        }

        r = sd_bus_message_exit_container(m);
        check(r);
    }

    r = sd_bus_message_exit_container(m);
    check(r);

    r = sd_bus_message_enter_container(m, 'a', "s");
    check(r);

    for (;;) {
        r = sd_bus_message_read_basic(m, 's', &s);
        check(r);
        if (r == 0)
            break;

        if (strcmp(s, "Active") == 0) {
            sd_bus_error error = SD_BUS_ERROR_NULL;
            bool is_active;
            r = sd_bus_get_property_trivial(root->bus,
                    "org.freedesktop.login1", root->session_path,
                    "org.freedesktop.login1.Session", "Active", &error,
                    'b', &is_active);
            check(r);

            update(is_active);

            return 0;
        }
    }

    return 0;
}

static int
handle_bus_signal_pause_device(sd_bus_message *message, void *data,
        sd_bus_error *ret_error)
{
    struct root *root = data;

    int r;

    uint32_t major, minor;
    char *reason;

    r = sd_bus_message_read(message, "uus", &major, &minor, &reason);
    if (r < 0)
        die("Failed to read message from DBus.");

    if (major == 226 && minor == 0) {
        root->display.state |= DISPLAY_PAUSED;
        sd_event_source_set_enabled(root->display.event_source, SD_EVENT_OFF);
        close(root->display.fd);
    }

    report_event("Received a PauseDevice message: major=%d, minor=%d, reason=%s.",
            (int)major, (int)minor, reason);

    return 0;
}

static int
handle_bus_signal_resume_device(sd_bus_message *message, void *data,
        sd_bus_error *ret_error)
{
    int r;

    struct root *root = data;
    struct keyboard *keyboard = &root->keyboard;

    uint32_t major, minor;
    int fd;

    r = sd_bus_message_read(message, "uuh", &major, &minor, &fd);
    if (r < 0)
        die("Failed to communicate with DBus.");

    report_event("Received a ResumeDevice message: major=%d, minor=%d, fd=%d.",
            (int)major, (int)minor, fd);

    if (major == keyboard->major && minor == keyboard->minor) {
        fd = dup(fd);
        enable_cloexec(fd);
        keyboard_reconnect(keyboard, fd);
    } else if (major == 226 && minor == 0) {
        fd = dup(fd);
        struct display *disp = &root->display;
        disp->state &= ~DISPLAY_PAUSED;
        display_reconnect(disp, fd);
        draw(root);
    }

    return 0;
}

static void
root_find_keyboard(uint32_t *major, uint32_t *minor)
{
    struct udev *udev = udev_new();
    if (udev == NULL)
        die("Failed to choose keyboard.");
    struct udev_enumerate *enumerate = udev_enumerate_new(udev);
    if (enumerate == NULL)
        die("Failed to choose keyboard.");
    int r;
    r = udev_enumerate_add_match_property(enumerate, "ID_INPUT_KEYBOARD", "1");
    if (r < 0)
        die("Failed to choose keyboard.");
    r = udev_enumerate_scan_devices(enumerate);
    if (r < 0)
        die("Failed to choose keyboard.");
    struct udev_list_entry *first_entry = udev_enumerate_get_list_entry(enumerate);
    if (first_entry == NULL)
        die("Failed to choose keyboard.");
    struct udev_list_entry *entry;
    *major = 0;
    *minor = 0;
    udev_list_entry_foreach(entry, first_entry) {
        const char *name = udev_list_entry_get_name(entry);
        if (name == NULL)
            continue;
        struct udev_device *device = udev_device_new_from_syspath(udev, name);
        if (device == NULL)
            continue;
        dev_t dev = udev_device_get_devnum(device);
        if (major(dev) == 0 && minor(dev) == 0) {
            udev_device_unref(device);
            continue;
        }
        *major = major(dev);
        *minor = minor(dev);
        udev_device_unref(device);
        break;
    }
    udev_enumerate_unref(enumerate);
    udev_unref(udev);
    if (*major == 0 && *minor == 0)
        die("Failed to choose keyboard.");
}

static void
root_switch_to_vt(int n)
{
    int fd = open("/dev/tty", O_RDWR);
    if (fd == -1)
        die("Failed to switch VT.");

    int r;

    r = ioctl(fd, VT_ACTIVATE, n);
    if (r == -1)
        die("Failed to switch VT.");

    r = ioctl(fd, VT_WAITACTIVE, n);
    if (r == -1)
        die("Failed to switch VT.");

    close(fd);
}

static void
root_launch_client(struct root *root, const char *path, const char *arg0)
{
    struct workspace *workspace = &root->workspaces[root->focus];
    struct window *window = &workspace->windows[workspace->focus];

    if (window->pid != -1)
        return;

    int r;

    size_t mem_size = 2 * BYTES_PER_PIXEL * window->width * window->height;
    char filename[64];
    r = snprintf(filename, sizeof(filename), "buffer_%u_%u",
            root->focus, workspace->focus);
    if (r >= sizeof(filename))
        die("Failed to create draw buffers for child process.");
    int mem_fd = syscall(__NR_memfd_create, filename, 0);
    if (mem_fd == -1)
        die("Failed to create draw buffers for child process.");
    do { r = ftruncate(mem_fd, mem_size); } while (r == -1 && errno == EINTR);
    if (r == -1)
        die("Failed to create draw buffers for child process.");
    void *mem = mmap(NULL, mem_size, PROT_READ|PROT_WRITE, MAP_SHARED,
            mem_fd, 0);
    if (mem == MAP_FAILED)
        die("Failed to create draw buffers for child process.");

    window->mem = mem;

    int socket_fds[2];
    r = socketpair(AF_UNIX, SOCK_SEQPACKET|SOCK_NONBLOCK, 0, socket_fds);
    if (r == -1)
        die("Failed to create socket for child process.");

    window->socket_fd = socket_fds[0];
    send_welcome(window);

    r = sd_event_add_io(root->event, &window->event_source, window->socket_fd,
            EPOLLIN, handle_client_message, root);
    if (r < 0)
        die("Failed to add child message event handler.");

    clear_buffer(window->mem, window->width, window->height);

    pid_t pid = fork();
    switch (pid) {
    case -1:
        die("Failed to fork child process.");
        break;
    case 0:
        {
            sigset_t sigset;
            int r;
            const char error_message[] =
                "Failed to initialize signal handling for child process.";
            r = sigemptyset(&sigset);
            if (r == -1)
                die(error_message);
            r = sigprocmask(SIG_SETMASK, &sigset, NULL);
            if (r == -1)
                die(error_message);
            signal(SIGPIPE, SIG_DFL);
            signal(SIGINT, SIG_DFL);
            signal(SIGTERM, SIG_DFL);
            signal(SIGCHLD, SIG_DFL);
        }
        close(socket_fds[0]);
        if (socket_fds[1] != 3) {

            //  TODO We close fd 3 here without any guarantee that mem_fd != 3.
            //  How to fix this problem?

            do {
                r = dup2(socket_fds[1], 3);
            } while (r == -1 && errno == EINTR);
            if (r == -1)
                die("Failed while launching child process.");
            close(socket_fds[1]);
        }
        if (mem_fd != 4) {
            do {
                r = dup2(mem_fd, 4);
            } while (r == -1 && errno == EINTR);
            if (r == -1)
                die("Failed while launching child process.");
            close(mem_fd);
        }
        execlp(path, arg0, (char *)NULL);
        die("Failed to exec child process.");
        break;
    default:
        window->pid = pid;
        close(mem_fd);
        close(socket_fds[1]);
        enable_cloexec(socket_fds[0]);
        break;
    }
}

static struct window *
root_lookup_window_by_fd(struct root *root, int fd)
{
    for (int i = 0; i < NUM_WORKSPACES; i++) {
        struct workspace *workspace = &root->workspaces[i];
        for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
            struct window *window = &workspace->windows[j];
            if (window->socket_fd == fd)
                return window;
        }
    }
    return NULL;
}

static void
root_set_workspace_focus(struct root *root, int focus)
{
    if (root->focus == focus)
        return;

    struct workspace *workspace;
    struct window *window;

    workspace = &root->workspaces[root->focus];
    window = &workspace->windows[workspace->focus];
    send_keyboard_detached(window);

    root->focus = focus;

    workspace = &root->workspaces[root->focus];
    window = &workspace->windows[workspace->focus];
    send_keyboard_attached(window, root->keyboard.key_bits);

    draw(root);
}

static void
root_connect_to_bus(struct root *root)
{
    int r = sd_bus_default_system(&root->bus);
    if (r < 0)
        die("Failed to connect to the system bus.");
}

static void
root_create_event_loop(struct root *root)
{
    int r = sd_event_default(&root->event);
    if (r < 0)
        die("Failed to create the event loop.");
}

static void
root_take_control_of_session(struct root *root)
{
    int r;

    r = sd_pid_get_session(getpid(), &root->session_name);
    if (r < 0)
        die("Failed to determine user session name.");
    report_state("User session name is %s.", root->session_name);

    sd_bus_error error = SD_BUS_ERROR_NULL;
    sd_bus_message *message = NULL;
    r = sd_bus_call_method(root->bus, "org.freedesktop.login1",
            "/org/freedesktop/login1", "org.freedesktop.login1.Manager",
            "GetSessionByPID",
            &error, &message,
            "u", (uint32_t)getpid());
    if (r < 0)
        die("Failed to determine the logind object path of the session.");
    r = sd_bus_message_read(message, "o", &root->session_path);
    if (r < 0)
        die("Failed while communicating with logind.");
    root->session_path = strdup(root->session_path);
    sd_bus_message_unref(message);

    r = sd_bus_call_method(root->bus, "org.freedesktop.login1",
            root->session_path,
            "org.freedesktop.login1.Session", "TakeControl",
            &error, &message,
            "b", (int32_t)0);
    if (r < 0)
        die("Failed to take control of session.");
    sd_bus_message_unref(message);
}

static void
root_take_control_of_display(struct root *root)
{
    sd_bus_error error = SD_BUS_ERROR_NULL;
    sd_bus_message *message = NULL;
    struct display *disp = &root->display;
    display_init(disp);
    uint32_t major = 226;
    uint32_t minor = 0;
    int r = sd_bus_call_method(root->bus, "org.freedesktop.login1",
            root->session_path,
            "org.freedesktop.login1.Session", "TakeDevice",
            &error, &message,
            "uu", major, minor);
    if (r < 0)
        die("Failed while communicating with logind.");
    int fd;
    int32_t paused;
    r = sd_bus_message_read(message, "hb", &fd, &paused);
    if (r < 0)
        die("Failed while communicating with logind.");
    if (paused)
        die("Failed to take control of the display.");
    fd = dup(fd);
    sd_bus_message_unref(message);
    enable_cloexec(fd);
    disable_nonblock(fd);
    display_connect(root, disp, fd);
}

static void
root_take_control_of_keyboard(struct root *root)
{
    sd_bus_error error = SD_BUS_ERROR_NULL;
    sd_bus_message *message = NULL;
    struct keyboard *keyboard = &root->keyboard;
    uint32_t major;
    uint32_t minor;
    root_find_keyboard(&major, &minor);
    int r = sd_bus_call_method(root->bus, "org.freedesktop.login1",
            root->session_path,
            "org.freedesktop.login1.Session", "TakeDevice",
            &error, &message,
            "uu", major, minor);
    if (r < 0)
        die("Failed while communicating with logind.");
    int fd;
    int32_t paused;
    r = sd_bus_message_read(message, "hb", &fd, &paused);
    if (r < 0)
        die("Failed while communicating with logind.");
    if (paused)
        die("Failed to take control of the keyboard.");
    fd = dup(fd);
    enable_cloexec(fd);
    sd_bus_message_unref(message);
    keyboard_connect(root, keyboard, fd, major, minor);
}

static void
root_subscribe_to_bus_signals(struct root *root)
{
    char *pattern = NULL;
    int r;

    pattern = easy_sprintf("type='signal',path='%s',"
            "interface='org.freedesktop.login1.Session',"
            "member='PauseDevice'",
            root->session_path);
    r = sd_bus_add_match(root->bus, NULL, pattern,
            handle_bus_signal_pause_device, root);
    if (r < 0)
        die("Failed to subscribe to PauseDevice signal.");
    free(pattern);

    pattern = easy_sprintf("type='signal',path='%s',"
            "interface='org.freedesktop.login1.Session',"
            "member='ResumeDevice'",
            root->session_path);
    r = sd_bus_add_match(root->bus, NULL, pattern,
            handle_bus_signal_resume_device, root);
    if (r < 0)
        die("Failed to subscribe to ResumeDevice signal.");
    free(pattern);

    pattern = easy_sprintf("type='signal',path='%s',"
            "interface='org.freedesktop.DBus.Properties',"
            "member='PropertiesChanged'",
            root->session_path);
    r = sd_bus_add_match(root->bus, NULL, pattern,
            handle_bus_signal_properties_changed, root);
    if (r < 0)
        die("Failed to subscribe to PropertiesChanged signal.");
    free(pattern);

    int fd = sd_bus_get_fd(root->bus);
    if (fd < 0)
        die("Failed to initialize DBus monitoring.");
    r = sd_event_add_io(root->event, NULL, fd, EPOLLIN,
            handle_bus_event, root);
    if (r < 0)
        die("Failed to initialize DBus monitoring.");
}

static void
root_subscribe_to_os_signals(struct root *root)
{
    sigset_t sigset;
    const char error_message[] = "Failed to initialize signal handling.";

    signal(SIGPIPE, SIG_IGN);

    int r;
    r = sigemptyset(&sigset);
    if (r == -1)
        die(error_message);

    r = sigaddset(&sigset, SIGTERM);
    if (r == -1)
        die(error_message);

    r = sigaddset(&sigset, SIGINT);
    if (r == -1)
        die(error_message);

    r = sigaddset(&sigset, SIGCHLD);
    if (r == -1)
        die(error_message);

    r = sigprocmask(SIG_BLOCK, &sigset, NULL);
    if (r == -1)
        die(error_message);

    r = sd_event_add_signal(root->event, NULL, SIGTERM, NULL, NULL);
    if (r < 0)
        die(error_message);

    r = sd_event_add_signal(root->event, NULL, SIGINT, NULL, NULL);
    if (r < 0)
        die(error_message);

    r = sd_event_add_signal(root->event, NULL, SIGCHLD, handle_os_signal_sigchld, root);
    if (r < 0)
        die(error_message);
}

static void
root_interact(struct root *root)
{
    uint32_t disp_width, disp_height;
    {
        struct display_buffer *back_buffer = display_back_buffer(&root->display);
        disp_width = back_buffer->width;
        disp_height = back_buffer->height;
    }

    for (int i = 0; i < NUM_WORKSPACES; i++) {
        struct workspace *workspace = &root->workspaces[i];

        workspace->focus = 0;

        for (int j = 0; j < NUM_WINDOWS_PER_WORKSPACE; j++) {
            struct window *window = &workspace->windows[j];
            switch (j) {
            case 0:
                window->x = 20;
                window->y = 20;
                window->width = (disp_width / 2) - 30;
                window->height = disp_height - 40;
                break;
            case 1:
                window->x = (disp_width / 2) + 10;
                window->y = 20;
                window->width = (disp_width / 2) - 30;
                window->height = (disp_height / 2) - 30;
                break;
            case 2:
                window->x = (disp_width / 2) + 10;
                window->y = (disp_height / 2) + 10;
                window->width = (disp_width / 2) - 30;
                window->height = (disp_height / 2) - 30;
                break;
            default:
                die("Failed to initialize workspaces.");
                break;
            }
            window->front_buffer = 0;
            window->mem = NULL;
            window->socket_fd = -1;
            window->pid = -1;
        }
    }

    root->is_active = true;
    root->focus = 0;

    draw(root);

    int r = sd_event_loop(root->event);
    if (r < 0)
        die("Failed while processing events.");
}

static void
root_shutdown(struct root *root)
{
    sd_bus_error error = SD_BUS_ERROR_NULL;
    sd_bus_message *message = NULL;
    display_disconnect(&root->display);
    int r = sd_bus_call_method(root->bus, "org.freedesktop.login1",
            root->session_path,
            "org.freedesktop.login1.Session", "ReleaseControl",
            &error, &message,
            "");
    if (r < 0)
        die("Failed to release control of session.");
    sd_bus_message_unref(message);
    free(root->session_path);
    free(root->session_name);
    sd_event_unref(root->event);
    sd_bus_unref(root->bus);
}

int
main(void)
{
    report_event("Started.");

    struct root *root = ROOT_INIT;

    root_connect_to_bus(root);

    root_create_event_loop(root);

    root_take_control_of_session(root);

    root_take_control_of_display(root);

    root_take_control_of_keyboard(root);

    root_subscribe_to_bus_signals(root);

    root_subscribe_to_os_signals(root);

    root_interact(root);

    root_shutdown(root);

    report_event("Finished.");

    return 0;
}