#!/usr/bin/python3 -u
# coding=utf-8
from argparse import ArgumentParser, Namespace as Args
from _dbus_glib_bindings import DBusGMainLoop
from datetime import datetime
import traceback
import dbus
from gi.repository import GLib as glib
import signal
import sys

from _dbus_bindings import Connection, Message, MethodCallMessage, MethodReturnMessage, SignalMessage, ErrorMessage, \
	BUS_DAEMON_NAME, BUS_DAEMON_PATH, BUS_DAEMON_IFACE, HANDLER_RESULT_HANDLED, BUS_SYSTEM, BUS_SESSION

# noinspection PyUnreachableCode
if False:
	from typing import Optional, AnyStr, NoReturn, Dict, Any, List, Callable


RESET_COLOR = 0
RED = 31
GREEN = 32
ORANGE = 33
BLUE = 34
PURPLE = 35
CYAN = 36
LIGHT_GREY = 37
DARK_GREY = 90
LIGHT_RED = 91
LIGHT_GREEN = 92
YELLOW = 93
LIGHT_BLUE = 94
PINK = 95
LIGHT_CYAN = 96

COLORS = [RED, LIGHT_GREEN, LIGHT_GREY, ORANGE, BLUE, PURPLE, CYAN, GREEN, LIGHT_RED, YELLOW, LIGHT_BLUE, PINK, LIGHT_CYAN]
N_COLORS = len(COLORS)


def raise_on_error(msg):
	# type: (Message) -> NoReturn
	if isinstance(msg, ErrorMessage):
		raise Exception(msg.get_error_name())


def ignore_errors(func):
	def wrapper(*args):
		try:
			return func(*args)
		except:
			return None
	return wrapper


def catch(func):
	def wrapper(*args):
		try:
			return func(*args)
		except KeyboardInterrupt:
			return
		except Exception:
			traceback.print_exc()
	return wrapper


def call_daemon(con, on_reply, member, *args):
	# type: (Connection, Callable[[Message], NoReturn], str, List[Any]) -> NoReturn

	msg = MethodCallMessage(BUS_DAEMON_NAME, BUS_DAEMON_PATH, BUS_DAEMON_IFACE, member)

	for arg in args:
		msg.append(arg)

	con.send_message_with_reply(msg, on_reply, require_main_loop=True)


def call_daemon_blocking(con, member, *args):
	# type: (Connection, str, List[Any]) -> Message

	msg = MethodCallMessage(BUS_DAEMON_NAME, BUS_DAEMON_PATH, BUS_DAEMON_IFACE, member)

	for arg in args:
		msg.append(arg)

	return con.send_message_with_reply_and_block(msg)


def add_message_callback(connection, callback):
	# type: (Connection, Callable[[Message], NoReturn]) -> NoReturn

	def dispatch(_, msg):
		# type: (Connection, Message) -> int

		callback(msg)

		return HANDLER_RESULT_HANDLED

	connection.add_message_filter(dispatch)


def add_match_rule(connection, match_rule):
	# type: (Connection, AnyStr) -> NoReturn
	call_daemon(connection, raise_on_error, 'AddMatch', match_rule)


@ignore_errors
def resolve_service_blocking(sid):
	# type: (AnyStr) -> NoReturn

	reply = call_daemon_blocking(DBUS, 'GetConnectionUnixProcessID', sid)

	if isinstance(reply, ErrorMessage):
		return None

	args = reply.get_args_list()

	if not args:
		return None

	pid = str(args[0])
	PIDS[sid] = str(pid)

	with open('/proc/{0}/comm'.format(pid)) as proc:
		PROCESSES[sid] = proc.read().replace('\0', ' ').rstrip()


