#!/usr/bin/python2 -u
# coding=utf-8

import logging
import re
import socket
import sys
import gobject
import signals
import config as cfg

from dbus.mainloop.glib import DBusGMainLoop
from pymodbus.client.sync import ModbusSerialClient as Modbus
from pymodbus.exceptions import ModbusException, ModbusIOException
from pymodbus.other_message import ReportSlaveIdRequest
from pymodbus.pdu import ExceptionResponse
from pymodbus.register_read_message import ReadInputRegistersResponse
from data import BatteryStatus, BatterySignal, Battery, ServiceSignal
from python_libs.ie_dbus.dbus_service import DBusService

import time
import os
import csv
import pika
import zipfile
import hashlib
import base64
import hmac
import requests
from datetime import datetime
import io
import json
from convert import first
CSV_DIR = "/data/csv_files/"
INSTALLATION_NAME_FILE = '/data/innovenergy/openvpn/installation-name'

# trick the pycharm type-checker into thinking Callable is in scope, not used at runtime
# noinspection PyUnreachableCode
if False:
	from typing import Callable, List, Iterable, NoReturn


RESET_REGISTER = 0x2087


def compress_csv_data(csv_data, file_name="data.csv"):
	memory_stream = io.BytesIO()

	# Create a zip archive in the memory buffer
	with zipfile.ZipFile(memory_stream, 'w', zipfile.ZIP_DEFLATED) as archive:
		# Add CSV data to the ZIP archive using writestr
		archive.writestr(file_name, csv_data.encode('utf-8'))

	# Get the compressed byte array from the memory buffer
	compressed_bytes = memory_stream.getvalue()

	# Encode the compressed byte array as a Base64 string
	base64_string = base64.b64encode(compressed_bytes).decode('utf-8')

	return base64_string

class S3config:
	def __init__(self):
		self.bucket = cfg.S3BUCKET
		self.region = "sos-ch-dk-2"
		self.provider = "exo.io"
		self.key = cfg.S3KEY
		self.secret = cfg.S3SECRET
		self.content_type = "application/base64; charset=utf-8"

	@property
	def host(self):
		return "{}.{}.{}".format(self.bucket, self.region, self.provider)

	@property
	def url(self):
		return "https://{}".format(self.host)

	def create_put_request(self, s3_path, data):
		headers = self._create_request("PUT", s3_path)
		url = "{}/{}".format(self.url, s3_path)
		response = requests.put(url, headers=headers, data=data)
		return response

	def _create_request(self, method, s3_path):
		date = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
		auth = self._create_authorization(method, self.bucket, s3_path, date, self.key, self.secret, self.content_type)
		headers = {
			"Host": self.host,
			"Date": date,
			"Authorization": auth,
			"Content-Type": self.content_type
		}
		return headers

	@staticmethod
	def _create_authorization(method, bucket, s3_path, date, s3_key, s3_secret, content_type="", md5_hash=""):
		payload = "{}\n{}\n{}\n{}\n/{}/{}".format(
			method, md5_hash, content_type, date, bucket.strip('/'), s3_path.strip('/')
		)
		signature = base64.b64encode(
			hmac.new(s3_secret.encode(), payload.encode(), hashlib.sha1).digest()
		).decode()
		return "AWS {}:{}".format(s3_key, signature)


def SubscribeToQueue():
	try:
		connection = pika.BlockingConnection(pika.ConnectionParameters(host="10.2.0.11",
                                                                    port=5672,
                                                                    virtual_host="/",
                                                                    credentials=pika.PlainCredentials("producer", "b187ceaddb54d5485063ddc1d41af66f")))
		channel = connection.channel()
		channel.queue_declare(queue="statusQueue", durable=True)
		print("Subscribed to queue")
	except Exception as ex:
		print("An error occurred while connecting to the RabbitMQ queue:", ex)
	return channel


previous_warnings = {}
previous_alarms = {}

class MessageType:
	ALARM_OR_WARNING = "AlarmOrWarning"
	HEARTBEAT = "Heartbeat"

class AlarmOrWarning:
	def __init__(self, description, created_by):
		self.date = datetime.now().strftime('%Y-%m-%d')
		self.time = datetime.now().strftime('%H:%M:%S')
		self.description = description
		self.created_by = created_by

	def to_dict(self):
		return {
			"Date": self.date,
			"Time": self.time,
			"Description": self.description,
			"CreatedBy": self.created_by
		}

