#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <langinfo.h>
#include <linux/input.h>
#include <locale.h>
#include <signal.h>
#include <stdarg.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/uio.h>
#include <sys/wait.h>
#include <unistd.h>

#include <fira_mono_tables.h>
#include <keymap.h>
#include <shell.h>

static void
germ_init(void);

static void
germ_interact(void);

int
main(void)
{
    germ_init();
    for (;;)
        germ_interact();
}

static void
germ_io_add(int fd, uint32_t events);

struct germ {
    unsigned char *buffer;
    unsigned int back_buffer:1;
    unsigned int back_buffer_is_locked:1;
    unsigned int is_draw_needed:1;
    uint32_t keyboard_modifiers;
    uint32_t width, height;
    int epoll_fd;
    int line_size;
    char *command_buffer;
    int cursor;
};

#define array_size(a) (sizeof(a) / sizeof((a)[0]))

static _Noreturn void
die(const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    fprintf(stderr, "Error: ");
    vfprintf(stderr, fmt, args);
    va_end(args);
    fputc('\n', stderr);
    exit(1);
}

void *
malloc_or_die(size_t size)
{
    void *x = malloc(size);
    if (size > 0 && x == NULL)
        die("Failed to allocate memory.");
    return x;
}

static void
germ_draw(struct germ *germ);

static void
germ_launch(struct germ *germ, uint32_t width, uint32_t height,
        unsigned char *buffer);

static void
germ_on_input(struct germ *germ, uint16_t code, int32_t value);

#define BYTES_PER_PIXEL 4

//  TODO Font metrics should be stored with the font bitmaps, no?

static struct metrics {
    int line_spacing;
    int advance;
    int baseline;
    int border_size_top;
    int border_size_bottom;
    int border_size_left;
    int border_size_right;
} metrics;

static void
on_buffer_unlocked(struct germ *germ)
{
    germ->back_buffer_is_locked = 0;
    if (germ->is_draw_needed)
        germ_draw(germ);
}

static uint32_t
encode_color_xrgb(uint8_t r, uint8_t g, uint8_t b)
{
    union {
        uint32_t u;
        uint8_t bytes[4];
    } x = {
        .bytes = { b, g, r, 0 },
    };
    return x.u;
}

static void
draw_pixmap(void *mem, const uint8_t *bytes, int rows, int width,
        int x_start, int y_start, uint32_t stride)
{

    //  The transformation from input bytes to output color values is the
    //  linear mapping that maps the extremes as follows:
    //
    //  0x00 => 0xff
    //  0xff => 0x47
    //
    //  In other words, a clear byte (0x00) is mapped to white (0xff) and a
    //  fully-set byte (0xff) is mapped to a dark gray (0x47).

    uint32_t *pixels = mem;
    uint32_t stride_in_pixels = stride / BYTES_PER_PIXEL;
    int x_end = x_start + width;
    int y_end = y_start + rows;
    for (int y = y_start, v = 0; y < y_end; y++, v++) {
        for (int x = x_start, u = 0; x < x_end; x++, u++) {
            uint32_t byte = bytes[v * width + u];
            uint32_t value = (byte * 0x48 + (0xff - byte) * 0x100) / 0x100;
            uint32_t color = encode_color_xrgb(value, value, value);
            pixels[y * stride_in_pixels + x] = color;
        }
    }
}

static int
clamp(int x, int a, int b)
{
    if (x < a)
        return a;
    if (x > b)
        return b;
    return x;
}

static uint32_t
find_bitmap_offset(int codepoint)
{
    //  Assumption: interval_starts is not empty.
    //  Assumption: interval_starts is strictly increasing.
    //  Assumption: interval_starts[0] == 0.
    //  Assumption: intervals do not overlap.
    //  Assumption: The null glyph has bitmap offset 0.

    int i = 0;
    while (i + 1 < array_size(interval_starts) && interval_starts[i + 1] <= codepoint)
        i++;

    //  Invariant: i is the highest index such that interval_starts[i] <= codepoint.

    if (codepoint >= interval_starts[i] + interval_sizes[i]) {

        //  Invariant: codepoint is not part of any interval.

        return 0;
    }

    //  Invariant: codepoint belongs to interval i.

    int j = codepoint - interval_starts[i];
    while (--i >= 0)
        j += interval_sizes[i];

    return bitmap_offsets[j];
}

static void
draw_text(void *mem, uint32_t stride, const char *text)
{
    int pen_y = metrics.border_size_top + metrics.baseline;
    int pen_x = metrics.border_size_left;
    for (int j = 0; text[j] != '\0'; j++) {
        uint32_t bitmap_offset = find_bitmap_offset(text[j]);
        int top = bitmap_bytes[bitmap_offset + 0];
        if (top > 127) top -= 256;
        int left = bitmap_bytes[bitmap_offset + 1];
        if (left > 127) left -= 256;
        int rows = bitmap_bytes[bitmap_offset + 2];
        int width = bitmap_bytes[bitmap_offset + 3];
        draw_pixmap(mem, &bitmap_bytes[bitmap_offset + 4],
                rows, width,
                pen_x + left, pen_y - top, stride);
        pen_x += metrics.advance;
    }
}