def resolve_service(sid):
	# type: (AnyStr) -> NoReturn

	@ignore_errors
	def on_reply(msg):
		# type: (Message) -> NoReturn

		if not isinstance(msg, MethodReturnMessage):
			return

		args = msg.get_args_list()
		if len(args) <= 0:
			return

		pid = args[0]
		PIDS[sid] = str(pid)

		with open('/proc/{0}/comm'.format(pid)) as proc:
			PROCESSES[sid] = proc.read().replace('\0', ' ').rstrip()

	call_daemon(DBUS, on_reply, 'GetConnectionUnixProcessID', sid)


def color_code(c):
	# type: (int) -> str
	return '\033[' + str(c) + 'm'


def colorize(txt, color):
	# type: (AnyStr, int) -> str

	if NO_COLOR:
		return txt

	if not txt:  # do not surround empty strings with color codes
		return ''

	return color_code(color) + txt + color_code(RESET_COLOR)


def column(column_name, right_align = False):
	def decorator(func):
		def wrapper(*args):

			if column_name not in COLUMNS:
				return ''

			txt = func(*args)

			if TABS:
				return txt.strip() + '\t'

			width = COLUMNS.get(column_name) or 0
			width = max(width, len(txt))
			COLUMNS[column_name] = width

			if right_align:
				return txt.rjust(width) + ' '
			else:
				return txt.ljust(width) + ' '
		return wrapper
	return decorator


def color(col):
	def decorator(func):
		def wrapper(*args):
			return colorize(func(*args), col)
		return wrapper
	return decorator


def pretty_print_kv_pair(kv_pair):
	# type: ((AnyStr, Any)) -> AnyStr

	key, value = kv_pair

	k = colorize('"' + key + '"', BLUE)
	v = pretty_print_arg(value)

	return k + ': ' + v


def pretty_print_arg(arg):
	# type: (Any) -> AnyStr

	if isinstance(arg, dict):
		kv_pairs = list(map(pretty_print_kv_pair, iter(arg.items())))
		return '{' + ', '.join(kv_pairs) + '}'

	if isinstance(arg, (dbus.Array, dbus.Struct)):
		if len(arg) == 0:
			return 'null'
		else:
			args = list(map(pretty_print_arg, arg))
			return '[' + ', '.join(args) + ']'

	terminal = str(arg)

	if isinstance(arg, dbus.String):
		terminal = '"' + arg + '"'
	elif isinstance(arg, dbus.Boolean):
		terminal = 'false' if arg == 0 else 'true'
	elif isinstance(arg, dbus.Byte):
		terminal = str(int(arg))

	return colorize(terminal, LIGHT_BLUE)


def get_args(msg):
	# type: (Message) -> AnyStr
	if 'arguments' not in COLUMNS:
		return ''
	
	try:	
		return ' '.join(map(pretty_print_arg, msg.get_args_list()))
	except:
		return colorize('Failed to parse payload', RED)


def get_service(sid):

	proc = PROCESSES.get(sid) if 'process' in COLUMNS else ''
	pid  = PIDS.get(sid) if 'pid' in COLUMNS else ''

	if proc is None or pid is None:
		if REAL_TIME:
			resolve_service(sid)  # async (callback) resolve service and report <unknown>
		else:
			resolve_service_blocking(sid)  
			proc = PROCESSES.get(sid) if 'process' in COLUMNS else ''
			pid  = PIDS.get(sid) if 'pid' in COLUMNS else ''

	if proc is None:
		if pid is None:
			proc = ''  # dont show <unknown> twice
		else:
			proc = '<unknown>'

	if pid is None:
		pid = '<unknown>'

	sid = sid if 'bus_id' in COLUMNS else ''

	return ' '.join([s for s in (proc, pid, sid) if s != ''])


def get_message_color(msg):
	# type: (Message) -> int

	if isinstance(msg, SignalMessage):
		return RESET_COLOR

	serial = msg.get_serial() if isinstance(msg, MethodCallMessage) else msg.get_reply_serial()

	return COLORS[serial % N_COLORS]


