#!/usr/bin/env python3

# HPI-independent functions for reading/writing/decoding DAB data packets
# from ASI8900

"""
Analyse DAB data file

Usage: dab_data.py [-v N] [-d N] [-h] <filename>

Options:
  -h, --help                        Show this help message and exit
  -v N, --verbosity=N               Verbosity of analysis output [default: 0]
  -d N, --decode=N                  Decoding level [default: 0]
"""



from collections import namedtuple, defaultdict
import logging
import struct
import time
try:
    from bitarray import bitarray
    import mot
    import msc.datagroups as msc_dg
    import msc.packets as msc_pkt
    mot.logger.setLevel(logging.ERROR)
    msc_dg.logger.setLevel(logging.ERROR)
    msc_pkt.logger.setLevel(logging.ERROR)
    dab_decoders_available = True
except:
    logging.info('DAB decoders not available')
    dab_decoders_available = False

sl_magic = 'DABS'

# 4s : file tag = 'DABS'
# H : file format version
# d : timestamp
raw_file_header = struct.Struct('4sHd')


def write_file_header(f):
    ver = 1
    ts = time.time()
    hdr = raw_file_header.pack(sl_magic, ver, ts)
    f.write(hdr)

# c : type
# B : version
# H : value length
tlv_header = struct.Struct('cBH')

# tlv 'S', 1, len
# I : 32 bit service id
# <service info struct>
raw_service_header = struct.Struct('I')

def write_service(f, s, sid):
    tipe = 'S'
    ver = 1
    length = raw_service_header.size + len(s)
    tlv = tlv_header.pack(tipe, ver, length)
    hdr = raw_service_header.pack(sid)
    f.write(tlv)
    f.write(hdr)
    f.write(s)

# tlv 'C', 1, len
# I : service id
# H : component id
# <component info struct>
raw_component_header = struct.Struct('IH')

def write_component(f, c, sid, cid):
    tipe = 'C'
    ver = 1
    length = raw_component_header.size + len(c)
    tlv = tlv_header.pack(tipe, ver, length)
    hdr = raw_component_header.pack(sid, cid)
    f.write(tlv)
    f.write(hdr)
    f.write(c)


def write_raw_service_list(f, sl):
    write_file_header(f)
    for sid, s, comps in sl:  # (sid, raw_s, [(cid, raw_c), ...])
        write_service(f, s, sid)
        for cid, c in comps:
            write_component(f, c, sid, cid)


def read_raw_service_list(f):
    b = f.read(raw_file_header.size)
    fh = raw_file_header.unpack(b)
    if fh[0] != sl_magic:
        raise ValueError('Not a dab service file')

    sl = []
    while True:
        hdr = f.read(tlv_header.size)
        if len(hdr) < tlv_header.size:
            break
        tipe, ver, length = tlv_header.unpack(hdr)
        if tipe == 'S':
            logging.info('Service version {} length {}'.format(ver, length))
            sid = raw_service_header.unpack(f.read(raw_service_header.size))[0]
            value = f.read(length - raw_service_header.size)
            current = (sid, value, [])
            sl.append(current)
        elif tipe == 'C':
            logging.info('Component version {} length {}'.format(ver, length))
            sid, cid = raw_component_header.unpack(f.read(raw_component_header.size))
            value = f.read(length - raw_component_header.size)
            if sid !=  current[0]:
                logging.warning('Component sid {:#X} != Service id {:#X}', sid, current[0])
            current[2].append((cid, value))
        else:
            logging.warning('Skipping unknown data type {} version {} length {}'.format(tipe, ver, length))
            value = f.read(length)

    return tuple(sl)


def test_write_raw_service_list():
    sl = (
        (1, b'service 1', [(11, b's1 c1'), (12, b's1 c2')]),
        (2, b'service 2', [(21, b's2 c1')]),
    )
    with open('test.services', 'wb') as f:
        write_raw_service_list(f, sl)
        u = b'unknown data'
        f.write(tlv_header.pack('X', 2, len(u)))
        f.write(u)
        f.write(tlv_header.pack('Z', 3, 0))

    with open('test.services', 'rb') as f:
        slr = read_raw_service_list(f)

    assert(sl == slr)