static void
draw_cursor(void *mem, uint32_t stride, int cursor)
{
    uint32_t *pixels = mem;
    uint32_t stride_in_pixels = stride / BYTES_PER_PIXEL;

    int x_start = metrics.border_size_left + cursor * metrics.advance;
    int y_start =
        metrics.border_size_top + metrics.line_spacing - 5;

    int x_end = clamp(x_start + metrics.advance, 0, stride_in_pixels);
    int y_end = y_start + 3;

    int guide_start = (y_end - 2) * stride_in_pixels;
    int guide_end = guide_start + stride_in_pixels;

    {
        uint32_t color = encode_color_xrgb(0xcc, 0xcc, 0xcc);
        for (int i = guide_start; i < guide_end; i++)
            pixels[i] = color;
    }

    {
        uint32_t color = encode_color_xrgb(0x48, 0x48, 0x48);
        for (int y = y_start; y < y_end; y++)
            for (int x = x_start; x < x_end; x++)
                pixels[y * stride_in_pixels + x] = color;
    }
}

static void
draw_background(void *mem, uint32_t width, uint32_t height)
{
    uint32_t color = encode_color_xrgb(0xff, 0xff, 0xff);

    uint32_t *pixel = mem;
    for (uint32_t y = 0; y < height; y++)
        for (uint32_t x = 0; x < width; x++)
            *pixel++ = color;
}

static void
send_reset(int fd);

static void
send_show(int fd);

static void
germ_draw(struct germ *germ)
{
    uint32_t stride;
    unsigned char *buffer;

    if (germ->back_buffer_is_locked) {
        germ->is_draw_needed = 1;
        return;
    }
    buffer = germ->buffer +
        (germ->back_buffer * BYTES_PER_PIXEL * germ->width * germ->height);
    draw_background(buffer, germ->width, germ->height);
    stride = BYTES_PER_PIXEL * germ->width;
    draw_cursor(buffer, stride, germ->cursor);
    draw_text(buffer, stride, germ->command_buffer);
    send_show(3);
    germ->back_buffer = !germ->back_buffer;
    germ->back_buffer_is_locked = 1;
    germ->is_draw_needed = 0;
}

static void
init_metrics(void)
{
    metrics.advance = 10;
    metrics.line_spacing = 24;
    metrics.baseline = 14;
    metrics.border_size_top = 10;
    metrics.border_size_bottom = 10;
    metrics.border_size_left = 5;
    metrics.border_size_right = 5;
}

static void
germ_launch(struct germ *germ, uint32_t width, uint32_t height,
        unsigned char *buffer)
{
    germ->epoll_fd = epoll_create1(EPOLL_CLOEXEC);
    if (germ->epoll_fd == -1)
        die("Failed to initialize.");
    germ->buffer = buffer;
    germ->back_buffer = 1;
    germ->back_buffer_is_locked = 0;
    germ->is_draw_needed = 0;
    germ->width = width;
    germ->height = height;
    germ->keyboard_modifiers = 0;
}

#define MOD_CONTROL_MASK 1
#define MOD_ALT_MASK 2
#define MOD_SHIFT_MASK 4

