#!/usr/bin/env python
# I, Danny Milosavljevic, hereby place this file into the public domain

import symbols
from core import SegmentAssembler
from core import Label
from core import Register
from core import registers
from core import Operation
import sys

"""
general X86 instruction layout:
    prefix 0..4 bytes
    opcode 1..2 bytes
    modrm 1 byte
    sib 1 byte
    displacement 1 byte or word
    immediate 1 byte or word
modrm=
    76543210
    |||||||+- reg1
    ||||||+-- reg1
    |||||+--- reg1
    ||||+---- reg2
    |||+----- reg2
    ||+------ reg2
    |+------- mod
    +-------- mod
mod=
  00 indirect
  11 direct
  01 indirect with byte offset
  10 indirect with word offset
"""

def getOperationArgumentCount(opcode):
    if opcode >= 0x8000 and opcode <= 0x8F0F and (opcode & 0xFF) == 0x0F: # jcond relative 16/32
        return(1)
    elif opcode >= 0xF8 and opcode <= 0xFD: # CLC etc
        return(0)
    elif opcode in [0xCB, 0xCF]: # ret
        return(0)
    elif opcode in [0x8F, 0x00FF, 0x01FF, 0x02FF, 0x04FF, 0x02F7, 0x03F7, 0xE9, 0xCD]: # pop, jmp
        return(1)
    return(2)
for i, name in enumerate(["EAX", "ECX", "EDX", "EBX", "ESP", "EBP", "ESI", "EDI"]): # MUST only be 8 registers for general-purpose instructions.
    registers[symbols.intern(name)] = Register(symbols.intern(name), i)
opcodes = {
        # TODO IN OUT
        "ADDMR": 0x01,
        "ADDRM": 0x03,
        "ORMR": 0x09,
        "OMRR": 0x0B,
        "ADCMR": 0x11,
        "ADCRM": 0x13,
        "PUSHSS": 0x16,
        "POPSS": 0x17,
        "SBBMR": 0x19,
        "SBBRM": 0x1B,
        "ANDMR": 0x21,
        "ANDRM": 0x23,
        "SUBMR": 0x29,
        "SUBRM": 0x2B,
        "XORMR": 0x31,
        "XOMRR": 0x33,
        "CMPMR": 0x39,
        "CMPRM": 0x3B,
        "MOVMR": 0x89, # bit 1 (direction) is clear
        "MOVRM": 0x8B, # bit 1 (direction) is set
        "MOVSM": 0x8C, # segment to memory/reg
        "LEARM": 0x8D,
        "MOVMS": 0x8E,
        "POPM": 0x8F,
        "XCHGMM": 0x87, # xchg
        # 0x9A CALLF
        "PUSHF": 0x9C,
        #"PUSHFD": 0x9D,
        "POPF": 0x9D,
        #"POPFD": 0x9D,
        #ROL ROR RCL RCR SHL SHR SAL SAR
        "ROLIM": 0x00C1,
        "RORIM": 0x01C1,
        "RCLIM": 0x02C1,
        "RCRIM": 0x03C1,
        "SHLIM": 0x04C1,
        "SHRIM": 0x05C1,
        "SARIM": 0x07C1, # FIXME
        # LES load far pointer 0xC4
        # LDS load far pointer 0xC5
        "ENTER": 0xC8,
        "LEAVE": 0xC9,
        "RETC": 0xCA,
        "RET": 0xCB,
        "INT3": 0xCC,
        "INTI": 0xCD,
        "IRET": 0xCF,
        "CALLr": 0xE8,
        "JMPr": 0xE9,
        "JMPrsmall": 0xEB,
        "TESTIM": 0x00F7,
        "NOTM": 0x02F7,
        "NEGM": 0x03F7,
        "UMULM": 0x04F7,
        "UDIVM": 0x05F7,
        "IMULM": 0x06F7,
        "IDIVM": 0x07F7,
        "CLC": 0xF8,
        "STC": 0xF9,
        "CLI": 0xFA,
        "STI": 0xFB,
        "CLD": 0xFC,
        "STD": 0xFD,
        # "INCR": 0x40, # +r
        # "DECR": 0x48, # +r
        # "PUSHR": 0x50, # +r
        # "POPR": 0x60 # +r
        "INCM": 0x00FF,
        "DECM": 0x01FF,
        "CALL": 0x02FF,
        #"CALLF": (0xFF, 3),
        "JMPi": 0x04FF,
        #"JMP": (0xFF, 4),
        "PUSHM": 0x06FF,
        "JOr": 0x800F,
        "JNOr": 0x810F,
        "JBr": 0x820F,
        "JCr": 0x820F,
        "JNBr": 0x830F,
        "JNCr": 0x830F,
        "JEr": 0x840F,
        "JNEr": 0x850F,
        "JBEr": 0x860F,
        "JNBEr": 0x870F,
        "JSr": 0x880F,
        "JNSr": 0x890F,
        "JPEr": 0x8A0F,
        "JPOr": 0x8B0F,
        "JLr": 0x8C0F,
        "JNLr": 0x8D0F,
        "JLEr": 0x8E0F,
        "JNLEr": 0x8F0F,
        #"JNErsmall": ,
}
opcodes = dict([(symbols.intern(n), v) for n, v in opcodes.items()])
def getRegisterCode(r):
    return(r.code)