DATA_TYPE_STR = {0:'unspec data', 1:'PAD data', 2:'DLS/DL+ over PAD', 3:'Audio'}
PADDATA_SUBTYPE_STR = {0:'unspec data', 1:'TMC', 5:'TDC/TPEG', 60:'MOT'}
packet_header = struct.Struct('BBBBIIHHHH')
PacketHeader = namedtuple('PacketHeader', 'irq, buff_count, srv_state, data_type, serv_id, comp_id, uatype, byte_count, seg_num, num_segs')


def data_subtype_string(data_type, data_subtype):
    if data_type != 1:
        return 'unspec data'
    else:
        return PADDATA_SUBTYPE_STR.get(data_subtype, 'unknown data subtype')

dls_prefix = struct.Struct('BB')
def decode_dls(pd):
    p1, p2 = dls_prefix.unpack_from(pd)
    toggle = (p1 & 0x80) != 0
    command_flag = (p1 & 0x10) != 0
    command = p1 & 0x0F
    charset = (p2 & 0xF0) >> 4
    s = 'T:{0}, C:{1}, command:0b{2:04b} charset:0b{3:04b}'.format(toggle, command_flag, command, charset)
    if not command_flag:
        # strings are null terminated, drop the last character
        s += ' msg:'+printable(pd[2:-1])
    return s

def data_header(s):
    f = list(packet_header.unpack_from(s))
    f = PacketHeader._make(f)
    return f


def decode_packet(p):
    hdr = data_header(p)
    payload = p[packet_header.size:]
    if len(payload) != hdr.byte_count:
        logging.error('Header byte count {} != returned {}'.format(len(payload), hdr.byte_count))
    return hdr, payload


# Packet file format is
# 8 byte double timestamp 'd'
# 2 byte data length 'H'
# data bytes [length]

file_pkt_hdr = struct.Struct('dH')

def write_packet(f, ts, pd):
    hdr = file_pkt_hdr.pack(ts, len(pd))
    f.write(hdr)
    f.write(pd)

def write_packet_file(fn, packets):
    logging.info('Writing packets to {}'.format(fn))
    with open(fn, 'wb') as f:
        for p in packets:
            write_packet(f, p)


def read_packet_file(fn):
    logging.info('Reading packets from {}'.format(fn))
    packets = []
    with open(fn, 'rb') as dab:
        while True:
            hdr = dab.read(file_pkt_hdr.size)
            if len(hdr) < file_pkt_hdr.size:
                    break
            ts, l = file_pkt_hdr.unpack(hdr)
            packets.append((ts, dab.read(l)))

    return packets



def hex_str(s):
    """Convert string into list of hex strings representing each character"""
    ofs = 0
    res = ''
    to_ascii = lambda c:c if ord(c) in printable_chars else '.'
    to_hex = lambda s,l:' '.join(tuple('{:02x}'.format(ord(c)) for c in s)+('  ',)*(l-len(s)))
    for ofs in range(0,len(s), 16):
        l = s[ofs:ofs+16]
        toks = (
            '0x{:04x} '.format(ofs),
            to_hex(l[:8], 8),
            ' ',
            to_hex(l[8:], 8),
            ''.join(to_ascii(c) for c in l),
            '\n'
            )
        res += ' '.join(toks)
    return res


printable_chars = tuple(range(32, 127))

def printable(s):
    lc = []
    for c in s:
        if ord(c) in printable_chars:
            lc.append(c)
        else:
            lc.append('\\x{:02X}'.format(ord(c)))

    return ''.join(lc)

def packet_reader(file_path):
    logging.info('Reading packets from: {}'.format(file_path))
    with open(file_path, 'rb') as dab:
        while True:
            hdr = dab.read(file_pkt_hdr.size)
            if len(hdr) < file_pkt_hdr.size:
                    break
            ts, l = file_pkt_hdr.unpack(hdr)
            yield ts, dab.read(l)

