#!/usr/bin/env python
# I, Danny Milosavljevic, hereby place this file into the public domain

import struct
import time
import symbols
import os

IMAGE_FILE_MACHINE_AMD64 = 0x8664
IMAGE_FILE_MACHINE_I386 = 0x14C

# characteristics
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020
IMAGE_FILE_32BIT_MACHINE = 0x0100
IMAGE_FILE_DEBUG_STRIPPED = 0x0200
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800
IMAGE_FILE_SYSTEM = 0x1000
IMAGE_FILE_DLL = 0x2000
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 # uniprocessor only

# section characteristics

IMAGE_SCN_CNT_CODE = 0x00000020
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080
IMAGE_SCN_LNK_INFO = 0x00000200
IMAGE_SCN_LNK_REMOVE = 0x00000800
IMAGE_SCN_ALIGN_1BYTES = 0x00100000
IMAGE_SCN_ALIGN_2BYTES = 0x00200000
IMAGE_SCN_ALIGN_4BYTES = 0x00300000
IMAGE_SCN_ALIGN_8BYTES = 0x00400000
IMAGE_SCN_ALIGN_16BYTES = 0x00500000
IMAGE_SCN_ALIGN_32BYTES = 0x00600000
#... until 8192 bytes.
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000
IMAGE_SCN_MEM_SHARED = 0x10000000
IMAGE_SCN_MEM_EXECUTE = 0x20000000
IMAGE_SCN_MEM_READ = 0x40000000
IMAGE_SCN_MEM_WRITE = 0x80000000

# relocation type
IMAGE_REL_I386_ABSOLUTE = 0x0000
IMAGE_REL_I386_DIR32 = 0x0006
IMAGE_REL_I386_DIR32NB = 0x0007
IMAGE_REL_I386_REL32 = 0x0014
# section headers in order of RVA values

# section number
IMAGE_SYM_UNDEFINED = 0
IMAGE_SYM_ABSOLUTE = (-1)
IMAGE_SYM_DEBUG = (-2)

# symbol storage class
IMAGE_SYM_CLASS_NULL, IMAGE_SYM_CLASS_AUTOMATIC, IMAGE_SYM_CLASS_EXTERNAL, IMAGE_SYM_CLASS_STATIC, IMAGE_SYM_CLASS_REGISTER, IMAGE_SYM_CLASS_EXTERNAL_DEF, IMAGE_SYM_CLASS_LABEL, IMAGE_SYM_CLASS_UNDEFINED_LABEL, IMAGE_SYM_CLASS_MEMBER_OF_STRUCT = range(9)
IMAGE_SYM_CLASS_UNDEFINED_STATIC = 14

