#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <ctype.h>
#include <fcntl.h>
#include <stdint.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include "runtime.h"

_Noreturn void
halt(void)
{
    exit(1);
}

_Noreturn void
die(const char *s)
{
    fprintf(stderr, "Fatal: %s\n", s);
    halt();
}

static void *
xmalloc(size_t size)
{
    void *x = malloc(size);
    if (x == NULL) halt();
    return x;
}

_Bool
has_tag(struct value x, uint32_t tag)
{
    return (x.bits & TAG_MASK) == tag;
}

_Bool
is_number(struct value x)
{
    return (x.bits & 1) == 0;
}

void *
address(struct heap *heap, struct value x)
{
    return heap->values + (x.bits >> 4);
}

struct value
reference_value(uint32_t offset, uint32_t tag)
{
    return value((offset << 4) | tag);
}

int32_t
value_unbox_int32(struct value x)
{
    if (!is_number(x)) halt();
    return ((int32_t)x.bits)/2;
}

struct value
labeled_empty_tuple(uint16_t label)
{
    struct value x = { .bits = label };
    x.bits <<= 16;
    x.bits |= (SECOND_TAG_LABELED | TAG_IMMED);
    return x;
}

struct value
alloc_tuple(struct heap *heap, struct value *values, uint16_t size)
{
    return alloc_module(heap, values, size, UINT16_MAX);
}

struct value
alloc_labeled_value(struct heap *heap, uint16_t label, struct value value)
{
    if (value.bits == empty_tuple.bits) {
        struct value x = { .bits = label };
        x.bits <<= 16;
        x.bits |= (SECOND_TAG_LABELED | TAG_IMMED);
        return x;
    }
    struct labeled_value *labeled_value =
        (struct labeled_value *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_LABELED);
    heap->top += sizeof(struct labeled_value);
    labeled_value->label = label;
    labeled_value->value = value;
    return x;
}

uint16_t
value_label(struct heap *heap, struct value x)
{
    if (has_tag(x, TAG_IMMED)) {
        uint16_t label = (uint16_t)(x.bits >> 16);
        uint32_t second_tag = x.bits & 0xfff0;
        if (second_tag != SECOND_TAG_LABELED) halt();
        return label;
    }
    if (!has_tag(x, TAG_LABELED)) halt();
    struct labeled_value *value = (struct labeled_value *)address(heap, x);
    return value->label;
}

void
open_labeled_value(struct heap *heap, struct value *frame, struct value x,
        int n)
{
    if (has_tag(x, TAG_IMMED)) {
        uint32_t second_tag = x.bits & 0xfff0;
        if (second_tag != SECOND_TAG_LABELED) halt();
        if (n != 0) halt();
        return;
    }
    if (!has_tag(x, TAG_LABELED)) halt();
    struct labeled_value *value = (struct labeled_value *)address(heap, x);
    if (n == 1) {
        frame[0] = value->value;
        return;
    }
    if (!has_tag(value->value, TAG_MODULE)) halt();
    struct module *module = (struct module *)address(heap, value->value);
    if (module->size != n) halt();
    memmove(frame, &module->entries, n*sizeof(struct value));
}

struct value
remove_label(struct heap *heap, struct value x)
{
    if (has_tag(x, TAG_IMMED)) {
        uint32_t second_tag = x.bits & 0xfff0;
        if (second_tag != SECOND_TAG_LABELED) halt();
        return empty_tuple;
    }
    if (!has_tag(x, TAG_LABELED)) halt();
    struct labeled_value *value = (struct labeled_value *)address(heap, x);
    return value->value;
}

void
open_tuple(struct heap *heap, struct value *frame, struct value x, int n)
{
    if (x.bits == empty_tuple.bits) {
        if (n != 0) halt();
        return;
    }
    if (!has_tag(x, TAG_MODULE)) halt();
    struct module *tuple = (struct module *)address(heap, x);
    if (tuple->index_begin != UINT16_MAX) halt();
    if (tuple->size != n) halt();
    memmove(frame, &tuple->entries, n*sizeof(struct value));
}

