#!/usr/bin/env python2

import select
import socket
import struct
import StringIO
import sys

protocol_version = 196608

def pack1(format, value):
	return(struct.pack(">" + format, value))
def unpack1(format, raw):
	return(struct.unpack(">" + format, raw)[0])
def packz(value):
	return(struct.pack("%ds" % (len(value) + 1), value))
def send_packet(destination, payload, packet_type):
	body_size = len(payload)
	total_size = body_size + 4
	thing = pack1("i", total_size) + payload
	if packet_type is not None:
		thing = packet_type + thing
	destination.send(thing)
def readz(IO):
	dest = StringIO.StringIO()
	c = IO.read(1)
	while c != "" and c != "\0":
		dest.write(c)
		c = IO.read(1)
	return(dest.getvalue())
def send_startup_message(destination, parameters):
	"""
		parameters usually contains:
			user
			database
	"""
	packet = StringIO.StringIO()
	#packet.write(pack1("i", 0)) # len
	packet.write(pack1("i", protocol_version)) # protocol version
	for key, value in parameters:
		packet.write(packz(key))
		packet.write(packz(value))
	packet.write(packz(""))
	send_packet(destination, packet.getvalue(), None)
def send_simple_query(destination, query):
	packet = StringIO.StringIO()
	packet.write(packz(query))
	send_packet(destination, packet.getvalue(), 'Q')
def send_parse_command(destination, prepared_statement_name, query_string, parameter_types = []):
	packet = StringIO.StringIO()
	packet.write(packz(prepared_statement_name or ""))
	packet.write(packz(query_string))
	packet.write(pack1("h", len(parameter_types)))
	for parameter_type in parameter_types: # IDs
		packet.write(pack1("i", parameter_type or 0))
	send_packet(destination, packet.getvalue(), 'Q')
	# -> parse_complete response
def send_execute_portal(destination, name = None, max_row_count = None):
	packet = StringIO.StringIO()
	packet.write(packz(name or ""))
	packet.write(pack1("i", max_row_count or 0))
	send_packet(destination, packet.getvalue(), 'E')
def send_flush(destination):
	send_packet(destination, "", 'H')
def send_close_thing(destination, kind, name):
	""" kind is either 'S' or 'P' """
	packet = StringIO.StringIO()
	assert(len(kind) == 1)
	packet.write(kind)
	packet.write(packz(name or ""))
	send_packet(destination, packet.getvalue(), 'C')
	# => wait for close complete
def send_close_portal(destination, name = None):
	send_close_thing(destination, 'P', name)
def send_close_prepared_statement(destination, name = None):
	send_close_thing(destination, 'S', name)
def send_describe_thing(destination, kind, name):
	""" kind is either 'S' or 'P' """
	assert(len(kind) == 1)
	packet = StringIO.StringIO()
	packet.write(kind)
	packet.write(packz(name or ""))
	send_packet(destination, packet.getvalue(), 'D')
def send_describe_portal(destination, name):
	return(send_describe_thing(destination, 'P', name))
def send_describe_prepared_statement(destination, name):
	return(send_describe_thing(destination, 'S', name))
def send_bind(destination, portal_name, prepared_statement_name, arguments):
	packet = StringIO.StringIO()
	packet.write(packz(portal_name or ""))
	packet.write(packz(prepared_statement_name or ""))
	format_code_flag = 0 # zero for all text, one for all the same, otherwise count.
	format_codes = []
	packet.write(pack1("h", format_code_flag))
	for format_code in format_codes:
		packet.write(pack1("h", format_code)) # 0=text,1=binary.
	argument_count = len(arguments)
	packet.write(pack1("h", argument_count))
	for argument in arguments:
		argument = str(argument) if argument is not None else None
		sz = len(argument) if argument is not None else -1
		packet.write(pack1("h", sz))
		if argument is not None:
			packet.write(argument)
	result_column_format = 0 # 0=all text, 1=all the same, count
	packet.write(pack1("h", result_column_format))
	for result_column_type in []:
		packet.write(pack1("h", result_column_type)) # 0=text, 1=binary
	send_packet(destination, packet.getvalue(), 'B')
	# gets bind_complete eventually.
def parse_parameter_status(packet_data):
	assert(packet_data[0] == 'S')
	parts = packet_data[5:].split("\0")
	return(parts[0], parts[1]) # key, value
def parse_backend_key(packet_data):
	assert(packet_data[0] == 'K')
	process_ID = unpack1("i", packet_data[5:9])
	secret_key = unpack1("i", packet_data[9:13])
	return(process_ID, secret_key)
def parse_parse_complete(packet_data):
	return(None)
def parse_portal_suspended(packet_data):
	assert(packet_data[0] == 's')
	return(None)
def parse_bind_complete(packet_data):
	assert(packet_data[0] == '2')
	return(None)
def parse_close_complete(packet_data):
	assert(packet_data[0] == '3')
	return(None)
def parse_error_response(packet_data): # or notice
	assert(packet_data[0] == 'E' or packet_data[0] == 'N')
	IO = StringIO.StringIO(packet_data[5:])
	result = []
	while True:
		field_type = IO.read(1)
		""" field_type
			'S' severity
			'C' code
			'M' message
			'D' detail
			'H' hint
			'P' position
			'p' internal position
			'q' internal query
			'W' where
			'F' file
			'L' line
			'R' routine
		"""
		if field_type == chr(0):
			break
		value = readz(IO)
		result.append((field_type, value))
	return(result)
def parse_notice_response(packet_data):
	return(parse_error_response(packet_data))
