#!/usr/bin/env python

import sys
import symbols
try:
	import io
except ImportError:
	import StringIO as io

class ParseError(Exception):
	pass

def maybe_ord(c):
	return(ord(c) if not isinstance(c, int) else c)
def maybe_chr(c):
	return(chr(c) if isinstance(c, int) else c)
def whitespace_P(c):
	return c in [b"\x00", b"\x09", b"\x0A", b"\x0C", b"\x0D", b"\x20", b"%"]
def number_P(c):
	return c in b"0123456789+-." # FIXME can a number start with "." ?
def list_P(c):
	return c == b"["
def dictionary_or_string_P(c):
	return c == b"<" or c == b"("
def string_P(c): # not unique.
	return c == b"(" or c == b"<"
def boolean_P(c): # not unique
	return c == b"t" or c == b"f"
def name_P(c):
	return c == b"/"
def n_P(c):
	return c == b"n"
hexdigits = b"0123456789abcdefABCDEF"
def lower2(c):
	if c >= b"A" and c <= b"Z":
		result = byr(ord(b"a") + maybe_ord(c) - ord(b"A"))
		return(result)
	else:
		return c
def byr(c):
	result = bytes([c])
	if isinstance(result, str):
		return(chr(c))
	else:
		return(result)
class Scanner(object):
	def __init__(self):
		self.input_file = None
		self.next_input_files = []
		self.position = -1
		self.input = None
	def start_parsing(self, input_file):
		self.input_file = input_file
	def seek(self, position, whence = 0):
		self.input_file.seek(position, whence)
		self.input = None
		self.position = self.input_file.tell() - 1
		self.consume()
	def raise_error(self, expected_text, got_text = None):
		raise ParseError("error: expected %r but got %r near offset %d" % (expected_text, got_text or self.input, self.position))
	def consume(self, expected_text = None):
		c = 1 if expected_text is None else len(expected_text)
		for i in range(c):
			old_input = self.input
			#print(self.position, self.input_file.tell())
			#assert(self.position == self.input_file.tell())
			if expected_text is not None and maybe_ord(expected_text[i]) != maybe_ord(old_input):
				self.raise_error(expected_text)
			while self.input_file is not None:
				self.input = self.input_file.read(1)
				if len(self.input) > 0:
					assert(isinstance(self.input, bytes))
					break
				self.input_file = self.goto_next_file()
			self.position += 1
		#if old_input != b"\x0D":
		#	sys.stdout.write(old_input or "")
		return(old_input)
	def goto_next_file(self):
		if len(self.next_input_files) == 0:
			return(None)
		else:
			self.input_file = self.next_input_files[0]
			self.position = 0
			self.next_input_files = self.next_input_files[1:]
			return(self.input_file)
	def skip(self, count):
		if count == 0:
			return
		s = self.input_file.read(count - 1)
		assert(len(s) == count - 1)
		self.position += count - 1
		self.consume()
	def parse_list(self):
		result = []
		self.consume(b"[")
		self.parse_optional_whitespace()
		while self.input != b"" and self.input != b"]":
			if self.input == b"R": # whoops. See anti-recursion comment in parse_ref_expression.
				self.consume(b"R")
				self.parse_optional_whitespace()
				ID = result[-2]
				generation = result[-1]
				del result[-1]
				del result[-1]
				result.append(self.ensure_object_loaded(ID, generation))
				if self.input == "" or self.input == b"]":
					break
			result += self.parse_ref_expression()
		self.consume(b"]")
		self.parse_optional_whitespace()
		return(result)
	def parse_dictionary(self, skip = 0):
		if skip <= 0:
			self.consume(b"<")
		if skip <= 1:
			self.consume(b"<")
		assert(skip < 2)
		self.parse_optional_whitespace()
		result = {}
		while self.input != b">":
			key = self.parse_name()
			value = self.parse_ref_expression()
			assert(len(value) == 1) # can be a two-element list on misdetection!
			value = value[0]
			#print(key)
			#print(value)
			result[key] = value
		self.consume(b">>")
		self.parse_optional_whitespace()
		return(result)
	def parse_actual_EOL(self):
		self.consume(b"\x0D")
		if self.input == b"\x0A":
			self.consume(b"\x0A")
		pass
	def parse_optional_whitespace(self):
		while self.input != b"" and whitespace_P(self.input):
			if self.input == b"%": # comment
				while self.input != b"" and self.input != b"\x0A":
					self.consume()
			else:
				self.consume()
	def parse_whitespace(self):
		if not self.input in [b"\x00", b"\x09", b"\x0A", b"\x0C", b"\x0D", b"\x20", b"%"]:
			self.raise_error("<whitespace>")
		self.parse_optional_whitespace()
	def parse_EOL(self):
		self.parse_optional_whitespace()
		# counts as whitespace, so not all that useful.
		#self.consume(b"\r")
		#self.consume(b"\n")
		pass
	def parse_name(self):
		if self.input != b"/":
			self.raise_error("/")
		result = io.BytesIO()
		result.write(self.consume(b"/"))
		while self.input != b"" and self.input not in b"()<>[]{}/%" and not whitespace_P(self.input):
			if self.input == b"#":
				self.consume()
				d1 = hexdigits.index(lower2(self.consume()))
				d2 = hexdigits.index(lower2(self.consume()) if self.input != b">" and self.input != b"" else 0)
				result.write(byr(d1 * 16 + d2))
			else:
				result.write(self.consume())
		self.parse_optional_whitespace()
		return(symbols.intern(result.getvalue()))
	def parse_boolean(self):
		if self.input == b"t":
			self.consume(b"true")
			self.parse_optional_whitespace()
			return(True)
		elif self.input == b"f":
			self.consume(b"false")
			self.parse_optional_whitespace()
			return(False)
		else:
			self.raise_error("<boolean>")
	def parse_integer(self, B_skip_whitespace = True, B_none_is_0 = False):
		negate = 1
		if self.input in [b"+", b"-"]:
			negate = -1 if self.consume() == b"-" else +1
		result = 0
		if self.input == b"" or self.input not in b"0123456789":
			if B_none_is_0:
				return 0
			else:
				self.raise_error("<integer>")
		while self.input != b"" and self.input in b"0123456789":
			result = result * 10
			result += maybe_ord(self.consume()) - ord('0')
		if B_skip_whitespace:
			self.parse_optional_whitespace()
		return(result * negate)
	def parse_octal_integer(self, B_sign = False, max_len = None, B_skip_whitespace = True):
		negate = 1
		if B_sign:
			if self.input in [b"+", b"-"]:
				negate = -1 if self.consume() == b"-" else +1
		if self.input == b"" or self.input not in b"01234567":
			self.raise_error("<integer>")
		result = 0
		l = 0
		while self.input != b"" and self.input in b"01234567":
			result = result * 8
			c = self.consume()
			result += maybe_ord(c) - ord('0')
			l += 1
			if max_len is not None and l >= max_len:
			        break
		if B_skip_whitespace:
			self.parse_optional_whitespace()
		return(result * negate)
	def parse_number(self):
		prefix = self.parse_integer(False, True)
		shift = 0
		if self.input == b".":
			self.consume()
			result = 0
			while self.input != b"" and self.input in b"0123456789":
				result = result * 10
				result += maybe_ord(self.consume()) - ord('0')
			shift += 1
			self.parse_optional_whitespace()
			return(prefix + float(result) / (10 ** shift))
		else:
			self.parse_optional_whitespace()
			return(prefix)
	def parse_hex_string(self, skip = 0, B_allow_whitespace = False):
		result = io.BytesIO()
		if skip <= 0:
			self.consume(b"<")
		assert(skip <= 1)
		if B_allow_whitespace:
			self.parse_optional_whitespace()
		while self.input != b"" and self.input != b">":
			d1 = hexdigits.index(lower2(self.consume()))
			if B_allow_whitespace:
				self.parse_optional_whitespace()
			d2 = hexdigits.index(lower2(self.consume()) if self.input != b">" and self.input != b"" else 0)
			if B_allow_whitespace:
				self.parse_optional_whitespace()
			result.write(byr(d1 * 16 + d2))
		self.consume(b">")
		self.parse_optional_whitespace()
		return(result.getvalue())
	def parse_string(self):
		if self.input == b"<":
			return(self.parse_hex_string())
		self.consume(b"(")
		result = io.BytesIO()
		nesting_count = 1
		while self.input != b"" and (self.input != b")" or nesting_count > 1):
			if self.input == b"\\":
				self.consume()
				e = self.input
				if e == b"n":
					self.consume()
					result.write(b"\x0A")
				elif e == b"r":
					self.consume()
					result.write(b"\x0D")
				elif e == b"t":
					self.consume()
					result.write(b"\x09")
				elif e == b"b":
					self.consume()
					result.write(b"\x08")
				elif e == b"f":
					self.consume()
					result.write(b"\x0C") # FIXME
				elif e == b"(":
					self.consume()
					result.write(b"(")
				elif e == b")":
					self.consume()
					result.write(b")")
				elif e == b"\\":
					self.consume()
					result.write(b"\\")
				elif e == b"\x0D" or e == b"\x0A": # FIXME
                                        self.consume() # FIXME
					#self.parse_actual_EOL()
				elif e in b"01234567": # FIXME remove
					result.write(byr(self.parse_octal_integer(max_len = 3, B_skip_whitespace = False)))
				else:
                                        result.write(e)
					#self.raise_error("<string-escape-character>") # specs say unknown escapes are "ignored".
			elif self.input == b"(":
				result.write(self.consume())
				nesting_count += 1
			elif self.input == b")":
				result.write(self.consume())
				nesting_count -= 1
			elif self.input == b"\x0D":
                                self.consume()
				# TODO result.write(b"\x0A")
			else:
				result.write(self.consume())
		self.consume(b")")
		self.parse_optional_whitespace()
		return(result.getvalue())
	def parse_expression(self):
		return(self.parse_number() if number_P(self.input) else
		       self.parse_dictionary_or_string() if dictionary_or_string_P(self.input) else
		       self.parse_list() if list_P(self.input) else
		       self.parse_boolean() if boolean_P(self.input) else
		       self.parse_name() if name_P(self.input) else
		       self.parse_n() if n_P(self.input) else
		       self.xraise_error("<expression>"))
	def parse_n(self):
		self.consume(b"null")
		self.parse_optional_whitespace()
		return(None)
	def xraise_error(self, expected):
		print >>sys.stderr, "info: the following was near: "
		for i in range(10):
			sys.stderr.write(self.consume())
		self.raise_error(expected)
	def parse_dictionary_or_string(self):
		if self.input != b"<":
			return(self.parse_string())
		self.consume(b"<")
		if self.input == b"<":
			return(self.parse_dictionary(1))
		else:
			return(self.parse_hex_string(1, True))
	def parse_ref_expression(self): # override this
		return[self.parse_expression()]

#if __name__ == "__main__":
#	scanner = Scanner()
#	import StringIO
#	#print(r"(\000%\000J\000F\000T\000F)Tj.")
#	scanner.start_parsing(StringIO.StringIO(r"(\000%\000J\000F\000T\000F)Tj."))
#	scanner.consume()
#	#print(scanner.parse_string())
	