struct value
alloc_closure(struct heap *heap, int8_t num_params,
        uint32_t env_size, int32_t code_offset)
{
    struct closure *c = (struct closure *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_CLOSURE);
    heap->top += sizeof(struct closure) + sizeof(struct value)*env_size;
    c->num_params = num_params;
    c->env_size = env_size;
    c->code_offset = code_offset;
    return x;
}

void
closure_store(struct heap *heap, struct value x, uint32_t i,
        struct value value)
{
    if (!has_tag(x, TAG_CLOSURE)) halt();
    struct closure *closure = (struct closure *)address(heap, x);
    closure->free_values[i] = value;
}

static uint32_t
string_size(struct heap *heap, struct value s)
{
    if (!has_tag(s, TAG_STRING)) halt();
    struct string *string = (struct string *)address(heap, s);
    return string->size;
}

static uint8_t *
string_bytes(struct heap *heap, struct value s)
{
    if (!has_tag(s, TAG_STRING)) halt();
    struct string *string = (struct string *)address(heap, s);
    return &string->bytes[0];
}

static const char *
string_chars(struct heap *heap, struct value s)
{
    return (const char *)string_bytes(heap, s);
}

struct value
alloc_string(struct heap *heap, const char *s)
{
    size_t text_size = strlen(s);
    size_t object_size = sizeof(struct string) + 4*(text_size/4+1);
    struct string *string = (struct string *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_STRING);
    heap->top += object_size;
    string->size = text_size;
    memmove((void *)string + sizeof(struct string), s, text_size);
    memset((void *)string + sizeof(struct string) + text_size, 0,
            object_size - (sizeof(struct string) + text_size));
    return x;
}

struct value
number(int32_t n)
{
    if (n < (INT32_MIN/2) || (INT32_MAX/2) < n) halt();
    return value(2*n);
}

struct value
prim_die(struct heap *heap, struct value s)
{
    if (!has_tag(s, TAG_STRING)) halt();
    struct string *string = (struct string *)address(heap, s);
    die((const char *)string->bytes);
    return empty_tuple;
}

struct value
prim_print_line(struct heap *heap, struct value s)
{
    if (!has_tag(s, TAG_STRING)) halt();
    struct string *string = (struct string *)address(heap, s);
    printf("%s\n", string->bytes);
    return empty_tuple;
}

struct value
prim_file_create(struct heap *heap, struct value name)
{
    FILE *stream = fopen(string_chars(heap, name), "w");
    if (stream == NULL) halt();
    struct file *heap_file = (struct file *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_FILE);
    heap->top += sizeof(struct file);
    heap_file->stream = stream;
    return x;
}

struct value
prim_file_open(struct heap *heap, struct value name)
{
    FILE *stream = fopen(string_chars(heap, name), "r");
    if (stream == NULL) halt();
    struct file *heap_file = (struct file *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_FILE);
    heap->top += sizeof(struct file);
    heap_file->stream = stream;
    return x;
}

struct value
prim_file_close(struct heap *heap, struct value file)
{
    if (!has_tag(file, TAG_FILE)) halt();
    struct file *heap_file = (struct file *)address(heap, file);
    fclose(heap_file->stream);
    return empty_tuple;
}

struct value
prim_file_read_all(struct heap *heap, struct value file)
{
    if (!has_tag(file, TAG_FILE)) halt();
    struct file *heap_file = (struct file *)address(heap, file);
    struct stat statbuf;
    if (-1 == fstat(fileno(heap_file->stream), &statbuf)) halt();
    unsigned size = (unsigned)statbuf.st_size;
    char *source = xmalloc(size+1);
    if (1 != fread(source, size, 1, heap_file->stream)) halt();
    source[size] = 0;
    struct value s = alloc_string(heap, source);
    free(source);
    return s;
}