channel = SubscribeToQueue()
# Create an S3config instance
s3_config = S3config()
INSTALLATION_ID=int(s3_config.bucket.split('-')[0])
PRODUCT_ID = 1
is_first_update = True
prev_status = 0
subscribed_to_queue_first_time = False
heartbit_interval = 0

def update_state_from_dictionaries(current_warnings, current_alarms, node_numbers):
	global previous_warnings, previous_alarms, INSTALLATION_ID, PRODUCT_ID, is_first_update, channel, prev_status, heartbit_interval, subscribed_to_queue_first_time

	heartbit_interval += 1
	
	if is_first_update:
		changed_warnings = current_warnings
		changed_alarms = current_alarms
		is_first_update = False
	else:
		changed_alarms = {}
		changed_warnings = {}
		# calculate the diff in warnings and alarms
		prev_alarm_value_list = list(previous_alarms.values())
		alarm_keys = list(previous_alarms.keys())

		for i, alarm in enumerate(current_alarms.values()):
			if alarm != prev_alarm_value_list[i]:
				changed_alarms[alarm_keys[i]] = True
			else:
				changed_alarms[alarm_keys[i]] = False

		prev_warning_value_list=list(previous_warnings.values())
		warning_keys=list(previous_warnings.keys())

		for i, warning in enumerate(current_warnings.values()):
			if warning!=prev_warning_value_list[i]:
				changed_warnings[warning_keys[i]]=True
			else:
				changed_warnings[warning_keys[i]]=False

	status_message = {
		"InstallationId": INSTALLATION_ID,
		"Product": PRODUCT_ID,
		"Status": 0,
		"Type": 1,
		"Warnings": [],
		"Alarms": []
	}

	alarms_number_list = []
	for node_number in node_numbers:
		cnt = 0
		for i, alarm_value in enumerate(current_alarms.values()):
			if list(current_alarms.keys())[i].split("/")[3] == node_number:
				if alarm_value:
					cnt+=1
		alarms_number_list.append(cnt)

	warnings_number_list = []
	for node_number in node_numbers:
		cnt = 0
		for i, alarm_value in enumerate(current_warnings.values()):
			if list(current_warnings.keys())[i].split("/")[3] == node_number:
				if warning_value:
					cnt+=1
		warnings_number_list.append(cnt)

	# Evaluate alarms
	if any(changed_alarms.values()):
		for i, changed_alarm in enumerate(changed_alarms.values()):
			if changed_alarm and list(current_alarms.values())[i]:
				status_message["Alarms"].append(AlarmOrWarning(list(current_alarms.keys())[i],"System").to_dict())

	if any(changed_warnings.values()):
		for i, changed_warning in enumerate(changed_warnings.values()):
			if changed_warning and list(current_warnings.values())[i]:
				status_message["Warnings"].append(AlarmOrWarning(list(current_warnings.keys())[i],"System").to_dict())

	if any(current_alarms.values()):
		status_message["Status"]=2

	if not any(current_alarms.values()) and any(current_warnings.values()):
		status_message["Status"]=1

	if not any(current_alarms.values()) and not any(current_warnings.values()):
		status_message["Status"]=0

	if status_message["Status"]!=prev_status or len(status_message["Warnings"])>0 or len(status_message["Alarms"])>0:
		prev_status=status_message["Status"]
		status_message["Type"]=0
		status_message = json.dumps(status_message)
		channel.basic_publish(exchange="", routing_key="statusQueue", body=status_message)
		print(status_message)
		print("Message sent successfully")
	elif heartbit_interval>=15 or not subscribed_to_queue_first_time:
		print("Send heartbit message to rabbitmq")
		heartbit_interval=0
		subscribed_to_queue_first_time=True
		status_message = json.dumps(status_message)
		channel.basic_publish(exchange="", routing_key="statusQueue", body=status_message)

	previous_warnings = current_warnings.copy()
	previous_alarms = current_alarms.copy()

	return status_message, alarms_number_list, warnings_number_list

def read_csv_as_string(file_path):
	"""
	Reads a CSV file from the given path and returns its content as a single string.
	"""
	try:
		# Note: 'encoding' is not available in open() in Python 2.7, so we'll use 'codecs' module.
		import codecs
		with codecs.open(file_path, 'r', encoding='utf-8') as file:
			return file.read()
	except IOError as e:
		if e.errno == 2:  # errno 2 corresponds to "No such file or directory"
			print("Error: The file {} does not exist.".format(file_path))
		else:
			print("IO error occurred: {}".format(str(e)))
		return None



