import struct

import config as cfg
from data import LedState, BatteryStatus

# trick the pycharm type-checker into thinking Callable is in scope, not used at runtime
# noinspection PyUnreachableCode
if False:
	from typing import Callable, List, Iterable, Union, AnyStr, Any


def read_bool(base_register, bit):
	# type: (int, int) -> Callable[[BatteryStatus], bool]

	# TODO: explain base register offset
	register = base_register + int(bit/16)
	bit = bit % 16

	def get_value(status):
		# type: (BatteryStatus) -> bool
		value = status.modbus_data[register - cfg.BASE_ADDRESS]
		return value & (1 << bit) > 0

	return get_value


def read_float(register, scale_factor=1.0, offset=0.0):
	# type: (int, float, float) -> Callable[[BatteryStatus], float]

	def get_value(status):
		# type: (BatteryStatus) -> float
		value = status.modbus_data[register - cfg.BASE_ADDRESS]

		if value >= 0x8000:    # convert to signed int16
			value -= 0x10000   # fiamm stores their integers signed AND with sign-offset @#%^&!

		return (value + offset) * scale_factor

	return get_value


def read_registers(register, count):
	# type: (int, int) -> Callable[[BatteryStatus], List[int]]

	start = register - cfg.BASE_ADDRESS
	end = start + count

	def get_value(status):
		# type: (BatteryStatus) -> List[int]
		return [x for x in status.modbus_data[start:end]]

	return get_value


def comma_separated(values):
	# type: (Iterable[str]) -> str
	return ", ".join(set(values))


def count_bits(base_register, nb_of_registers, nb_of_bits, first_bit=0):
	# type: (int, int, int, int) -> Callable[[BatteryStatus], int]

	get_registers = read_registers(base_register, nb_of_registers)
	end_bit = first_bit + nb_of_bits

	def get_value(status):
		# type: (BatteryStatus) -> int

		registers     = get_registers(status)
		bin_registers = [bin(x)[-1:1:-1] for x in registers]            # reverse the bits in each register so that bit0 is to the left
		str_registers = [str(x).ljust(16, "0") for x in bin_registers]  # add leading zeroes, so all registers are 16 chars long
		bit_string    = ''.join(str_registers)             # join them, one long string of 0s and 1s
		filtered_bits = bit_string[first_bit:end_bit]      # take the first nb_of_bits bits starting at first_bit

		return filtered_bits.count('1')  # count 1s

	return get_value


def read_led_state(register, led):
	# type: (int, int) -> Callable[[BatteryStatus], int]

	read_lo = read_bool(register, led * 2)
	read_hi = read_bool(register, led * 2 + 1)

	def get_value(status):
		# type: (BatteryStatus) -> int

		lo = read_lo(status)
		hi = read_hi(status)

		if hi:
			if lo:
				return LedState.blinking_fast
			else:
				return LedState.blinking_slow
		else:
			if lo:
				return LedState.on
			else:
				return LedState.off

	return get_value


# noinspection PyShadowingNames
def unit(unit):
	# type: (unicode) -> Callable[[unicode], unicode]

	def get_text(v):
		# type: (unicode) -> unicode
		return "{0}{1}".format(str(v), unit)

	return get_text


def const(constant):
	# type: (any) -> Callable[[any], any]
	def get(*args):
		return constant
	return get


def mean(numbers):
	# type: (List[Union[float,int]]) -> float
	return float(sum(numbers)) / len(numbers)


def first(ts, default=None):
	return next((t for t in ts), default)


def bitfields_to_str(lists):
	# type: (List[List[int]]) -> str

	def or_lists():
		# type: () -> Iterable[int]

		length = len(first(lists))
		n_lists = len(lists)

		for i in range(0, length):
			e = 0
			for l in range(0, n_lists):
				e = e | lists[l][i]
			yield e

	hexed = [
		'{0:0>4X}'.format(x)
		for x in or_lists()
	]

	return ' '.join(hexed)


def pack_string(string):
	# type: (AnyStr) -> Any
	data = string.encode('UTF-8')
	return struct.pack('B', len(data)) + data


def read_bitmap(register):
	# type: (int) -> Callable[[BatteryStatus], int]

	def get_value(status):
		# type: (BatteryStatus) -> int
		value = status.modbus_data[register - cfg.BASE_ADDRESS]
		return value 

	return get_value

def return_in_list(ts):
	return ts

def first(ts):
	return next(t for t in ts)

def read_hex_string(register, count):
	# type: (int, int) -> Callable[[BatteryStatus], str]
	"""
	reads count consecutive modbus registers from start_address,
	and returns a hex representation of it:
	e.g. for count=4: DEAD BEEF DEAD BEEF.
	"""
	start = register - cfg.BASE_ADDRESS
	end = start + count

	def get_value(status):
		# type: (BatteryStatus) -> str
		return ' '.join(['{0:0>4X}'.format(x) for x in status.modbus_data[start:end]])

	return get_value