#!/usr/bin/python2 -u
# coding=utf-8

import os
import struct
from time import sleep

import serial
from os import system

from pymodbus.client.sync import ModbusSerialClient as Modbus
from pymodbus.exceptions import ModbusIOException
from pymodbus.pdu import ModbusResponse
from os.path import dirname, abspath
from sys import path, argv, exit

path.append(dirname(dirname(abspath(__file__))))

PAGE_SIZE = 0x100
HALF_PAGE = PAGE_SIZE / 2
WRITE_ENABLE = [1]
SERIAL_STARTER_DIR = '/opt/victronenergy/serial-starter/'
FIRMWARE_VERSION_REGISTER = 1054

ERASE_FLASH_REGISTER = 0x2084
RESET_REGISTER = 0x2087


# trick the pycharm type-checker into thinking Callable is in scope, not used at runtime
# noinspection PyUnreachableCode
if False:
    from typing import List, NoReturn, Iterable, Optional


class LockTTY(object):

    def __init__(self, tty):
        # type: (str) -> None
        self.tty = tty

    def __enter__(self):
        system(SERIAL_STARTER_DIR + 'stop-tty.sh ' + self.tty)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        system(SERIAL_STARTER_DIR + 'start-tty.sh ' + self.tty)


def calc_stm32_crc_round(crc, data):
    # type: (int, int) -> int
    crc = crc ^ data
    for _ in range(32):
        xor = (crc & 0x80000000) != 0
        crc = (crc & 0x7FFFFFFF) << 1    # clear bit 31 because python ints have "infinite" bits
        if xor:
            crc = crc ^ 0x04C11DB7

    return crc


def calc_stm32_crc(data):
    # type: (Iterable[int]) -> int
    crc = 0xFFFFFFFF

    for dw in data:
        crc = calc_stm32_crc_round(crc, dw)

    return crc


def init_modbus(tty):
    # type: (str) -> Modbus

    return Modbus(
        port='/dev/' + tty,
        method='rtu',
        baudrate=115200,
        stopbits=1,
        bytesize=8,
        timeout=0.15,  # seconds
        parity=serial.PARITY_ODD)


def failed(response):
    # type: (ModbusResponse) -> bool

    # Todo 'ModbusIOException' object has no attribute 'function_code'
    return response.function_code > 0x80


def clear_flash(modbus, slave_address):
    # type: (Modbus, int) -> bool

    print ('erasing flash...')

    write_response = modbus.write_registers(address=0x2084, values=[1], unit=slave_address)

    if failed(write_response):
        print('erasing flash FAILED')
        return False

    flash_countdown = 17
    while flash_countdown > 0:
        read_response = modbus.read_holding_registers(address=0x2085, count=1, unit=slave_address)

        if failed(read_response):
            print('erasing flash FAILED')
            return False

        if read_response.registers[0] != flash_countdown:
            flash_countdown = read_response.registers[0]

            msg = str(100 * (16 - flash_countdown) / 16) + '%'
            print('\r{0} '.format(msg), end=' ')

    print('done!')

    return True


# noinspection PyShadowingBuiltins
def bytes_to_words(bytes):
    # type: (str) -> List[int]
    return list(struct.unpack('>' + len(bytes)/2 * 'H', bytes))


def send_half_page_1(modbus, slave_address, data, page):
    # type: (Modbus, int, str, int) -> NoReturn

    first_half = [page] + bytes_to_words(data[:HALF_PAGE])
    write_first_half = modbus.write_registers(0x2000, first_half, unit=slave_address)

    if failed(write_first_half):
        raise Exception("Failed to write page " + str(page))


def send_half_page_2(modbus, slave_address, data, page):
    # type: (Modbus, int, str, int) -> NoReturn

    registers = bytes_to_words(data[HALF_PAGE:]) + calc_crc(page, data) + WRITE_ENABLE
    result = modbus.write_registers(0x2041, registers, unit=slave_address)

    if failed(result):
        raise Exception("Failed to write page " + str(page))


def get_fw_name(fw_path):
    # type: (str) -> str
    return fw_path.split('/')[-1].split('.')[0]