struct value
prim_file_write(struct heap *heap, struct value file, struct value str)
{
    if (!has_tag(file, TAG_FILE)) halt();
    if (!has_tag(str, TAG_STRING)) halt();
    struct file *heap_file = (struct file *)address(heap, file);
    size_t ret = fwrite(string_chars(heap, str), string_size(heap, str), 1,
            heap_file->stream);
    if (ret != 1) halt();
    return empty_tuple;
}

struct value
prim_show_integer(struct heap *heap, struct value n)
{
    if (!is_number(n)) halt();
    char text[16];
    size_t size = sizeof(text);
    if (snprintf(text, size, "%d", (int)value_unbox_int32(n)) >= size) halt();
    return alloc_string(heap, text);
}

struct value
prim_bits(struct heap *heap, struct value x)
{
    return number(x.bits);
}

struct value
prim_multiply(struct heap *heap, struct value a, struct value b)
{
    // TODO Overflow and stuff?
    return number(value_unbox_int32(a) * value_unbox_int32(b));
}

struct value
prim_add(struct heap *heap, struct value a, struct value b)
{
    // TODO Overflow and stuff?
    struct value n = number(value_unbox_int32(a) + value_unbox_int32(b));
    return n;
}

struct value
prim_negate(struct heap *heap, struct value n)
{
    // TODO Overflow and stuff?
    return number(-value_unbox_int32(n));
}

struct value
prim_equal(struct heap *heap, struct value a, struct value b)
{
    if (!is_number(a) || !is_number(b)) halt();
    if (a.bits == b.bits)
        return true_value;
    return false_value;
}

struct value
prim_less(struct heap *heap, struct value a, struct value b)
{
    if ((int32_t)a.bits < (int32_t)b.bits)
        return true_value;
    return false_value;
}

struct value
prim_less_or_equal(struct heap *heap, struct value a, struct value b)
{
    if ((int32_t)a.bits <= (int32_t)b.bits)
        return true_value;
    return false_value;
}

struct value
prim_greater(struct heap *heap, struct value a, struct value b)
{
    if ((int32_t)a.bits > (int32_t)b.bits)
        return true_value;
    return false_value;
}

struct value
prim_greater_or_equal(struct heap *heap, struct value a, struct value b)
{
    if ((int32_t)a.bits >= (int32_t)b.bits)
        return true_value;
    return false_value;
}

struct value
prim_string_length(struct heap *heap, struct value s)
{
    // TODO Which is it, "size" or "length"?
    return number(string_size(heap, s));
}

struct value
prim_string_fetch(struct heap *heap, struct value s, struct value i)
{
    const char *chars = string_chars(heap, s);
    int32_t ii = value_unbox_int32(i);
    if (ii < 0 || string_size(heap, s) <= ii) halt();
    return number(chars[ii]);
}

struct value
prim_string_compare(struct heap *heap, struct value s1, struct value s2)
{
    // I tried using strcmp first but it was not working out for me so I wrote
    // my own. :P
    //
    // return number(strcmp(string_bytes(heap, s1), string_bytes(heap, s2)));
    //
    // TODO What is the deal with strcmp? I was seeing
    //      "STDIO" < "dummy" and "dummy" < "STDIO"

    const char *p1 = string_chars(heap, s1);
    const char *p2 = string_chars(heap, s2);
    for (;;) {
        char a = *p1;
        char b = *p2;
        if (a == 0) {
            if (b == 0) return number(0);
            return number(-1);
        }
        if (b == 0) return number(1);
        if (a < b) return number(-1);
        if (a > b) return number(1);
        p1++;
        p2++;
    }
    return number(0); // Unreachable.
}

struct value
prim_string_equal(struct heap *heap, struct value s1, struct value s2)
{
    if (strcmp(string_chars(heap, s1), string_chars(heap, s2)))
        return false_value;
    return true_value;
}

