#!/usr/bin/env python
# I, Danny Milosavljevic, hereby place this file into the public domain.

# TODO volume label

import struct
import entries
from entries import DiskEntry

blocks_frontier = 40 * 21
sector_size = 256 # bytes
usable_block_size = 254 # bytes
sector_counts = [0] + [21]*17 + [19]*(24-18+1) + [18]*(30-25+1) + [17]*(40-31+1) # per track, track 0 is dummy.
track_block = []
total_count = 0
for count in sector_counts:
    track_block.append(total_count)
    total_count += count
def block_from_TS(track, sector):
    # track is 1-based, sector is 0-based.
    assert(track >= 1)
    return(track_block[track] + sector)
assert(block_from_TS(1, 0) == 0)
assert(block_from_TS(2, 0) == 21)
assert(block_from_TS(3, 0) == 21 * 2)

class BlockAccessor(object): # error info is bolted on since it is too low-level anyway.
    def __init__(self, stream):
        self.stream = stream
        self.block_size = sector_size
    def read(self, block_number):
        offset = self.block_size * block_number
        self.stream.seek(offset)
        data = self.stream.read(self.block_size)
        return(data)
    def write(self, block_number, data):
        offset = self.block_size * block_number
        self.stream.seek(offset)
        assert(len(data) == self.block_size)
        self.stream.write(data)
def get_node_size(data):
    T, S = parse_TS(data)
    if T == 0:
        return(S + 1 - 2)
    else:
        return(sector_size - 2)
def get_next_node(data):
    # TODO limit to #blocks max in order to be fault-tolerant.
    #data = block_accessor.read(block_number)
    T, S = parse_TS(data)
    if T == 0:
        return(-1)
    else:
        return(block_from_TS(T, S))
def get_chain_and_size(block_accessor, block_number): # in bytes
    # each sector starts with a Track/Sector link to the next. If track=0: last (then sector=$FF <- usually how much of the sector is used).
    chain = []
    chain_set = set()
    total_size = 0
    while block_number != -1 and len(chain) < blocks_frontier:
        chain.append(block_number)
        if block_number in chain_set:
            print("chain is corrupt") # FIXME proper error handling
        chain_set.add(block_number)
        data = block_accessor.read(block_number)
        mini_size = get_node_size(data)
        total_size += mini_size
        block_number = get_next_node(data)
    return(chain, total_size)
class FileStream(object): # this fetches ahead 1 block in order to avoid reading stuff over and over again.
    def __init__(self, block_accessor, first_block):
        self.block_accessor = block_accessor
        self.chain, self.size = get_chain_and_size(block_accessor, first_block)
        self.block_number = first_block
        self.position = 0
        self.block_data = ""
        self.seek(0)
    def seek(self, position, whence = 0):
        assert(whence == 0)
        self.position = position
        self.fetch_block()
    def fetch_block(self):
        i = self.position // usable_block_size
        if i < len(self.chain):
            self.block_number = self.chain[i]
            self.block_data = self.block_accessor.read(self.block_number)[2:] # only the usable part.
        else:
            self.block_number = -1
            self.block_data = ""
    def read(self, count):
        result = []
        assert(count > 0)
        data_offset = self.position % usable_block_size
        if not (self.position + count <= self.size):
            count = self.size - self.position
        while count > 0:
            local_count = min(count, usable_block_size - data_offset)
            local_data = self.block_data[data_offset : data_offset + local_count]
            assert(len(local_data) == local_count)
            result.append(local_data)
            count -= local_count
            old_position = self.position
            self.position += local_count
            if self.position // usable_block_size != old_position // usable_block_size:
                self.fetch_block()          
        return("".join(result))
class FileAccessor(object):
    def __init__(self, block_accessor):
        self.block_accessor = block_accessor
    def get_size(self, first_block):
        chain, size = get_chain_and_size(self.block_accessor, first_block)
        return(size)
    def stat(self, directory_entry, first_block):
        get_readable_type(directory_entry["file_type"])
        self.get_size(first_block)
        pass # TODO
    def open(self, first_block, directory_entry = None):
        return(FileStream(self.block_accessor, first_block))