def init_modbus(tty):
	# type: (str) -> Modbus

	logging.debug('initializing Modbus')

	return Modbus(
		port='/dev/' + tty,
		method=cfg.MODE,
		baudrate=cfg.BAUD_RATE,
		stopbits=cfg.STOP_BITS,
		bytesize=cfg.BYTE_SIZE,
		timeout=cfg.TIMEOUT,
		parity=cfg.PARITY)


def init_udp_socket():
	# type: () -> socket

	s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
	s.setblocking(False)

	return s


def report_slave_id(modbus, slave_address):
	# type: (Modbus, int) -> str

	slave = str(slave_address)

	logging.debug('requesting slave id from node ' + slave)

	with modbus:

		request = ReportSlaveIdRequest(unit=slave_address)
		response = modbus.execute(request)

		if response is ExceptionResponse or issubclass(type(response), ModbusException):
			raise Exception('failed to get slave id from ' + slave + ' : ' + str(response))

		return response.identifier


def identify_battery(modbus, slave_address):
	# type: (Modbus, int) -> Battery

	logging.info('identifying battery...')

	hardware_version, bms_version, ampere_hours = parse_slave_id(modbus, slave_address)
	firmware_version = read_firmware_version(modbus, slave_address)

	specs = Battery(
		slave_address=slave_address,
		hardware_version=hardware_version,
		firmware_version=firmware_version,
		bms_version=bms_version,
		ampere_hours=ampere_hours)

	logging.info('battery identified:\n{0}'.format(str(specs)))

	return specs


def identify_batteries(modbus):
	# type: (Modbus) -> List[Battery]

	def _identify_batteries():
		slave_address = 0
		n_missing = -255

		while n_missing < 3:
			slave_address += 1
			try:
				yield identify_battery(modbus, slave_address)
				n_missing = 0
			except Exception as e:
				logging.info('failed to identify battery at {0} : {1}'.format(str(slave_address), str(e)))
				n_missing += 1

		logging.info('giving up searching for further batteries')

	batteries = list(_identify_batteries())  # dont be lazy!

	n = len(batteries)
	logging.info('found ' + str(n) + (' battery' if n == 1 else ' batteries'))

	return batteries


def parse_slave_id(modbus, slave_address):
	# type: (Modbus, int) -> (str, str, int)

	slave_id = report_slave_id(modbus, slave_address)

	sid = re.sub(r'[^\x20-\x7E]', '', slave_id)  # remove weird special chars

	match = re.match('(?P<hw>48TL(?P<ah>[0-9]+)) *(?P<bms>.*)', sid)

	if match is None:
		raise Exception('no known battery found')

	return match.group('hw').strip(), match.group('bms').strip(), int(match.group('ah').strip())


def read_firmware_version(modbus, slave_address):
	# type: (Modbus, int) -> str

	logging.debug('reading firmware version')

	with modbus:

		response = read_modbus_registers(modbus, slave_address, base_address=1054, count=1)
		register = response.registers[0]

		return '{0:0>4X}'.format(register)


def read_modbus_registers(modbus, slave_address, base_address=cfg.BASE_ADDRESS, count=cfg.NO_OF_REGISTERS):
	# type: (Modbus, int, int, int) -> ReadInputRegistersResponse

	logging.debug('requesting modbus registers {0}-{1}'.format(base_address, base_address + count))

	return modbus.read_input_registers(
		address=base_address,
		count=count,
		unit=slave_address)


def read_battery_status(modbus, battery):
	# type: (Modbus, Battery) -> BatteryStatus
	"""
	Read the modbus registers containing the battery's status info.
	"""

	logging.debug('reading battery status')

	with modbus:
		data = read_modbus_registers(modbus, battery.slave_address)
		return BatteryStatus(battery, data.registers)


def publish_values_on_dbus(service, battery_signals, battery_statuses):
	# type: (DBusService, Iterable[BatterySignal], Iterable[BatteryStatus]) -> ()

	publish_individuals(service, battery_signals, battery_statuses)
	publish_aggregates(service, battery_signals, battery_statuses)


def publish_aggregates(service, signals, battery_statuses):
	# type: (DBusService, Iterable[BatterySignal], Iterable[BatteryStatus]) -> ()

	for s in signals:
		if s.aggregate is None:
			continue
		values = [s.get_value(battery_status) for battery_status in battery_statuses]
		value = s.aggregate(values)
		service.own_properties.set(s.dbus_path, value, s.unit)