struct value
prim_string_append(struct heap *heap, struct value s1, struct value s2)
{
    uint32_t s1_size = string_size(heap, s1);
    uint32_t size = s1_size + string_size(heap, s2);
    char *fresh = xmalloc(size+1);
    memmove(fresh, string_bytes(heap, s1), s1_size);
    memmove(fresh+s1_size, string_bytes(heap, s2), size+1-s1_size);
    struct value s = alloc_string(heap, fresh);
    free(fresh);
    return s;
}

struct value
prim_string_clip(struct heap *heap, struct value s, struct value begin,
        struct value end)
{
    if (!has_tag(s, TAG_STRING)) halt();
    if (!is_number(begin)) halt();
    if (!is_number(end)) halt();
    int32_t b = value_unbox_int32(begin);
    int32_t e = value_unbox_int32(end);
    if (b < 0 || e < 0 || e < b) halt();
    uint32_t s_size = string_size(heap, s);
    if (b >= s_size || e > s_size) halt();
    size_t text_size = (size_t)(e-b);
    size_t object_size = sizeof(struct string) + 4*(text_size/4+1);
    struct string *string = (struct string *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_STRING);
    heap->top += object_size;
    string->size = text_size;
    const char *s_chars = string_chars(heap, s);
    memmove((void *)string + sizeof(struct string), s_chars+b, text_size);
    memset((void *)string + sizeof(struct string) + text_size, 0,
            object_size - (sizeof(struct string) + text_size));
    return x;
}

struct value
prim_ref_new(struct heap *heap, struct value x)
{
    struct ref *ref = (struct ref *)(heap->values + heap->top);
    struct value r = reference_value(heap->top, TAG_REF);
    heap->top += sizeof(struct ref);
    ref->x = x;
    return r;
}

struct value
prim_ref_store(struct heap *heap, struct value r, struct value x)
{
    if (!has_tag(r, TAG_REF)) halt();
    struct ref *ref = (struct ref *)address(heap, r);
    ref->x = x;
    return x;
}

struct value
prim_ref_fetch(struct heap *heap, struct value r)
{
    if (!has_tag(r, TAG_REF)) halt();
    struct ref *ref = (struct ref *)address(heap, r);
    return ref->x;
}

struct value
alloc_module(struct heap *heap, struct value *values, uint16_t size, uint16_t index_begin)
{
    struct module *module = (struct module *)(heap->values + heap->top);
    struct value x = reference_value(heap->top, TAG_MODULE);
    heap->top += sizeof(struct module) + size*sizeof(struct value);
    module->size = size;
    module->index_begin = index_begin;
    for (int i = 0; i < size; i++)
        module->entries[i] = values[i];
    return x;
}

void
module_store(struct heap *heap, struct value x, uint16_t i, struct value value)
{
    if (!has_tag(x, TAG_MODULE)) halt();
    struct module *module = (struct module *)address(heap, x);
    module->entries[i] = value;
}

struct value
tuple_fetch(struct heap *heap, struct value x, uint16_t i)
{
    if (!has_tag(x, TAG_MODULE)) halt();
    struct module *module = (struct module *)address(heap, x);
    if (module->index_begin != UINT16_MAX) halt();
    return module->entries[i];
}

struct value
module_fetch(struct heap *heap, struct value x, uint16_t label)
{
    if (!has_tag(x, TAG_MODULE)) halt();
    struct module *module = (struct module *)address(heap, x);
    if (module->index_begin == UINT16_MAX) halt();
    uint16_t begin = module->index_begin;
    for (uint16_t i = begin; module_indices[i] != UINT16_MAX; i++) {
        if (module_indices[i] == label)
            return module->entries[i-begin];
    }
    halt();
}

int
main(int argc, const char *argv[])
{
    run_machine(argc, argv);

    return 0;
}