def parse_no_data(packet_data):
	assert(packet_data[0] == 'n')
	return(None)
def parse_ready_for_query(packet_data):
	transaction_status = packet_data[5]
	return(transaction_status) # 'I'(idle) or 'T'(transaction) or 'E'(error)
def parse_command_complete(packet_data): # or EmptyQueryResponse if query string was empty.
	assert(packet_data[0] == 'C')
	command_tag = packet_data[5:]
	command_tag = command_tag.split("\0")[0]
	return(command_tag)
def parse_empty_query_response(packet_data): # or parse_command_complete
	assert(packet_data[0] == 'I')
	return(None)
class ColumnInfo(object):
	def __repr__(self):
		return("ColumnInfo(%r)" % self.__dict__)
def parse_authentication_OK_response(packet_data):
	assert(packet_data[0] == 'R')
	return(unpack1("i", packet_data[5:9]))
def parse_row_description(packet_data):
	assert(packet_data[0] == 'T')
	column_count = unpack1("h", packet_data[5:7])
	IO = StringIO.StringIO(packet_data[7:])
	result = []
	for i in range(column_count):
		column_name = readz(IO)
		raw = IO.read(4+2+4+2+4+2)
		c = ColumnInfo()
		c.name = column_name
		c.table_ID, c.ID, c.type_ID, c.typlen, c.typmod, c.format = struct.unpack(">ihihih", raw)
		result.append(c)
	return(result)
def parse_data_row(packet_data):
	assert(packet_data[0] == 'D')
	column_count = unpack1("h", packet_data[5:7])
	IO = StringIO.StringIO(packet_data[7:])
	result = []
	for i in range(column_count):
		value_length = unpack1("i", IO.read(4)) # -1 = NULL
		if value_length != -1:
			value = IO.read(value_length)
		else:
			value = None
		result.append(value)
	return(result)
class Receiver(object):
	def __init__(self):
		self.received_data = StringIO.StringIO()
		self.expected_type_count = 1
		self.expected_size = None
		self.B_expect_header = True
	def handleReceivedPacket(self, packet_data):
		packet_type = packet_data[0]
		#print("OK", packet_type)
		if packet_type == 'S':
			print(parse_parameter_status(packet_data))
		elif packet_type == 'K':
			print(parse_backend_key(packet_data))
		elif packet_type == 'Z':
			print(parse_ready_for_query(packet_data))
		elif packet_type == 'C':
			print(parse_command_complete(packet_data))
		elif packet_type == 'T':
			print(parse_row_description(packet_data))
		elif packet_type == 'E':
			print(parse_error_response(packet_data))
		elif packet_type == 'n':
			print(parse_no_data(packet_data))
		elif packet_type == "N":
			print(parse_notice_response(packet_data))
			# TOOD 'A' notification
			# TODO 't' parameter description
		elif packet_type == '1':
			print(parse_parse_complete(packet_data))
		elif packet_type == 's':
			print(parse_portal_suspended(packet_data))
		elif packet_type == 'D':
			print(parse_data_row(packet_data))
		elif packet_type == 'I':
			print(parse_empty_query_response(packet_data))
		elif packet_type == 'R':
			print(parse_authentication_OK_response(packet_data))
		elif packet_type == '2':
			print(parse_bind_complete(packet_data))
		elif packet_type == '3':
			print(parse_close_complete(packet_data))
		else:
			print("packet", repr(packet_data))
	def handleReceivedData(self, chunk):
		len_before = self.received_data.tell()
		self.received_data.write(chunk)
		while True:
			if self.expected_size is None:
				if self.received_data.tell() >= 4 + self.expected_type_count: # we got length
					data = self.received_data.getvalue()
					size = unpack1("i", data[self.expected_type_count : self.expected_type_count + 4])
					assert(size >= 4)
					self.expected_size = size + 1 # fix up heading message-type
					#print("SZ", size)
				else:
					break
			elif self.received_data.tell() >= self.expected_size: # we got (at least) an entire packet.
				data = self.received_data.getvalue()
				packet_data = data[: self.expected_size]
				self.received_data.seek(0)
				self.received_data.write(data[self.expected_size : ])
				self.received_data.truncate()
				self.expected_size = None
				self.handleReceivedPacket(packet_data) # [0] + packet_data[5:]) # get rid of length, sigh...
			else:
				break

a = socket.socket()
a.connect(("localhost", 5432))
send_startup_message(a, [("user", "postgres"), ("database", "template1")])
send_simple_query(a, "SELECT * FROM users")
#send_execute_portal(a)
receiver = Receiver()
while True:
	read_FDs, write_FDs, except_FDs = select.select([a], [], [a])
	if len(read_FDs) < 1:
		continue
	chunk = a.recv(4096)
	if chunk != "":
		receiver.handleReceivedData(chunk)
		sys.stdout.flush()
	else: # EOF
		break

"""
By TCP rules, the only way for a server program to know if a client has disconnected,
is to try to read from the socket. Specifically, if select() says there is data, but
recv() returns 0 bytes of data, then this implies the client has disconnected.

But a server program might want to confirm that a tcp client is still connected without
reading data. For example, before it performs some task or sends data to the client.
This program will demonstrate how to detect a TCP client disconnect without reading data.

The method to do this:
1) select on socket as poll (no wait)
2) if no recv data waiting, then client still connected
3) if recv data waiting, the read one char using PEEK flag 
4) if PEEK data len=0, then client has disconnected, otherwise its connected.
Note, the peek flag will read data without removing it from tcp queue.
"""
