#!/usr/bin/env python3
from enum import Enum, auto
from dataclasses import dataclass
import sys
import os
import subprocess
import tempfile 

import shutil
VERSION = "0.2"
def read_file(file_path: str) -> str:
    with open(file_path) as f:
        return f.read()


class TokenKind(Enum):
    NUMBER = auto()
    STRING = auto()
    IDENT = auto()
    RETURN = auto()
    LPAREN = auto()
    RPAREN = auto()
    LBRACE = auto()
    RBRACE = auto()
    SEMICOLON = auto()
    EOF = auto()
    PLUS = auto()
    MINUS = auto()
    STAR = auto()
    SLASH = auto()
    EQUAL = auto()
    LT = auto()
    GT = auto()
    LE = auto()
    GE = auto()
    EQEQ = auto()
    NEQ = auto()
    IF = auto()
    ELSE = auto()
    WHILE = auto()
    EXTRN = auto()
    COMMA = auto()
    REFOF = auto()

KW_TABLE = {
    "return": TokenKind.RETURN,
    "if": TokenKind.IF,
    "else": TokenKind.ELSE,
    "while": TokenKind.WHILE,
    "extrn": TokenKind.EXTRN,
}


@dataclass()
class Token:
    kind: TokenKind
    value: str
    line: int
    col: int


class CompileError(Exception):
    def __init__(self, message, file, line, col):
        self.message = message
        self.file = file
        self.line = line
        self.col = col
        super().__init__(self.__str__())

    def __str__(self):
        return f"{self.file}:{self.line}:{self.col}: error: {self.message}"