# FIXME alignment
# TODO dollar sign section names
class COFFWriter(object):
    defaultObjectFileNameFormat = "%s.OBJ"
    def __init__(self, name):
        self.name = name
        self.tempName = "%s.tmp" % name
        self.stream = open(self.tempName, "wb")
        self.sectionHeaders = [
            [".bss", IMAGE_SCN_CNT_UNINITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_WRITE, None, 0, 0],
            [".data", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_WRITE, None, 0, 0],
            # debug debug debug
            [".edata", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ, None, 0, 0],
            [".idata", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_WRITE, None, 0, 0],
            #(".idlsym", IMAGE_SCN_LNK_INFO),
            #(".pdata", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ), # exception info
            [".rdata", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ, None, 0, 0], # rodata
            #(".reloc", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_DISCARDABLE), # image relocations
            #(".rsrc", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ),
            #(".sxdata", IMAGE_SCN_LNK_INFO), # symbol indices for each of the exception handlers.
            [".text", IMAGE_SCN_CNT_CODE | IMAGE_SCN_MEM_EXECUTE | IMAGE_SCN_MEM_READ, None, 0, 0],
            #(".tls", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_WRITE),
            #(".tls$", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ | IMAGE_SCN_MEM_WRITE),
            #(".xdata", IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ), # exception user data
        ]
        self.stringTable = []
        self.symbolTable = [None] # 0th is dummy.
    def ensureStringTableEntry(self, text):
        i = ("".join(self.stringTable)).find("\0%s\0" % text)
        if i == -1:
            i = len(self.stringTable)
            for c in "%s\0" % text:
                self.stringTable.append(c)
        else:
            i += 1
        return(i)
    def recordSymbol(self, ID, addr, bResolved, bindingScope, sectionName = None):
        size = 4 # FIXME
        self.symbolTable.append([sectionName, ID, addr, size])
    def getSymbolIndex(self, ID):
        return([i for i, s in enumerate(self.symbolTable) if s[1] == ID])
    def getSectionIndex(self, name):
        return([i for i, h in enumerate(self.sectionHeaders) if h[0] == name][0])
    def recordLocation(self, name, position, size):
        i = self.getSectionIndex(name)
        self.sectionHeaders[i][3] = position
        self.sectionHeaders[i][4] = size
    def writeCOFFHeader(self):
        stream = self.stream
        machine = IMAGE_FILE_MACHINE_I386 # FIXME
        timeDateStamp = int(time.time())
        sectionCount = len(self.sectionHeaders)
        symbolTableOffset = 0 # fixed later
        stream.write(struct.pack("H", machine))
        stream.write(struct.pack("H", sectionCount))
        stream.write(struct.pack("I", timeDateStamp)) # time_t
        stream.write(struct.pack("I", symbolTableOffset))
        stream.write(struct.pack("H", 0)) # optional header size
        stream.write(struct.pack("H", IMAGE_FILE_32BIT_MACHINE)) # characteristics
    def writeSectionHeaders(self):
        for entry in self.sectionHeaders:
            self.writeSectionHeaderEntry(*entry)
    def writeSectionHeaderEntry(self, name, typ, dummy, positionOfRawData, sizeOfRawData):
        stream = self.stream
        characteristics = typ
        positionOfRelocations = 0 # fixed later
        positionOfLineNumbers = 0 # fixed later
        relocationCount = 0 # fixed later
        lineNumberCount = 0
        stream.write(struct.pack("8s", name)) # null-padded, can be "/<number>" to mean an offset into the string table.
        stream.write(struct.pack("I", 0)) # virtual size
        stream.write(struct.pack("I", 0)) # virtual address
        stream.write(struct.pack("I", sizeOfRawData))
        stream.write(struct.pack("I", positionOfRawData)) # aligned to multiple of 4.
        stream.write(struct.pack("I", positionOfRelocations))
        stream.write(struct.pack("I", positionOfLineNumbers))
        stream.write(struct.pack("H", relocationCount))
        stream.write(struct.pack("H", lineNumberCount))
        stream.write(struct.pack("I", characteristics))
        # 40 bytes
        # Symbol Table: Auxiliary Format 5: Section Definitions: 
    #Code
    #Data
    #Debug Information
    def writeSymbolTableEntry(self, sectionName, name, addr, size):
        stream = self.stream
        typ = 0 # not a function; 0x20 would be function
        storageClass = IMAGE_SYM_CLASS_STATIC
        assert(len(name) <= 8)
        auxSymbolCount = 0
        sectionNumber = IMAGE_SYM_UNDEFINED if sectionName is None else self.getSectionIndex(sectionName)
        stream.write(struct.pack("8s", name)) # FIXME union of 3: first four bytes are 0 if it's a string table offset
        stream.write(struct.pack("I", addr)) # FIXME sometimes something else
        stream.write(struct.pack("H", sectionNumber))
        stream.write(struct.pack("H", typ))
        stream.write(struct.pack("B", storageClass))
        stream.write(struct.pack("B", auxSymbolCount))
    def writeSymbolTable(self):
        beginning = self.stream.tell()
        for entry in self.symbolTable[1:]:
            self.writeSymbolTableEntry(*entry)
        #self.recordLocation(".symtab", beginning, self.stream.tell() - beginning)
    def writeStringTable(self):
        stream = self.stream
        body = self.stringTable
        stream.write(struct.pack("I", len(body) + 4))
        stream.write("".join(body))
    def writeRelocationEntry(self, addr, symbolTableIndex, typ):
        stream = self.stream
        stream.write(struct.pack("I", addr)) # RVA+bodyOffset
        stream.write(struct.pack("I", symbolTableIndex))
        stream.write(struct.pack("H", typ))
    #limit number of sections to 96.
    def writeHeader(self):
        self.writeCOFFHeader()
        self.writeSectionHeaders()
        self.writeSymbolTable()
        self.writeStringTable()
    def finish(self):
        self.stream.seek(0)
        self.writeCOFFHeader() # fix them up
        self.writeSectionHeaders() # fix them up
    def close(self):
        self.stream.close()
        try:
            os.unlink(self.name)
        except:
            pass
        os.rename(self.tempName, self.name)
    # TODO move this to common base class?
    def writeData(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".data", beginning, self.stream.tell() - beginning)
    def writeBSS(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".bss", beginning, self.stream.tell() - beginning)
    def writeText(self, data):
        beginning = self.stream.tell()
        for item in data:
            self.stream.write(chr(item))
        self.recordLocation(".text", beginning, self.stream.tell() - beginning)
    def writeRoData(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".rdata", beginning, self.stream.tell() - beginning)
    def write(self, segments):
        self.writeHeader()
        self.writeText(segments[symbols.intern(".text")].data)
        # TODO self.writeRelText()
        self.writeData(segments[symbols.intern(".data")].data)
        self.writeBSS(segments[symbols.intern(".bss")].data)
        # TODO self.writeNote()
        self.writeRoData(segments[symbols.intern(".rodata")].data)
        self.finish()
        # etc
        