@column('src')
def get_source(msg):
	# type: (Message) -> AnyStr

	sid = msg.get_destination() if is_reply(msg) else msg.get_sender()
	return get_service(sid)


@column('dst')
def get_destination(msg):
	# type: (Message) -> AnyStr

	sid = msg.get_sender() if is_reply(msg) else msg.get_destination()

	if sid is None:
		return '*'

	return get_service(sid)


@color(PURPLE)
@column('member')
def get_member(msg):
	return msg.get_member() or ''


@color(LIGHT_RED)
@column('member')
def get_error(msg):
	# type: (Message) -> str
	return msg.get_error_name() or ''


@color(CYAN)
@column('object_path')
def get_path(msg):
	# type: (Message) -> str
	return msg.get_path() or ''


@color(PINK)
@column('interface')
def get_interface(msg):
	# type: (Message) -> str
	return msg.get_interface() or ''


@color(GREEN)
@column('signature')
def get_signature(msg):
	# type: (Message) -> str
	return msg.get_signature() or ''


@color(DARK_GREY)
@column('time')
def get_timestamp():
	# type: () -> str
	return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]


@column('serial', right_align=True)
def get_serial(msg):
	# type: (Message) -> AnyStr
	serial = msg.get_reply_serial() or msg.get_serial() or ''
	return str(serial)


def get_arrow(msg):
	# type: (Message) -> str

	if isinstance(msg, (SignalMessage, MethodCallMessage)):
		return '==>'
	if isinstance(msg, MethodReturnMessage):
		return '<=='
	if isinstance(msg, ErrorMessage):
		return '<=!'
	raise Exception('unknown message type')


def is_reply(msg):
	# type: (Message) -> bool
	return isinstance(msg, (MethodReturnMessage, ErrorMessage))


def get_from_to(msg):
	# type: (Message) -> AnyStr

	color     = get_message_color(msg)
	serial    = get_serial(msg)
	src_proc  = get_source(msg)
	arrow     = get_arrow(msg)
	dest_proc = get_destination(msg)

	from_to = serial + src_proc + arrow + ' ' + dest_proc

	return colorize(from_to, color)


@catch
def on_message(msg):
	# type: (Message) -> Optional[Any]

	if msg.get_sender() == OWN_SID or msg.get_destination() == OWN_SID:  # do not show own messages
		return

	t_stamp = get_timestamp()
	from_to = get_from_to(msg)
	path    = get_path(msg)
	itf     = get_interface(msg)
	member  = get_member(msg)
	sig     = get_signature(msg)
	args    = get_args(msg)

	print((t_stamp + from_to + member + path + itf + sig + args))

	return None   # make the "type checker" happy


def signal_handler(*_):
	main_loop.quit()
	sys.exit(0)


def init_subscriptions(con, args):
	# type: (Connection, Args) -> NoReturn

	all_messages = False
	if not (args.method_call or args.method_return or args.signal or args.error):
		all_messages = True
	if args.method_call or all_messages:
		add_match_rule(con, "type='method_call',eavesdrop=true")
	if args.method_return or all_messages:
		add_match_rule(con, "type='method_return',eavesdrop=true")
	if args.signal or all_messages:
		add_match_rule(con, "type='signal',eavesdrop=true")
	if args.error or all_messages:
		add_match_rule(con, "type='error',eavesdrop=true")


def init_columns(args):
	# type: (Args) -> Dict[str,int]

	columns = dict(src=0, dst=0)

	msg_defaults = not (args.object_path or args.member or args.interface or args.arguments or args.signature or args.serial)

	if not (args.process_name or args.pid or args.bus_id):
		args.process_name = True

	if args.object_path or msg_defaults:
		columns['object_path'] = 0
	if args.member or msg_defaults:
		columns['member'] = 0
	if args.interface:
		columns['interface'] = 0
	if args.arguments or msg_defaults:
		columns['arguments'] = 0
	if args.signature:
		columns['signature'] = 0
	if args.serial:
		columns['serial'] = 0
	if args.process_name:
		columns['process'] = 0
	if args.pid:
		columns['pid'] = 0
	if args.bus_id:
		columns['bus_id'] = 0
	if args.time:
		columns['time'] = 0

	return columns


