#!/usr/bin/env python
# I, Danny Milosavljevic, hereby place this file into the public domain

import struct
import symbols
import os
import bisect

SHT_NULL, SHT_PROGBITS, SHT_SYMTAB, SHT_STRTAB, SHT_RELA, SHT_HASH, SHT_DYNAMIC, SHT_NOTE, SHT_NOBITS, SHT_REL, SHT_SHLIB, SHT_DYNSYM = range(12)
R_386_NONE, R_386_32, R_386_PC32, R_386_GOT32, R_386_PLT32, R_386_COPY, R_386_GLOB_DAT, R_386_JMP_SLOT, R_386_RELATIVE = range(9)

SHF_WRITE, SHF_ALLOC, SHF_EXECINSTR, SHF_DUMMY, SHF_MERGE, SHF_STRINGS, SHF_INFO_LINK, SHF_LINK_ORDER, SHF_OS_NONCONFORMING = [(1<<x) for x in range(9)]

class ELFWriter(object):
    defaultObjectFileNameFormat = "%s.o"
    ELFMagic = "\x7f\x45\x4c\x46\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00"
    offFormat = "I"
    addrFormat = "I"
    bigFormat = "I"
    bigSignedFormat = "I"
    typeMachineVersion = (1, 3, 1)
    knownBindings = {
        symbols.intern("global"): 0x10, 
        symbols.intern("local"): 0x00, 
        symbols.intern("weak"): 0x20, 
        symbols.intern("sections"): 3
    }
    def __init__(self, name):
        self.name = name
        self.tempName = "%s.tmp" % name
        stream = open(self.tempName, "wb")
        self.stream = stream
        # section header struct ELF32_Shdr
        self.sectionHeaders = [
            ["", SHT_NULL, "", 0, 0],
            [".text", SHT_PROGBITS, "", 0, 0],
            [".rel.text", SHT_REL, ".symtab", 0, 0],
            [".data", SHT_PROGBITS, "", 0, 0],
            [".bss", SHT_NOBITS, "", 0, 0],
            #[".note", SHT_NOTE, "", 0, 0],
            [".rodata", SHT_PROGBITS, "", 0, 0],
            #[".comment", SHT_PROGBITS, "", 0, 0],
            [".shstrtab", SHT_STRTAB, "", 0, 0],
            [".symtab", SHT_SYMTAB, ".strtab", 0, 0], # info has the first global
            [".strtab", SHT_STRTAB, "", 0, 0],
        ]
        self.fixups = []
        self.sectionHeaderStringTable = ["\0", "\0"]
        for h in self.sectionHeaders:
            self.ensureSectionHeaderStringTableEntry(h[0])
        self.stringTable = []
        self.symbolTable = []
        self.ensureStringTableEntry("")
        self.firstGlobalSymbolIndex = 0
        self.recordSymbol("", 0, True, symbols.intern("local")) # dummy entry for "undefined"
        # TODO move that to addSection or whatever.
        self.recordSymbol(".text", 0, True, symbols.intern("sections"), ".text")
        self.recordSymbol(".data", 0, True, symbols.intern("sections"), ".data")
        self.recordSymbol(".bss", 0, True, symbols.intern("sections"), ".bss")
    def ensureStringTableEntry(self, text):
        i = ("".join(self.stringTable)).find("\0%s\0" % text)
        if i == -1:
            i = len(self.stringTable)
            for c in "%s\0" % text:
                self.stringTable.append(c)
        else:
            i += 1
        return(i)
    def recordSymbol(self, ID, addr, bResolved, bindingScope, sectionName = None):
        size = 4 # FIXME
        # insert local symbols before global symbols.
        entry = [bindingScope, sectionName, ID, addr, size]
        if bindingScope != symbols.intern("global"):
            self.symbolTable.insert(self.firstGlobalSymbolIndex, entry)
            self.firstGlobalSymbolIndex += 1
        else:
            self.symbolTable.append(entry)
    def getSymbolIndex(self, ID):
        # FIXME section name.
        # TODO proper error if symbol is unknown?
        return([i for i, s in enumerate(self.symbolTable) if s[2] == ID][0])
    def ensureSectionHeaderStringTableEntry(self, text):
        i = ("".join(self.sectionHeaderStringTable)).find("\0%s\0" % text)
        if i == -1:
            i = len(self.sectionHeaderStringTable)
            for c in "%s\0" % text:
                self.sectionHeaderStringTable.append(c)
        else:
            i += 1
        return(i)
    def writeELFHeader(self):
        stream = self.stream
        o = stream.tell()
        assert(len(self.__class__.ELFMagic) == 16)
        stream.write(self.__class__.ELFMagic)
        stream.write(struct.pack("HHI", *self.__class__.typeMachineVersion)) # type, machine, version
        stream.write(struct.pack(self.__class__.addrFormat, 0)) # entry point
        stream.write(struct.pack(self.__class__.offFormat, 0)) # program header table file offset
        stream.write(struct.pack(self.__class__.offFormat, 0)) # section header file offset
        stream.write(struct.pack("I", 0)) # processor flags
        stream.write(struct.pack("H", 52 + (12 if self.__class__.addrFormat == "Q" else 0))) # ELF header size in bytes
        stream.write(struct.pack("H", 0)) # program header size in bytes
        stream.write(struct.pack("H", 0)) # program header table entry count
        stream.write(struct.pack("H", 40 + ((6*4) if self.__class__.addrFormat == "Q" else 0))) # Section header entry size
        stream.write(struct.pack("H", len(self.sectionHeaders))) # Section header entry count
        stream.write(struct.pack("H", self.getSectionIndex(".shstrtab")))
        assert(stream.tell() - o == 52 + (12 if self.__class__.addrFormat == "Q" else 0))
    def getSectionIndex(self, name):
        return(([i for i, h in enumerate(self.sectionHeaders) if h[0] == name] + [None])[0])
    def writeSectionHeaders(self):
        for entry in self.sectionHeaders:
            self.writeSectionHeaderEntry(*entry)
    def writeSectionHeaderEntry(self, name, typ, link, position, size):
        tableSize = 1
        alignment = 0
        info = 0
        flags = 0
        if name == ".text":
            alignment = 0x10
            # TODO configurable in the actual header?
            flags = SHF_ALLOC | SHF_EXECINSTR
        elif name == ".rel.text" or name == ".rela.text":
            tableSize = (0xc if name.startswith(".rela.") else 8) * (2 if self.__class__.addrFormat == "Q" else 1)
            alignment = 4 if self.__class__.addrFormat != "Q" else 8
            info = self.getSectionIndex(".text") # TODO how to find out which?
        elif name == ".data":
            alignment = 4 if self.__class__.addrFormat != "Q" else 8
            flags = SHF_ALLOC | SHF_WRITE
        elif name == ".bss":
            alignment = 4 if self.__class__.addrFormat != "Q" else 8
            flags = SHF_ALLOC | SHF_WRITE
        elif name == ".symtab":
            tableSize = 16 + (8 if self.__class__.addrFormat == "Q" else 0)
            alignment = 4 if self.__class__.addrFormat != "Q" else 8
            info = self.firstGlobalSymbolIndex
        stream = self.stream
        stream.write(struct.pack("I", self.ensureSectionHeaderStringTableEntry(name))) # string table index for name.
        stream.write(struct.pack("I", typ))
        stream.write(struct.pack(self.__class__.bigFormat, flags)) # flags
        stream.write(struct.pack(self.__class__.addrFormat, 0)) # virtual addr
        stream.write(struct.pack(self.__class__.offFormat, position)) # file offset
        stream.write(struct.pack(self.__class__.bigFormat, size)) # size in bytes
        stream.write(struct.pack("I", self.getSectionIndex(link))) # section link
        stream.write(struct.pack("I", info)) # info
        stream.write(struct.pack(self.__class__.bigFormat, alignment))
        stream.write(struct.pack(self.__class__.bigFormat, tableSize))
    def writeSymbolTableEntry(self, bindingScope, sectionName, name, addr, size):
        stream = self.stream
        o = stream.tell()
        nameIndex = self.ensureStringTableEntry(name)
        size = 0
        info = self.__class__.knownBindings[bindingScope] # TODO other types than "NOTYPE".
        other = 0
        link = ([i for i, h in enumerate(self.sectionHeaders) if h[0] == sectionName] + [0])[0] # section index, if known
        stream.write(struct.pack("I", nameIndex)) # string table index for name.
        if self.__class__.bigFormat == "Q":
            stream.write(struct.pack("B", info))
            stream.write(struct.pack("B", other))
            stream.write(struct.pack("H", link))
        stream.write(struct.pack(self.__class__.addrFormat, addr)) # value (addr)
        stream.write(struct.pack(self.__class__.bigFormat, size)) # size
        if self.__class__.bigFormat != "Q":
            stream.write(struct.pack("B", info))
            stream.write(struct.pack("B", other))
            stream.write(struct.pack("H", link))
        #link = self.getSectionSymbolIndex(sectionName) if bindingScope != "sections" else 
        assert(stream.tell() - o == 16 + (8 if self.__class__.bigFormat == "Q" else 0))
    def getSectionSymbolIndex(self, name):
        return ([i for i, h in enumerate(self.symbolTable) if h[1] == name and h[5] == symbols.intern("sections")] + [0])[0]
    def recordLocation(self, name, position, size):
        i = self.getSectionIndex(name)
        self.sectionHeaders[i][3] = position
        self.sectionHeaders[i][4] = size
    def writeNull(self):
        beginning = self.stream.tell()
        self.recordLocation("", beginning, self.stream.tell() - beginning)
    def writeHeader(self):
        self.writeELFHeader()
        self.writeNull()
    def writeText(self, data):
        beginning = self.stream.tell()
        for item in data:
            self.stream.write(chr(item))
        self.recordLocation(".text", beginning, self.stream.tell() - beginning)
    def writeRelTextEntry(self, addr, symbolIndex, offset, typ):
        assert(offset == 0)
        stream = self.stream
        stream.write(struct.pack(self.__class__.addrFormat, addr))
        stream.write(struct.pack(self.__class__.bigFormat, (symbolIndex << 8) | typ)) # symbol = value >> 8, type = value & 0xFF
    def writeRelaTextEntry(self, addr, symbolIndex, offset, typ):
        stream = self.stream
        stream.write(struct.pack(self.__class__.addrFormat, addr))
        stream.write(struct.pack(self.__class__.bigFormat, (symbolIndex << 8) | typ)) # symbol = value >> 8, type = value & 0xFF
        stream.write(struct.pack(self.__class__.bigSignedFormat, offset))
    def writeRelaText(self):
        beginning = self.stream.tell()
        for relocation in self.fixups:
            offset = relocation[2]
            if offset != 0:
                self.writeRelaTextEntry(*relocation)
        self.recordLocation(".rela.text", beginning, self.stream.tell() - beginning)
    def writeRelText(self):
        beginning = self.stream.tell()
        for relocation in self.fixups:
            offset = relocation[2]
            if offset == 0:
                self.writeRelTextEntry(*relocation)
            # otherwise handled in writeRelText()
        self.recordLocation(".rel.text", beginning, self.stream.tell() - beginning)
    def writeData(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".data", beginning, self.stream.tell() - beginning)
    def writeBSS(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".bss", beginning, self.stream.tell() - beginning)
    def writeNote(self):
        beginning = self.stream.tell()
        pass
        self.recordLocation(".note", beginning, self.stream.tell() - beginning)
    def writeRoData(self, data):
        beginning = self.stream.tell()
        stream = self.stream
        for item in data:
            stream.write(chr(data))
        self.recordLocation(".rodata", beginning, self.stream.tell() - beginning)
    def writeComment(self):
        beginning = self.stream.tell()
        pass
        self.recordLocation(".comment", beginning, self.stream.tell() - beginning)
    def writeSectionHeaderStringTable(self):
        beginning = self.stream.tell()
        self.stream.write("".join(self.sectionHeaderStringTable))
        self.recordLocation(".shstrtab", beginning, self.stream.tell() - beginning)
    def writeSymbolTable(self):
        beginning = self.stream.tell()
        for entry in self.symbolTable:
            self.writeSymbolTableEntry(*entry)
        self.recordLocation(".symtab", beginning, self.stream.tell() - beginning)
    def writeStringTable(self):
        beginning = self.stream.tell()
        self.stream.write("".join(self.stringTable))
        self.recordLocation(".strtab", beginning, self.stream.tell() - beginning)
    def finish(self):
        #self.writeComment()
        self.writeSectionHeaderStringTable()
        self.writeSymbolTable()
        self.writeStringTable()
        sectionHeaderOffset = self.stream.tell()
        self.writeSectionHeaders()
        self.stream.seek(4 * 8 + (8 if self.__class__.addrFormat == "Q" else 0))
        self.stream.write(struct.pack("I", sectionHeaderOffset))
    def close(self):
        self.stream.close()
        os.rename(self.tempName, self.name)
    def recordFixup(self, name, offset, targetAddr, targetSize, method):
        if offset != 0: # we need RELA, so drop REL.
            if getSectionIndex(".rela.text") is None:
                i = getSectionIndex(".rel.text")
                self.sectionHeaders[i] = [".rela.text", SHT_RELA, ".symtab", 0, 0]
            
        bRelative = method # TODO more
        self.fixups.append([targetAddr, self.getSymbolIndex(name), offset, R_386_PC32 if bRelative else R_386_32])
    def write(self, segments):
        writer = self
        writer.writeHeader()
        writer.writeText(segments[symbols.intern(".text")].data)
        writer.writeRelText()
        #writer.writeRelaText()
        writer.writeData(segments[symbols.intern(".data")].data)
        writer.writeBSS(segments[symbols.intern(".bss")].data)
        #writer.writeNote()
        writer.writeRoData(segments[symbols.intern(".rodata")].data)
        writer.finish()