def lex_file(contents: str) -> list[Token]:
    i = 0
    line = 1
    col = 1
    tokens = []

    while i < len(contents):
        if contents[i].isspace() and not contents[i] == "\n":
            i += 1
            col += 1
            continue
        if contents[i] == "\n":
            col = 1
            line += 1
            i += 1
            continue
        if contents[i] == "\"":
            old_col = col
            value = ""
            i += 1
            while contents[i] != "\"":


                value += contents[i]
                col += 1
                i += 1
            token = Token(TokenKind.STRING, value, line, old_col)
            tokens.append(token)
            i+=1
            continue
        if contents[i].isnumeric():
            old_col = col
            value = ""

            while contents[i].isnumeric():
                value += contents[i]
                col += 1
                i += 1
            token = Token(TokenKind.NUMBER, value, line, old_col)
            tokens.append(token)
            continue
        if contents[i].isalpha():

            old_col = col
            value = ""

            while contents[i].isalnum():
                value += contents[i]
                col += 1
                i += 1
            if value in KW_TABLE.keys():
                token = Token(KW_TABLE[value], value, line, old_col)
            else:
                token = Token(TokenKind.IDENT, value, line, old_col)
            tokens.append(token)
            continue
        if contents[i] == "(":
            token = Token(TokenKind.LPAREN, "(", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == ")":
            token = Token(TokenKind.RPAREN, ")", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == ",":
            token = Token(TokenKind.COMMA, ",", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "&":
            token = Token(TokenKind.REFOF, "&", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "{":
            token = Token(TokenKind.LBRACE, "{", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "}":
            token = Token(TokenKind.RBRACE, "}", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == ";":
            token = Token(TokenKind.SEMICOLON, ";", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "+":
            token = Token(TokenKind.PLUS, "+", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "-":
            token = Token(TokenKind.MINUS, "-", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "*":
            token = Token(TokenKind.STAR, "*", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "/":
            if contents[i + 1] == "/":
                while contents[i] != "\n":
                    i += 1
                continue
            token = Token(TokenKind.SLASH, "/", line, col)
            i += 1
            col += 1
            tokens.append(token)
            continue
        if contents[i] == "=":
            if contents[i + 1] == "=":
                token = Token(TokenKind.EQEQ, "==", line, col)
                i += 2
                col += 2
                tokens.append(token)
                continue
            else:
                token = Token(TokenKind.EQUAL, "=", line, col)
                i += 1
                col += 1
                tokens.append(token)
                continue
        if contents[i] == ">":
            if contents[i + 1] == "=":
                token = Token(TokenKind.GE, ">=", line, col)
                i += 2
                col += 2
                tokens.append(token)
                continue
            else:
                token = Token(TokenKind.GT, ">", line, col)
                i += 1
                col += 1
                tokens.append(token)
                continue
        if contents[i] == "<":
            if contents[i + 1] == "=":
                token = Token(TokenKind.LE, "<=", line, col)
                i += 2
                col += 2
                tokens.append(token)
                continue
            else:
                token = Token(TokenKind.LT, "<", line, col)
                i += 1
                col += 1
                tokens.append(token)
                continue
        if contents[i] == "!":
            if contents[i + 1] == "=":
                token = Token(TokenKind.NEQ, "!=", line, col)
                i += 2
                col += 2
                tokens.append(token)
                continue
            else:
                pass

        i += 1
    tokens.append(Token(TokenKind.EOF, "", line, col))
    return tokens


class Parser:
    def __init__(self, tokens: list[Token], file):
        self.tokens = tokens
        self.file = file
        self.i = 0

    def peek(self) -> Token:
        return self.tokens[self.i]

    def advance(self) -> Token:
        tok = self.peek()
        self.i += 1
        return tok

    def expect(self, kind: TokenKind) -> Token:
        tok = self.peek()
        if tok.kind != kind:
            raise CompileError(
                f"expected {kind.name}, got {tok.kind.name}",
                self.file,
                tok.line,
                tok.col,
            )
        return self.advance()

    def parse_program(self) -> dict:
        items = []

        while self.peek().kind != TokenKind.EOF:
            if self.peek().kind == TokenKind.EXTRN:
                self.advance()
                tok = self.advance()
                items.append({"type": "extrn", "value": tok.value})
                while self.peek().kind == TokenKind.COMMA:
                    self.advance()
                    tok = self.advance()
                    items.append({"type": "extrn", "value": tok.value})

                self.advance()
                continue
            items.append(self.parse_function())

        return {
            "type": "program",
            "items": items,
        }

    def parse_function(self) -> dict:
        name = self.expect(TokenKind.IDENT)

        self.expect(TokenKind.LPAREN)
        params = []
        while self.peek().kind != TokenKind.RPAREN:
            tok = self.expect(TokenKind.IDENT)
            params.append(tok.value)
            if self.peek().kind == TokenKind.COMMA:
                self.advance()
        self.expect(TokenKind.RPAREN)

        body = self.parse_block()

        return {
            "type": "function",
            "name": name.value,
            "params": params,
            "body": body,
        }

    def parse_block(self) -> dict:
        self.expect(TokenKind.LBRACE)

        statements = []

        while self.peek().kind != TokenKind.RBRACE:
            statements.append(self.parse_statement())

        self.expect(TokenKind.RBRACE)

        return {
            "type": "block",
            "statements": statements,
        }

    def parse_if(self):
        self.expect(TokenKind.IF)
        self.expect(TokenKind.LPAREN)
        expr = self.parse_expr()
        self.expect(TokenKind.RPAREN)
        block = self.parse_block()

        if self.peek().kind == TokenKind.ELSE:
            self.advance()
            next_block = self.parse_block()

            return {
                "type": "ifelse",
                "condition": expr,
                "block": block,
                "next": next_block,
            }

        return {"type": "if", "condition": expr, "block": block}

    def parse_while(self):
        self.expect(TokenKind.WHILE)
        self.expect(TokenKind.LPAREN)
        expr = self.parse_expr()
        self.expect(TokenKind.RPAREN)
        block = self.parse_block()
        return {"type": "while", "condition": expr, "block": block}

    def parse_statement(self) -> dict:
        if self.peek().kind == TokenKind.RETURN:
            return self.parse_return()
        if self.peek().kind == TokenKind.IDENT:
            return self.parse_assign()
        if self.peek().kind == TokenKind.STAR:
            return self.parse_pointer_assign()
        if self.peek().kind == TokenKind.IF:
            return self.parse_if()
        if self.peek().kind == TokenKind.WHILE:
            return self.parse_while()

        tok = self.peek()
        raise SyntaxError("Bad Token")

    def parse_return(self) -> dict:
        self.expect(TokenKind.RETURN)
        value = self.parse_expr()
        self.expect(TokenKind.SEMICOLON)

        return {
            "type": "return",
            "value": value,
        }


    def parse_assign(self) -> dict:
        tok = self.peek()

        self.expect(TokenKind.IDENT)
        if self.peek().kind == TokenKind.LPAREN:
            call = self.parse_call(tok.value)
            self.expect(TokenKind.SEMICOLON)
            return call

        self.expect(TokenKind.EQUAL)
        value = self.parse_expr()
        self.expect(TokenKind.SEMICOLON)

        return {
            "type": "assignment",
            "name": tok.value,
            "value": value,
        }



    def parse_pointer_assign(self) -> dict:
        self.advance()
        lhs = self.parse_primary()
        self.expect(TokenKind.EQUAL)
        rhs = self.parse_expr()
        self.expect(TokenKind.SEMICOLON)

        return {
            "type": "ptr_assign", "lhs": lhs, "rhs": rhs
            }

    def parse_primary(self) -> dict:
        tok = self.peek()
        if tok.kind == TokenKind.REFOF:
            self.advance()
            return { "type": "reference", "contents": self.parse_primary()}
        if tok.kind == TokenKind.STAR:
            self.advance()

            return { "type": "dereference", "contents": self.parse_primary()}
        if tok.kind == TokenKind.NUMBER:
            self.advance()
            return {
                "type": "number",
                "value": int(tok.value),
            }
        if tok.kind == TokenKind.STRING:
            self.advance()
            return {
                "type": "string",
                "value": str(tok.value)
            }
        if tok.kind == TokenKind.LPAREN:
            self.advance()
            expr = self.parse_expr()
            self.expect(TokenKind.RPAREN)
            return expr
        if tok.kind == TokenKind.IDENT:

            self.advance()
            if self.peek().kind == TokenKind.LPAREN:
                return self.parse_call(tok.value)
            return {"type": "variable", "value": tok.value}
        print(tok.kind)
        raise SyntaxError("Bad")

    def parse_add(self):
        left = self.parse_mul()

        while self.peek().kind in [TokenKind.PLUS, TokenKind.MINUS]:
            op_tok = self.advance()
            right = self.parse_mul()
            if op_tok.kind == TokenKind.PLUS:
                op = "+"
            else:
                op = "-"
            left = {
                "type": "binary",
                "op": op,
                "left": left,
                "right": right,
            }

        return left

    def parse_unary(self):
        if self.peek().kind == TokenKind.MINUS:
            self.advance()
            return {
                "type": "unary",
                "op": "-",
                "value": self.parse_unary(),
            }
        return self.parse_primary()

    def parse_mul(self):
        left = self.parse_unary()

        while self.peek().kind in (TokenKind.STAR, TokenKind.SLASH):
            op_tok = self.advance()

            right = self.parse_unary()

            if op_tok.kind == TokenKind.STAR:
                op = "*"
            else:
                op = "/"

            left = {
                "type": "binary",
                "op": op,
                "left": left,
                "right": right,
            }

        return left

    def parse_call(self, name):
        self.expect(TokenKind.LPAREN)
        params = []
        while self.peek().kind != TokenKind.RPAREN:
            params.append(self.parse_expr())
            if self.peek().kind == TokenKind.COMMA:
                self.advance()
        self.expect(TokenKind.RPAREN)
        return {"type": "call", "name": name, "params": params}

    def parse_cmp(self):
        left = self.parse_add()

        while self.peek().kind in [
            TokenKind.LT,
            TokenKind.GT,
            TokenKind.GE,
            TokenKind.LE,
        ]:
            op_tok = self.advance()
            right = self.parse_add()
            op_map = {
                TokenKind.LT: "<",
                TokenKind.LE: "<=",
                TokenKind.GT: ">",
                TokenKind.GE: ">=",
            }
            left = {
                "type": "binary",
                "op": op_map[op_tok.kind],
                "left": left,
                "right": right,
            }

        return left

    def parse_eq(self):
        left = self.parse_cmp()

        while self.peek().kind in [TokenKind.EQEQ, TokenKind.NEQ]:
            op_tok = self.advance()
            right = self.parse_cmp()
            op_map = {
                TokenKind.EQEQ: "==",
                TokenKind.NEQ: "!=",
            }
            left = {
                "type": "binary",
                "op": op_map[op_tok.kind],
                "left": left,
                "right": right,
            }

        return left

    def parse_expr(self):
        return self.parse_eq()

def find_strings(node):
    if isinstance(node, dict):
        if node.get("type") == "string":
            yield node.get("value")

        for value in node.values():
            yield from find_strings(value)

    elif isinstance(node, list):
        for item in node:
            yield from find_strings(item)
class CodeGen:
    def __init__(self):
        self.lines = []
        self.label_count = 0
        self.strings = {}

    def emit(self, asm):
        self.lines.append(asm)

    def create_label(self, name):
        self.label_count += 1
        return f"L_{name}_{self.label_count}"

    def write_string(self, name):
        if name in self.strings:
            return
        lbl = self.create_label("string")
        self.strings[name] = lbl
        self.emit(f"{lbl}:")
        bts = [ord(c) for c in name] + [0]
        self.emit("    db " + ", ".join(map(str, bts)))


    def compile_program(self, ast):
        for item in ast["items"]:
            if item["type"] == "extrn":
                self.emit(f"extern {item["value"]}")
        self.emit("bits 64")
        self.emit("global main")
        self.emit("section .rodata")

        for string_value in find_strings(ast):
            self.write_string(string_value)
        self.emit("section .text")

        for item in ast["items"]:
            if item["type"] == "function":
                self.compile_function(item)

        return "\n".join(self.lines) + "\n"

    def compile_block(self, block: dict):
        for stmt in block["statements"]:
            self.compile_stmt(stmt)

    def compile_stmt(self, stmt):
        if stmt["type"] == "return":
            self.compile_expr(stmt["value"])
            self.emit("    mov rsp, rbp")
            self.emit("    pop rbp")
            self.emit("    ret")
            return
        if stmt["type"] == "assignment":
            self.compile_expr(stmt["value"])
            if stmt["name"] not in self.local_vars:
                self.offset += 8
                self.local_vars[stmt["name"]] = self.offset
            offset = self.local_vars[stmt["name"]]
            self.emit(f"    mov [rbp - {offset}], rax")
            return
        if stmt["type"] == "ptr_assign":
            self.compile_expr(stmt["lhs"])
            self.emit("    push rax")
            self.compile_expr(stmt["rhs"])
            self.emit("    pop rcx")
            self.emit("    mov qword [rcx], rax")
            return
        if stmt["type"] == "if":
            self.compile_expr(stmt["condition"])
            end_label = self.create_label("if")
            self.emit("    cmp rax, 0")
            self.emit(f"    je {end_label}")
            self.compile_block(stmt["block"])
            self.emit(f"{end_label}:")
            return
        if stmt["type"] == "while":
            start_label = self.create_label("while_start")
            end_label = self.create_label("while_end")
            self.emit(f"{start_label}:")
            self.compile_expr(stmt["condition"])
            self.emit("    cmp rax, 0")
            self.emit(f"    je {end_label}")
            self.compile_block(stmt["block"])
            self.emit(f"    jmp {start_label}")
            self.emit(f"{end_label}:")
            return

        if stmt["type"] == "ifelse":
            self.compile_expr(stmt["condition"])
            else_label = self.create_label("else_start")
            end_label = self.create_label("if_end")

            self.emit("    cmp rax, 0")
            self.emit(f"    je {else_label}")
            self.compile_block(stmt["block"])
            self.emit(f"    jmp {end_label}")
            self.emit(f"{else_label}:")
            self.compile_block(stmt["next"])
            self.emit(f"{end_label}:")

            return

        elif stmt["type"] == "call":
            self.compile_expr(stmt)
            return

            # return {"type": "ifelse", "condition": expr, "block": block, "next": next_block} #

        raise NotImplementedError(stmt["type"])

    def compile_expr(self, expr):
        # return {"type": "variable", "value": tok.value}

        if expr["type"] == "number":
            self.emit(f'    mov rax, {expr["value"]}')
            return
        elif expr["type"] == "string":
            self.emit(f'    lea rax, [rel {self.strings[expr["value"]]}]')
            return
        elif expr["type"] == "call":
            paramr = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]

            params = expr["params"]

            if len(params) > 6:
                assert False, "I don't support more than 6 arguments"

            for pr in params:
                self.compile_expr(pr)
                self.emit("    push rax")

            for i in reversed(range(len(params))):
                self.emit(f"    pop {paramr[i]}")
                pass

            self.emit(f'    xor rax, rax')
            self.emit(f'    call {expr["name"]}')
            return

        elif expr["type"] == "variable":
            offset = self.local_vars[expr["value"]]
            self.emit(f"    mov rax, [rbp - {offset}]")
        elif expr["type"] == "reference":
            offset = self.local_vars[expr["contents"]["value"]]
            self.emit(f"    lea rax, [rbp - {offset}]")
        elif expr["type"] == "dereference":
            self.compile_expr(expr["contents"])
            self.emit("    mov rax, [rax]")
        elif expr["type"] == "binary":
            self.compile_expr(expr["left"])
            self.emit("    push rax")
            self.compile_expr(expr["right"])
            self.emit("    pop rcx")

            if expr["op"] == "+":
                self.emit("    add rax, rcx")
                return
            elif expr["op"] == "-":
                self.emit("    sub rcx, rax")
                self.emit("    mov rax, rcx")
                return
            elif expr["op"] == "*":
                self.emit("    imul rax, rcx")
                return
            elif expr["op"] == "/":
                self.emit("    mov rdi, rax")
                self.emit("    mov rax, rcx")
                self.emit("    cqo")
                self.emit("    idiv rdi")
            elif expr["op"] in ["<", "<=", ">", ">=", "==", "!="]:
                self.emit("    cmp rcx, rax")
                if expr["op"] == "<":
                    self.emit("    setl al")
                elif expr["op"] == "<=":
                    self.emit("    setle al")
                elif expr["op"] == ">":
                    self.emit("    setg al")
                elif expr["op"] == ">=":
                    self.emit("    setge al")
                elif expr["op"] == "==":
                    self.emit("    sete al")
                elif expr["op"] == "!=":
                    self.emit("    setne al")

                self.emit("    movzx rax, al")
                return
        elif expr["type"] == "unary":
            self.compile_expr(expr["value"])
            if expr["op"] == "-":
                self.emit("    neg rax")
                return

    def compile_function(self, fn):
        self.local_vars = {}
        self.offset = 0
        paramr = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]
        for var_name in fn["params"]:
            if not var_name in self.local_vars:
                self.offset += 8
                self.local_vars[var_name] = self.offset
        for stmt in fn["body"]["statements"]:
            if stmt["type"] == "assignment":
                var_name = stmt["name"]
                if not var_name in self.local_vars:
                    self.offset += 8
                    self.local_vars[var_name] = self.offset
        self.emit(f'{fn["name"]}:')
        self.emit("    push rbp")
        self.emit("    mov rbp, rsp")
        self.emit(f"    sub rsp, {self.offset}")
        for i in range(len(fn["params"])):
            self.emit(f"    mov [rbp - {self.local_vars[fn['params'][i]]}], {paramr[i]}")
        self.compile_block(fn["body"])
        self.emit("    mov rax, 0")
        self.emit("    mov rsp, rbp")
        self.emit("    pop rbp")
        self.emit("    ret")

def usage():
    print(f"Usage: tbb [-o output] input")
    print()
    print("Options:")
    print("  -o output   Write executable to output")
    print("  -S          Emit assembly")
    print("  -c          Emit an object file")
    print("  -h, --help  Show this help message")
    print("  --version   Show version info")
    sys.exit(0)

def parse_arguments():
    i = 1
    output = None 
    inputfile = ""
    emit_type = "out"

    while i < len(sys.argv):
        if sys.argv[i] in ("-h", "--help"):
            usage()
        elif sys.argv[i] == "--version":
            print(f"tbb {VERSION}")
            print('''
Copyright (C) 2026 Playful Mathematician
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

Written by Playful Mathematician''')
            sys.exit(0)
        
        if sys.argv[i] == "-o":
            i += 1
            output = sys.argv[i]
        elif sys.argv[i] == "-S":
            i += 1 
            emit_type = "s"

        elif sys.argv[i] == "-c":
            i += 1 
            emit_type = "o"
        else:
            inputfile = sys.argv[i]
        i+=1
    if inputfile == "":
        assert False, "One must have an input"
    if output == None:
        output =  os.path.splitext(inputfile)[0] +  "."+ emit_type
    return (output, inputfile, emit_type)

def main():
    output, inputfile, emit_type = parse_arguments()
    toks = lex_file(read_file(inputfile))
    parser = Parser(toks, input)
    parsed = parser.parse_program()
    codegen = CodeGen()
    asm = codegen.compile_program(parsed)
    with tempfile.TemporaryDirectory() as tmpdir:
        asm_path = os.path.join(tmpdir, "out.s")
        obj_path = os.path.join(tmpdir, "out.o")
        out_path = os.path.join(tmpdir, "out.out")
        with open(asm_path, "w") as f:
            f.write(asm)

        subprocess.run(
            [
                "nasm",
                "-felf64",
                asm_path,
                "-o",
                obj_path,
            ],
            check=True,
        )

        subprocess.run(
            [
                "cc",
                "-no-pie",
                obj_path,
                "-o",
                out_path,
            ],
            check=True,
        )

        if emit_type == "out":
            shutil.copy2(out_path, output)
        elif emit_type == "s":
            shutil.copy2(asm_path, output)
        elif emit_type == "o":
            shutil.copy2(obj_path, output)


if __name__ == "__main__":
    main()