def parse_args():
	# type: () -> Args
	parser = ArgumentParser(description='DBus Logger', add_help=True)

	bus_type = parser.add_argument_group(title = 'bus selection')
	bus_type.add_argument('--session', action="store_true", help='use session bus')
	bus_type.add_argument('--system',  action="store_true", help='use system bus (default)')

	message_types = parser.add_argument_group(title = 'message types', description = 'if none of these are specified, all messages will be shown')
	message_types.add_argument('-c', '--method_call',   action="store_true", help='log method calls')
	message_types.add_argument('-r', '--method_return', action="store_true", help='log method replies')
	message_types.add_argument('-e', '--error',         action="store_true", help='log error replies')
	message_types.add_argument('-s', '--signal',        action="store_true", help='log signals')

	service_fields = parser.add_argument_group(title = 'service info', description = 'if none of these are specified, only the process name will be shown')
	service_fields.add_argument('-p', '--process_name', action="store_true", help='show process name')
	service_fields.add_argument('-d', '--pid',          action="store_true", help='show process id')
	service_fields.add_argument('-b', '--bus_id',       action="store_true", help='show bus id')

	msg_fields = parser.add_argument_group(title = 'message fields', description = 'if none of these are specified, it will default to -oma')
	msg_fields.add_argument('-o', '--object_path', action="store_true", help='show object path')
	msg_fields.add_argument('-m', '--member',      action="store_true", help='show member')
	msg_fields.add_argument('-i', '--interface',   action="store_true", help='show interface')
	msg_fields.add_argument('-t', '--signature',   action="store_true", help='show signature (argument types)')
	msg_fields.add_argument('-a', '--arguments',   action="store_true", help='show message arguments (payload)')
	msg_fields.add_argument('-l', '--serial',      action="store_true", help='show message serial number')

	out_format = parser.add_argument_group(title = 'output formatting')
	out_format.add_argument('--no_color',      action="store_true", help='do not color output')
	out_format.add_argument('--tab_separated', action="store_true", help='output unaligned columns separated by tab characters')
	out_format.add_argument('--script',        action="store_true", help='shorthand for selecting above two options')
	out_format.add_argument('--time',          action="store_true", help='add timestamps to output. Recommended to be used only together with the --real_time option.')
	out_format.add_argument('--real_time',     action="store_true", help='process names/pids are resolved asynchronously to allow real-time output. This will introduce some <unknown> processes/pids, until they are resolved.')

	return parser.parse_args()

######################################################################################################


PROCESSES = dict()  # type: Dict[AnyStr, AnyStr]
PIDS      = dict()  # type: Dict[AnyStr, AnyStr]

signal.signal(signal.SIGINT, signal_handler)

DBusGMainLoop(set_as_default=True)
main_loop = glib.MainLoop()

args = parse_args()
bus  = BUS_SESSION if args.session else BUS_SYSTEM

# noinspection PyProtectedMember
DBUS      = Connection._new_for_bus(bus)  # important! must keep a reference to the connection, otherwise it will be gc'd!
OWN_SID   = DBUS.get_unique_name()
NO_COLOR  = args.no_color or args.script
TABS      = args.tab_separated or args.script
REAL_TIME = args.real_time

add_message_callback(DBUS, on_message)
init_subscriptions(DBUS, args)
COLUMNS = init_columns(args)

if 'time' in COLUMNS and not REAL_TIME:
	msg = 'WARNING:\nTimestamps cannot be guaranteed to be accurate!\nConsider using the --real_time option.\n\n'
	print((colorize(msg, RED)))

try:
	main_loop.run()
except KeyboardInterrupt:
	main_loop.quit()