static void
germ_on_input(struct germ *germ, uint16_t code, int32_t value)
{
    {

        //  TODO Handle the following weird case: left-shift is pressed,
        //  right-shift is pressed, left-shift is released: the modifier_mask
        //  should still indicate that shift is pressed.

        uint32_t modifier_mask = 0;
        switch (code) {
        case KEY_CAPSLOCK:
            modifier_mask = MOD_CONTROL_MASK;
            break;
        case KEY_LEFTALT:
        case KEY_RIGHTALT:
            modifier_mask = MOD_ALT_MASK;
            break;
        case KEY_LEFTSHIFT:
        case KEY_RIGHTSHIFT:
            modifier_mask = MOD_SHIFT_MASK;
            break;
        }
        if (modifier_mask != 0) {
            if (value == 0)
                germ->keyboard_modifiers &= ~modifier_mask;
            else
                germ->keyboard_modifiers |= modifier_mask;
            return;
        }
    }

    if (value == 0)
        return;

    if (code == KEY_BACKSPACE && germ->cursor > 0) {
        memmove(germ->command_buffer + (germ->cursor - 1),
                germ->command_buffer + germ->cursor,
                strlen(germ->command_buffer + germ->cursor) + 1);
        germ->cursor--;
        germ_draw(germ);
        return;
    }

    switch (code) {
    case KEY_UP:
        return;
    case KEY_DOWN:
        return;
    case KEY_RIGHT:
        if (germ->command_buffer[germ->cursor] != '\0') {
            germ->cursor++;
            germ_draw(germ);
        }
        return;
    case KEY_LEFT:
        if (germ->cursor > 0) {
            germ->cursor--;
            germ_draw(germ);
        }
        return;
    }

    if (code == KEY_ENTER) {
        send_reset(3);
        return;
    }

    char ascii = (germ->keyboard_modifiers & MOD_SHIFT_MASK) ?
        keymap_lookup_shifted(code) : keymap_lookup(code);

    if (germ->keyboard_modifiers & MOD_CONTROL_MASK) {
        switch (ascii) {
        case 'u':
            if (germ->cursor != 0) {
                memmove(germ->command_buffer,
                        germ->command_buffer + germ->cursor,
                        strlen(germ->command_buffer + germ->cursor) + 1);
                germ->cursor = 0;
                germ_draw(germ);
            }
            return;
        case 'w':
            {
                int cursor_begin = germ->cursor;
                while (germ->cursor > 0 && germ->command_buffer[germ->cursor - 1] == ' ')
                    germ->cursor--;
                while (germ->cursor > 0 && germ->command_buffer[germ->cursor - 1] != ' ')
                    germ->cursor--;
                memmove(germ->command_buffer + germ->cursor,
                        germ->command_buffer + cursor_begin,
                        strlen(germ->command_buffer + cursor_begin) + 1);
            }
            germ_draw(germ);
            return;
        case 'd':
            if (germ->command_buffer[germ->cursor] != '\0') {
                memmove(germ->command_buffer + germ->cursor,
                        germ->command_buffer + (germ->cursor + 1),
                        strlen(germ->command_buffer + germ->cursor + 1) + 1);
                germ_draw(germ);
            }
            return;
        case 'k':
            germ->command_buffer[germ->cursor] = '\0';
            germ_draw(germ);
            return;
        case 'f':
            if (germ->command_buffer[germ->cursor] != '\0') {
                germ->cursor++;
                germ_draw(germ);
            }
            return;
        case 'b':
            if (germ->cursor > 0) {
                germ->cursor--;
                germ_draw(germ);
            }
            return;
        case 'a':
            germ->cursor = 0;
            germ_draw(germ);
            return;
        case 'e':
            while (germ->command_buffer[germ->cursor] != '\0')
                germ->cursor++;
            germ_draw(germ);
            return;
        }
    }

    if (ascii != '\0') {

        switch (ascii) {
        case ' ':
        case '-':
        case '_':
        case '/':
        case '.':
            break;
        default:
            if (!isalnum(ascii))
                return;
            break;
        }

        if (strlen(germ->command_buffer) + 1 >= germ->line_size)
            return;

        memmove(germ->command_buffer + (germ->cursor + 1),
                germ->command_buffer + germ->cursor,
                strlen(germ->command_buffer + germ->cursor) + 1);
        germ->command_buffer[germ->cursor] = ascii;
        germ->cursor++;
        germ_draw(germ);
    }
}

static void
receive_welcome(int fd, struct shell_message_welcome *message)
{
    ssize_t r = read(fd, message, sizeof(*message));
    if (r != sizeof(*message))
        die("Failed to read welcome message.");

    if (message->class != SHELL_MESSAGE_WELCOME)
        die("Failed to receive welcome.");
}

static void
receive_keyboard_event(void *data, struct germ *germ)
{
    struct shell_message_keyboard_event *message = data;

    germ_on_input(germ, message->code, message->value);
}

static void
receive_buffer_unlocked(void *data, struct germ *germ)
{
    (void)data;

    on_buffer_unlocked(germ);
}

static void
receive_keyboard_detached(void *data)
{
    (void)data;

    fprintf(stderr, "Keyboard detached.\n");
}

static int
fetch_bit(const uint8_t *bits, unsigned i)
{
    return (bits[i / 8] & (1 << (i % 8))) ? 1 : 0;
}

static void
receive_keyboard_attached(void *data, struct germ *germ)
{
    struct shell_message_keyboard_attached *message = data;

    const uint8_t *key_bits = message->key_bits;

    if (fetch_bit(key_bits, KEY_CAPSLOCK))
        germ->keyboard_modifiers |= MOD_CONTROL_MASK;
    else
        germ->keyboard_modifiers &= ~MOD_CONTROL_MASK;

    if (fetch_bit(key_bits, KEY_LEFTALT) || fetch_bit(key_bits, KEY_RIGHTALT))
        germ->keyboard_modifiers |= MOD_ALT_MASK;
    else
        germ->keyboard_modifiers &= ~MOD_ALT_MASK;

    if (fetch_bit(key_bits, KEY_LEFTSHIFT) || fetch_bit(key_bits, KEY_RIGHTSHIFT))
        germ->keyboard_modifiers |= MOD_SHIFT_MASK;
    else
        germ->keyboard_modifiers &= ~MOD_SHIFT_MASK;

    fprintf(stderr, "Keyboard attached.\n");
}