def registerP(i):
    return(isinstance(i, Register))
def operandRelativeP(opcode):
    return((opcode >= 0x8000 and opcode <= 0x8F0F and (opcode & 0xFF) == 0x0F) or \
           (opcode == 0xE9))
def getOperatorDirection(opcode):
    assert(opcode & 1) # 32 bit, just in order to make this more resilent
    # FIXME do that in a nicer way.
    return(opcode & 2) >> 1 # 1=destination is register
def immediateOperatorP(opcode):
    return(opcode & 128) != 0
operandSizes = {
    0xCD: 1,
    0xB0: 1,
    0xB1: 1,
    0xB2: 1,
    0xB3: 1,
    0xB4: 1,
    0xB5: 1,
    0xB6: 1,
    0xB7: 1,
}
def getRequiredOperandSize(opcode):
    return operandSizes.get(opcode) or 4
class Assembler(object):
    def __init__(self):
        self.segments = {
            symbols.intern(".text"): SegmentAssembler(symbols.intern(".text")),
            symbols.intern(".data"): SegmentAssembler(symbols.intern(".data")),
            symbols.intern(".rodata"): SegmentAssembler(symbols.intern(".rodata")),
            symbols.intern(".bss"): SegmentAssembler(symbols.intern(".bss")),
        }
        self.segment = self.segments[symbols.intern(".text")]
        self.globals = {}
    def ensureGlobal(self, label):
        if label not in self.globals:
            self.globals[label] = [True]
    def globalP(self, label):
        return(label in self.globals)
    def immediateP(self, i):
        return(isinstance(i, int) or isinstance(i, Label))
    def addBuiltin(self, item):
        if item.operator == symbols.intern(".DB"):
            operand = item.operands[0]
            if isinstance(operand, int):
                self.segment.add(operand)
            else:
                for c in operand:
                    self.segment.add(ord(c))
        elif item.operator == symbols.intern(".GLOBAL"):
            operand = item.operands[0]
            self.ensureGlobal(operand)
        elif item.operator == symbols.intern(".EXTERNAL"):
            operand = item.operands[0]
            self.ensureGlobal(operand)
        else:
            sys.stderr.write("error: unknown builtin %r\n" % item.operator)
            sys.exit(1)
    def add(self, item):
        if isinstance(item, Label):
            assert(not item.bResolved)
            item.resolve(self.segment, len(self.segment.data))
        elif isinstance(item, Operation) and item.operator.name.startswith("."): # builtin
            self.addBuiltin(item)
        elif isinstance(item, Operation):
            PC = len(self.segment.data)
            opcode = opcodes[item.operator]
            if opcode == 0x8B and len(item.operands) == 2 and self.immediateP(item.operands[0]) and registerP(item.operands[1]):
                # mov 1, %e* has its very own instruction
                opcode = 0xB8 + getRegisterCode(item.operands[1])
                f = item.operands[0]
                #if isinstance(f, int) and f <= 127 and f >= -128: # byte-sized
                #    opcode = [0xB0+getRegisterCode(item.operands[1])]
                #    will not fill up.
                self.segment.add(opcode)
            else:
                wopcode = opcode
                while wopcode != 0:
                    self.segment.add(wopcode & 0xFF)
                    wopcode >>= 8
                if len(item.operands) > 0 and (not self.immediateP(item.operands[0]) or len(item.operands) > 1):
                    self.segment.add(self.getModMR(opcode, item))
                # TODO displacement etc.
            PCbeforeImmediate = len(self.segment.data)
            bRelative = operandRelativeP(opcode)
            for operand in item.operands:
                if self.immediateP(operand):
                    operandSize = getRequiredOperandSize(opcode)
                    if isinstance(operand, Label):
                        return(self.segment.addLabelReference(operand, operandSize, bRelative))
                    else:
                        self.segment.add(operand, operandSize)
                break # only the source can be immediate
    def getModMR(self, opcode, item):
        """
modrm=
    76543210
    |||||||+- r/m or dest of constant
    ||||||+-- r/m or dest of constant
    |||||+--- r/m or dest of constant
    ||||+---- reg
    |||+----- reg
    ||+------ reg
    |+------- mod
    +-------- mod
mod=
  00 indirect
  11 register
  01 indirect with byte offset (1-byte signed)
  10 indirect with word offset (4-byte signed)
if getOperatorDirection() == 0, reg is the source.
if getOperatorDirection() == 1, reg is the destination.
"""
        assert(len(item.operands) > 0)
        operands = item.operands + [None]
        modrm = 0 # FIXME 0xC0
        direction = getOperatorDirection(opcode)
        if not self.immediateP(operands[0]):
            modrm |= 0xC0 | (getRegisterCode(operands[0]) << 3)
        else:
            # TODO these bits are "opcode extension" instead.
            #assert((opcode & 2) == 0) # same-size operand
            assert((opcode & 1) == 1) # 32-bit operand
            direction = 0
            pass
        if operands[1] is not None:
            modrm |= getRegisterCode(operands[1])
        return(modrm)
    def dump(self):
        segment = self.segment
        for item in segment.data:
            print ("%02X" % item),
        print