# http://www.unusedino.de/ec64/technical/formats/d64.html
# size 174848. Then optional error info, 1 byte per sector.
# 35 tracks (1..35), each having a variable number of sections.
# track 1 is outermost, largest.
# track 1 to 17: 21 sectors.
# track 18 to 24: 19 sectors.
# track 25 to 30: 18 sectors.
# track 31 to 40 (or 35): 17 sectors.

# sector index starts at 0.
# track 18, sector 0: BAM area
"""
$00..$01 T/S for directory (junk)
$02 DOS version, default = $41. $00 OK. Others invalid (soft write protection).
$03 unused
$04..$8F BAM entries for each track, 4 bytes per track.
$90..$9F disk name
$A0..$A1 filled with $A0
$A2..$A3 Disk ID
$A4 filled with $A0
$A5..$A6 DOS type, default "2A"
$A7..$AA filled with $A0
$AB unused
$AC..$BF Dolphin DOS track 36..40 BAM entries for 40-track floppies
$C0..$D3 SPEED DOS track 36..40 BAM entries for 40-track floppies

BAM entry:
bit=1=free.
"""
def parse_BAM(data):
    r = {}
    r["next_T"], r["next_S"], r["DOS_version"], r["unused"] = struct.unpack("BBBB", data)
    r["disk_name"], r["$A0"], r["$A1"], r["disk_ID"], r["$A4"], r["DOS_type"] = struct.unpack("16sBB2sB2s", data[0x90:])
    r["BAM"] = []
    for i in range(0x04, 0x90, 4):
        r["BAM"].append(struct.unpack("<I", data[i:i+4])) # per track.
    # TODO Dolphin, SPEED.
    return(r)
def parse_TS(data):
    T, S = struct.unpack("BB", data[:2])
    return(T, S)

# track 18, sector 1: directory track.

"""
directory entry: $20 bytes.
$00..$01 next T/S
$02 file type (0=deleted) and flags (locked, closed)
$03..$04 file data T/S
$05..$14 16 character file name (PETSCII, $A0 pad)
$15..$16 T/S first side-sector block (REL file only)
$17 REL file record length (max. value 254)
$1E..$1F file size in sectors. (#sectors * 254 is approx. file size)
"""
def get_readable_type(value):
    return {
        0: "deleted",
        0x80: "DEL",
        0x81: "SEQ",
        0x82: "PRG",
        0x83: "USR",
        0x84: "REL",
    }.get(value & 7) or hex(value & 7)
def parse_directory_entry(data):
    file_type, file_data_T, file_data_S, file_name, REL_side_sector_T, REL_side_sector_S, REL_record_length, Dx18, Dx19, Dx1A, Dx1B, Dx1C, Dx1D, file_size_sectors = struct.unpack("<BBB16sBBBBBBBBBH", data[:0x1E])
    return(DiskEntry(file_type, file_data_T, file_data_S, file_name, REL_side_sector_T, REL_side_sector_S, REL_record_length, Dx18, Dx19, Dx1A, Dx1B, Dx1C, Dx1D, file_size_sectors))
def parse_pure_directory(directory_stream):
    i = 0
    while True:
        if (i % 8) != 0: # not a sector boundary, so read the junk (i.e. T/S) ourselves.
            directory_stream.read(2) # junk
        data = directory_stream.read(30)
        if data == "": # EOF
            break
        entry = parse_directory_entry(data)
        yield(entry)
        i += 1
def parse_directory(directory_stream):
    for entry in parse_pure_directory(directory_stream):
        if entry.file_type != 0:
            yield(entry)
if __name__ == "__main__":
    block_accessor = BlockAccessor(open("disk-23.d64", "rb"))
    file_accessor = FileAccessor(block_accessor)
    directory_stream = file_accessor.open(block_from_TS(18, 1))
    for entry in parse_directory(directory_stream):
        print(entry)