static void
send_reset(int fd)
{
    struct shell_message_reset message = {
        .class = SHELL_MESSAGE_RESET,
    };

    ssize_t r = write(fd, &message, sizeof(message));
    if (r != sizeof(message))
        die("Failed to write reset message.");
}

static void
send_show(int fd)
{
    struct shell_message_show message = {
        .class = SHELL_MESSAGE_SHOW,
    };

    ssize_t r = write(fd, &message, sizeof(message));
    if (r != sizeof(message))
        die("Failed to write show message.");
}

static void
handle_shell_message(int fd, struct germ *germ, uint32_t events)
{
    if (events & EPOLLHUP)
        exit(0);
    if (events & EPOLLIN) {
        static _Alignas(max_align_t) char packet[SHELL_PACKET_SIZE_MAX];
        int r;
        do {
            r = read(fd, packet, sizeof(packet));
        } while (r == -1 && errno == EINTR);
        if (r <= 0)
            die("Failed to read from shell socket.");
        uint16_t class = *(uint16_t *)packet;
        switch (class) {
        case SHELL_MESSAGE_RESET_COMPLETED:
            {
                //  Count whitespace-separated "words" to determine argc.

                int argc = 0;
                for (int i = 0, state = ' '; germ->command_buffer[i] != '\0'; i++) {
                    if (state == ' ' && germ->command_buffer[i] != ' ') {
                        argc++;
                        state = '-';
                    }
                    if (germ->command_buffer[i] == ' ')
                        state = ' ';
                }
                if (argc == 0)
                    die("Failed to launch client.");

                //  Construct argv.

                char *argv[argc + 1];
                argc = 0;
                for (int i = 0, state = ' '; germ->command_buffer[i] != '\0'; i++) {
                    if (state == ' ' && germ->command_buffer[i] != ' ') {
                        argv[argc++] = germ->command_buffer + i;
                        state = '-';
                    }
                    if (germ->command_buffer[i] == ' ') {
                        germ->command_buffer[i] = '\0';
                        state = ' ';
                    }
                }
                argv[argc] = NULL;

                //  Exec!

                execvp(argv[0], argv);
                die("Failed to launch client.");
            }
            break;
        case SHELL_MESSAGE_KEYBOARD_EVENT:
            receive_keyboard_event(packet, germ);
            break;
        case SHELL_MESSAGE_BUFFER_UNLOCKED:
            receive_buffer_unlocked(packet, germ);
            break;
        case SHELL_MESSAGE_KEYBOARD_DETACHED:
            receive_keyboard_detached(packet);
            break;
        case SHELL_MESSAGE_KEYBOARD_ATTACHED:
            receive_keyboard_attached(packet, germ);
            break;
        }
    }
}

static struct germ germ;

static void
germ_init(void)
{
    init_metrics();

    uint32_t width, height;
    {
        struct shell_message_welcome message;
        receive_welcome(3, &message);
        width = message.width;
        height = message.height;
    }

    size_t mem_size = 8 * width * height;
    void *mem = mmap(NULL, mem_size, PROT_READ|PROT_WRITE, MAP_SHARED, 4, 0);
    if (mem == MAP_FAILED)
        die("Failed to create draw buffers for child process.");

    germ_launch(&germ, width, height, mem);

    germ_io_add(3, EPOLLIN);

    germ.line_size =
        (width - metrics.border_size_left - metrics.border_size_right) /
            metrics.advance;
    germ.command_buffer = malloc_or_die(germ.line_size);
    germ.command_buffer[0] = '\0';
    germ.cursor = 0;

    germ_draw(&germ);
}

static void
germ_interact(void)
{
    struct epoll_event events[2];
    int num_events;
    {
        int r;
        do {
            r = epoll_wait(germ.epoll_fd, events, array_size(events), -1);
        } while (r == -1 && errno == EINTR);
        if (r == -1)
            die("epoll_wait: %m");
        num_events = r;
    }
    for (int i = 0; i < num_events; i++) {
        int fd = events[i].data.fd;
        if (fd == 3)
            handle_shell_message(3, &germ, events[i].events);
    }
}

static void
germ_io_add(int fd, uint32_t events)
{
    struct epoll_event event = {
        .data = { .fd = fd },
        .events = events,
    };
    int r = epoll_ctl(germ.epoll_fd, EPOLL_CTL_ADD, fd, &event);
    if (r == -1)
        die("epoll_ctl: %m");
}