def upload_fw(modbus, slave_id, fw_path, fw_name):
    # type: (Modbus, int, str, str) -> NoReturn

    with open(fw_path, "rb") as f:

        size = os.fstat(f.fileno()).st_size
        n_pages = size / PAGE_SIZE

        print('uploading firmware ' + fw_name + ' to BMS ...')

        for page in range(0, n_pages):

            page_data = f.read(PAGE_SIZE)

            msg = "page " + str(page + 1) + '/' + str(n_pages) + ' ' + str(100 * page / n_pages + 1) + '%'
            print('\r{0} '.format(msg), end=' ')

            if is_page_empty(page_data):
                continue

            send_half_page_1(modbus, slave_id, page_data, page)
            send_half_page_2(modbus, slave_id, page_data, page)


def is_page_empty(page):
    # type: (str) -> bool
    return page.count('\xff') == len(page)


def reset_bms(modbus, slave_id):
    # type: (Modbus, int) -> bool

    print ('resetting BMS...')

    result = modbus.write_registers(RESET_REGISTER, [1], unit=slave_id)

    # expecting a ModbusIOException (timeout)
    # BMS can no longer reply because it is already reset
    success = isinstance(result, ModbusIOException)

    if success:
        print('done')
    else:
        print('FAILED to reset battery!')

    return success


def calc_crc(page, data):
    # type: (int, str) -> List[int]

    crc = calc_stm32_crc([page] + bytes_to_words(data))
    crc_bytes = struct.pack('>L', crc)

    return bytes_to_words(crc_bytes)


def identify_battery(modbus, slave_id):
    # type: (Modbus, int) -> Optional[str]

    target = 'battery #' + str(slave_id) + ' at ' + modbus.port

    try:

        print(('contacting ' + target + ' ...'))

        response = modbus.read_input_registers(address=FIRMWARE_VERSION_REGISTER, count=1, unit=slave_id)
        fw = '{0:0>4X}'.format(response.registers[0])

        print(('found battery with firmware ' + fw))

        return fw

    except:
        print(('failed to communicate with ' + target + ' !'))
        return None


def print_usage():
    print(('Usage:   ' + __file__ + ' <serial device> <battery id> <firmware>'))
    print(('Example: ' + __file__ + ' ttyUSB0 2 A08C.bin'))


def parse_cmdline_args(argv):
    # type: (List[str]) -> (str, str, str, str)

    def fail_with(msg):
        print(msg)
        print_usage()
        exit(1)

    if len(argv) < 1:
        fail_with('missing argument for tty device')

    if len(argv) < 2:
        fail_with('missing argument for battery ID')

    if len(argv) < 3:
        fail_with('missing argument for firmware')

    return argv[0], int(argv[1]), argv[2], get_fw_name(argv[2])


def verify_firmware(modbus, battery_id, fw_name):
    # type: (Modbus, int, str) -> NoReturn

    fw_verify = identify_battery(modbus, battery_id)

    if fw_verify == fw_name:
        print('SUCCESS')
    else:
        print('FAILED to verify uploaded firmware!')
        if fw_verify is not None:
            print('expected firmware version ' + fw_name + ' but got ' + fw_verify)


def wait_for_bms_reboot():
    # type: () -> NoReturn

    # wait 20s for the battery to reboot

    print('waiting for BMS to reboot...')

    for t in range(20, 0, -1):
        print('\r{0} '.format(t), end=' ')
        sleep(1)

    print('0')


def main(argv):
    # type: (List[str]) -> NoReturn

    tty, battery_id, fw_path, fw_name = parse_cmdline_args(argv)

    with LockTTY(tty), init_modbus(tty) as modbus:

        if identify_battery(modbus, battery_id) is None:
            return

        clear_flash(modbus, battery_id)
        upload_fw(modbus, battery_id, fw_path, fw_name)

        if not reset_bms(modbus, battery_id):
            return

        wait_for_bms_reboot()

        verify_firmware(modbus, battery_id, fw_name)


main(argv[1:])