def publish_individuals(service, signals, battery_statuses):
	# type: (DBusService, Iterable[BatterySignal], Iterable[BatteryStatus]) -> ()

	for signal in signals:
		for battery_status in battery_statuses:
			address = battery_status.battery.slave_address
			dbus_path = '/_Battery/' + str(address) + signal.dbus_path
			value = signal.get_value(battery_status)
			service.own_properties.set(dbus_path, value, signal.unit)


def publish_service_signals(service, signals):
	# type: (DBusService, Iterable[ServiceSignal]) -> NoReturn

	for signal in signals:
		service.own_properties.set(signal.dbus_path, signal.value, signal.unit)


def upload_status_to_innovenergy(sock, statuses):
	# type: (socket, Iterable[BatteryStatus]) -> bool

	logging.debug('upload status')

	try:
		for s in statuses:
			sock.sendto(s.serialize(), (cfg.INNOVENERGY_SERVER_IP, cfg.INNOVENERGY_SERVER_PORT))
	except:
		logging.debug('FAILED')
		return False
	else:
		return True


def print_usage():
	print ('Usage:   ' + __file__ + ' <serial device>')
	print ('Example: ' + __file__ + ' ttyUSB0')


def parse_cmdline_args(argv):
	# type: (List[str]) -> str

	if len(argv) == 0:
		logging.info('missing command line argument for tty device')
		print_usage()
		sys.exit(1)

	return argv[0]


def reset_batteries(modbus, batteries):
	# type: (Modbus, Iterable[Battery]) -> NoReturn

	logging.info('Resetting batteries...')

	for battery in batteries:

		result = modbus.write_registers(RESET_REGISTER, [1], unit=battery.slave_address)

		# expecting a ModbusIOException (timeout)
		# BMS can no longer reply because it is already reset
		success = isinstance(result, ModbusIOException)

		outcome = 'successfully' if success else 'FAILED to'
		logging.info('Battery {0} {1} reset'.format(str(battery.slave_address), outcome))

	logging.info('Shutting down fz-sonick driver')
	exit(0)


alive = True   # global alive flag, watchdog_task clears it, update_task sets it

start_time = time.time()
def create_update_task(modbus, service, batteries):
	global start_time
	# type: (Modbus, DBusService, Iterable[Battery]) -> Callable[[],bool]
	"""
	Creates an update task which runs the main update function
	and resets the alive flag
	"""
	_socket = init_udp_socket()
	_signals = signals.init_battery_signals()
	
	csv_signals = signals.create_csv_signals(first(batteries).firmware_version)
	node_numbers = [battery.slave_address for battery in batteries]
	warnings_signals, alarm_signals = signals.read_warning_and_alarm_flags()
	current_warnings = {}
	current_alarms = {}

	def update_task():
		# type: () -> bool
		global alive, start_time

		logging.debug('starting update cycle')

		if service.own_properties.get('/ResetBatteries').value == 1:
			reset_batteries(modbus, batteries)

		statuses = [read_battery_status(modbus, battery) for battery in batteries]

		# Iterate over each node and signal to create rows in the new format
		for i, node in enumerate(node_numbers):
			for s in warnings_signals:
				signal_name = insert_id(s.name, i+1)
				value = s.get_value(statuses[i])
				current_warnings[signal_name] = value
			for s in alarm_signals:
				signal_name = insert_id(s.name, i+1)
				value = s.get_value(statuses[i])
				current_alarms[signal_name] = value
		
		status_message, alarms_number_list, warnings_number_list = update_state_from_dictionaries(current_warnings, current_alarms, node_numbers)

		publish_values_on_dbus(service, _signals, statuses)

		elapsed_time = time.time() - start_time
		if elapsed_time >= 30:
			create_csv_files(csv_signals, statuses, node_numbers, alarms_number_list, warnings_number_list)
			start_time = time.time()
		print("Elapsed time: {:.2f} seconds".format(elapsed_time))

		upload_status_to_innovenergy(_socket, statuses)

		logging.debug('finished update cycle\n')

		alive = True

		return True

	return update_task

def manage_csv_files(directory_path, max_files=20):
	csv_files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
	csv_files.sort(key=lambda x: os.path.getctime(os.path.join(directory_path, x)))
	# Remove oldest files if exceeds maximum
	while len(csv_files) > max_files:
		file_to_delete = os.path.join(directory_path, csv_files.pop(0))
		os.remove(file_to_delete)