def decode_packets(pkt_reader, decoder_map, verbose):
    stats = {}
    for ts, buf in pkt_reader:
        hdr, pd = decode_packet(buf)
        # Update stats
        stats_key = (hdr.serv_id, hdr.comp_id)
        pkt_count, byte_count = stats.get(stats_key, (0,0))
        stats[stats_key] = pkt_count+1, byte_count+len(pd)
        data_type = hdr.data_type >> 6
        data_subtype = hdr.data_type & 0x3F
        # Output info
        if verbose > 0:
            out_tuple = (ts, hdr, DATA_TYPE_STR[data_type], data_type,
                data_subtype_string(data_type, data_subtype), data_subtype)
            print('{0:3.3f} {1} data-type:{2} (0b{3:02b}) data-subtype:{4} (0b{5:06b})'.format(*out_tuple))
        if verbose > 2:
            print(hex_str(pd))
        elif verbose > 1:
            print('{}\n'.format(printable(pd)))
        decoder_key = (data_type, data_subtype, hdr.uatype)
        dec_func = decoder_map.get(decoder_key, None)
        if dec_func:
            dec_func(pd)
    return stats

def packet_decoder():
    pass

def datagroup_decoder(check_crc=True):
    dg = None
    while True:
        in_bits = yield dg
        dg = msc_dg.Datagroup.frombits(in_bits, check_crc=check_crc)

def mot_decoder():
    o = []
    cache = mot.Cache()
    while True:
        dg = yield o
        if not dg:
            o = []
            continue
        # TODO: check MOT segment size
        segment = dg.get_data()
        mot_segment_sz = ((ord(segment[0]) & 0x0F) << 8) + ord(segment[1])
        assert mot_segment_sz == len(segment)-2, 'mot_segment_sz:%d len(segment)-2:%d' % (mot_segment_sz, len(segment)-2)
        t_id = dg.get_transport_id()
        items = cache.get(t_id, [])
        if dg not in items: items.append(dg)
        items = sorted(items, key=lambda x: (x.get_type(), x.segment_index))
        cache[t_id] = items
        o = [mot.compile_object(t, cache) for t in list(cache.keys()) if mot.is_complete(t, cache)]

def sls_decoder():
    pass

def dls_dump_app(verbose):
    def f_(payload):
        print('dsl:', decode_dls(payload))
        print()
    return f_

def epg_decoder():
    pass

def sls_dump_app(dg_flag, verbose):
    if not dg_flag:
        pass
    else:
        pass
    dg_gen = datagroup_decoder()
    next(dg_gen)
    mot_gen = mot_decoder()
    next(mot_gen)
    def f_(payload):
        ba = bitarray()
        if not dg_flag:
            # FIXME: use packet decoder + data group decoder
            pass
        else:
            ba.frombytes(payload)
            dg = dg_gen.send(ba)
        if not dg:
            return
        o_list = mot_gen.send(dg)
        if not o_list:
            return
        for o in o_list:
            mot_o_body = o.get_body()
            print('Slide Image: {name} {type} ({size} bytes)'.format(name=o.get_name(), type=o.get_type(), size=len(mot_o_body)))
            print('Slide Parameters:')
            for p in o.get_parameters():
                print('\t', repr(p))
            if verbose:
                print('Slide body:')
                print(hex_str(mot_o_body))
            print()
            #with open(o.get_name(), 'wb') as outfile:
            #    outfile.write(mot_o_body)
    return f_

def main():
    from docopt import docopt

    opts = docopt(__doc__)
    # print(opts)
    verbose = int(opts['--verbosity'])
    decoder_map = {}
    if dab_decoders_available:
        # FIXME: get DG flag from X-PAD UA info or component ID in case of packet mode
        decoder_map[(1, 60, 2)] = sls_dump_app(True, verbose)
        decoder_map[(2, 0, 0)] = dls_dump_app(verbose)
    decode_packets(packet_reader(opts['<filename>']), decoder_map, verbose)

def test():
    test_write_raw_service_list()

def rrsl():
    with open('service_list.raw', 'rb') as f:
        sl = read_raw_service_list(f)

if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    #rrsl()
    #test()
    main()