def insert_id(path, id_number):
	parts = path.split("/")
	insert_position = parts.index("Devices") + 1
	parts.insert(insert_position, str(id_number))
	return "/".join(parts)

def create_csv_files(signals, statuses, node_numbers, alarms_number_list, warnings_number_list):
	timestamp = int(time.time())
	if timestamp % 2 != 0:
		timestamp-=1
	if not os.path.exists(CSV_DIR):
		os.makedirs(CSV_DIR)
	csv_filename = "{}.csv".format(timestamp)
	csv_path = os.path.join(CSV_DIR, csv_filename)

	with open(csv_path, 'ab') as csvfile:
		csv_writer = csv.writer(csvfile, delimiter=';')
		nodes_config_path = "/Config/Devices/BatteryNodes"
		nodes_list = ",".join(str(node) for node in node_numbers)
		config_row = [nodes_config_path, nodes_list, ""]
		csv_writer.writerow(config_row)
		for i, node in enumerate(node_numbers):
			csv_writer.writerow(["/Battery/Devices/{}/Alarms".format(str(i+1)), alarms_number_list[i], ""])
			csv_writer.writerow(["/Battery/Devices/{}/Warnings".format(str(i+1)), warnings_number_list[i], ""])
			for s in signals:
				signal_name = insert_id(s.name, i+1)
				value = s.get_value(statuses[i])
				row_values = [signal_name, value, s.get_text]
				csv_writer.writerow(row_values)

	csv_data = read_csv_as_string(csv_path)
	
	if csv_data is None:
		print("error while reading csv as string")
		return

	# zip-comp additions
	compressed_csv = compress_csv_data(csv_data)
	compressed_filename = "{}.csv".format(timestamp)

	response = s3_config.create_put_request(compressed_filename, compressed_csv)
	if response.status_code == 200:
		#os.remove(csv_path)
		print("Success")
	else:
		failed_dir = os.path.join(CSV_DIR, "failed")
		if not os.path.exists(failed_dir):
			os.makedirs(failed_dir)
		failed_path = os.path.join(failed_dir, csv_filename)
		os.rename(csv_path, failed_path)
		print("Uploading failed")
		manage_csv_files(failed_dir, 10)

	manage_csv_files(CSV_DIR)


def create_watchdog_task(main_loop):
	# type: (DBusGMainLoop) -> Callable[[],bool]
	"""
	Creates a Watchdog task that monitors the alive flag.
	The watchdog kills the main loop if the alive flag is not periodically reset by the update task.
	Who watches the watchdog?
	"""
	def watchdog_task():
		# type: () -> bool

		global alive

		if alive:
			logging.debug('watchdog_task: update_task is alive')
			alive = False
			return True
		else:
			logging.info('watchdog_task: killing main loop because update_task is no longer alive')
			main_loop.quit()
			return False

	return watchdog_task


def main(argv):
	# type: (List[str]) -> ()
	print("INSIDE DBUS SONICK")
	logging.basicConfig(level=cfg.LOG_LEVEL)
	logging.info('starting ' + __file__)

	tty = parse_cmdline_args(argv)
	modbus = init_modbus(tty)

	batteries = identify_batteries(modbus)

	if len(batteries) <= 0:
		sys.exit(2)

	service = DBusService(service_name=cfg.SERVICE_NAME_PREFIX + tty)

	service.own_properties.set('/ResetBatteries', value=False, writable=True)  # initial value = False

	main_loop = gobject.MainLoop()

	service_signals = signals.init_service_signals(batteries)
	publish_service_signals(service, service_signals)

	update_task = create_update_task(modbus, service, batteries)
	update_task()                        # run it right away, so that all props are initialized before anyone can ask
	watchdog_task = create_watchdog_task(main_loop)

	gobject.timeout_add(cfg.UPDATE_INTERVAL * 2, watchdog_task, priority = gobject.PRIORITY_LOW)  # add watchdog first
	gobject.timeout_add(cfg.UPDATE_INTERVAL, update_task, priority = gobject.PRIORITY_LOW)        # call update once every update_interval

	logging.info('starting gobject.MainLoop')
	main_loop.run()
	logging.info('gobject.MainLoop was shut down')

	sys.exit(0xFF)  # reaches this only on error


main(sys.argv[1:])