/*----------------------------------------------------------------------------*/
/* Firmware Obfuscation V1.0.4                                                */
/* Written in 2019 by Rudy Tellert Elektronik                                 */
/*                                                                            */
/* To the extent possible under law, the author has dedicated all copyright   */
/* and related and neighboring rights to the software "firmware obfuscation"  */
/* and its documentation to the public domain. This software and its          */
/* documentation are distributed without any warranty.                        */
/*                                                                            */
/* See also                                                                   */
/* http://creativecommons.org/publicdomain/zero/1.0/                          */
/* ---------------------------------------------------------------------------*/

#include "bootldr.h"

#ifdef _MSC_VER
#pragma warning( disable : 4312 )
#endif

#ifndef USE_TINY_BOOTLOADER
#define DATA_SIZE (((BASE64_SIZE)*3+3)/4)
#else
#define MIN_DATA_SIZE (HASH_LENGTH+sizeof(data_t))
#define DATA_SIZE MAX(sizeof(map_t), MIN_DATA_SIZE)
#endif
static Byte data[DATA_SIZE];
static unsigned dataSize;
static unsigned blockIndex;
static unsigned padding;

static void SerWriteCommError(char ch, int no)
{
    dataSize = blockIndex = padding = 0;
    SerWriteError(ch);
}

#ifndef REUSE_RX_BUFFER
#define MapLock()
#define MapUnlock()
#else
static void MapLock(void)
{
    rxBufferLocked = TRUE;
    QueueFlush(&rxQueue);
}

static void MapUnlock(void)
{
    rxBufferLocked = FALSE;
}
#endif

void main(void)
{
    const application_info_t *appInfo = NULL;
    Byte ch;
    Byte prevCh;
    Bool denyFurtherBlocks;
    Bool error;
    static Byte block[4];
    Byte *dataBuffer;
    unsigned blockCounter;
    static Byte nonce[NONCE_LENGTH];
#ifndef REUSE_DATA
#ifndef REUSE_RX_BUFFER
#define mapDataSize ((DATA_SIZE) + 0x60) /* additional overhead due to mem_block_t */
    static Byte mapData[mapDataSize];
#else
#define mapDataSize sizeof(rxBuffer)
    Byte *mapData = rxBuffer;
#endif
#define map ((map_t*)mapData)
#else
#define map ((map_t*)data)
#endif
    static char strBlockCounter[9];
#ifdef USE_ERROR_TEST
    Byte testCounter = 0;
#endif

    FlashInit();
    InitPermutation128KeySchedule();
    TimerInit();
    SerInit(9600);

    BEGIN_TIMER();
    if (SerIsActive()) break;
    END_TIMER(200);
    if (SerIsActive()) {
        const Byte *key = "\xed\xf1\xfe\xf2\xf5\xeb\xfc\xf3\xf6\xf8"
            "\xfd\xf9\xef\xfb\xfa\xec\xf0\xf4\xf7\xee";
        const Byte *keyEnd = key + 20;
        const Byte *p = key;
        Byte b;

        BEGIN_TIMER();
        if (SerGet(&b)) {
            if (b == *p) {
                if (++p == keyEnd) break;
            }
            else {
                p = key;
            }
        }
        END_TIMER(2000);

        if (p == keyEnd) {
            Byte attempts;
            for (attempts = 40; --attempts; ) {
                SerPutWithoutChecksum('>');
                BEGIN_TIMER();
                while (SerGet(&b)) {
                    if (b == 'a') {
                        SerWriteOk('a');
                        goto StartSession;
                    }
                }
                END_TIMER(100);
            }
        }
    }

CheckAppInfo:
    appInfo = GetAppInfo();
    if (appInfo == NULL) goto InvalidApp;
    if (appInfo->Area1Size && GetChecksum32((const void *)appInfo->Area1Begin, appInfo->Area1Size) != appInfo->Area1Checksum) goto InvalidApp;
    if (appInfo->Area2Size && GetChecksum32((const void *)appInfo->Area2Begin, appInfo->Area2Size) != appInfo->Area2Checksum) goto InvalidApp;
    if (appInfo->Area3Size && GetChecksum32((const void *)appInfo->Area3Begin, appInfo->Area3Size) != appInfo->Area3Checksum) goto InvalidApp;

/* StartApp: */
    /* Close resources */
    SerClose();
    TimerClose();
    
    /* Clear RAM [insert code here] */
    
    /* Disable interrupt [insert code here] */

#ifdef HWINFO_FIXED_ADDRESS
    ((entry_point_t)(appInfo->EntryPoint))();
#endif

InvalidApp:
#ifdef HWINFO_FIXED_ADDRESS
    for (;;) ;
#else
    return;
#endif

StartSession:

    /* New session */
    dataBuffer = data;
    dataSize = 0;
    denyFurtherBlocks = FALSE;
    blockCounter = 0;
    blockIndex = 0;
    padding = 0;
    appInfo = GetAppInfo();

    prevCh = 0; /* initialize previously received character */

    for (;;) {
        if (SerGet(&ch)) {
            switch (ch) {
            case CREATE_NEW_SESSION:
                dataBuffer = data;
                dataSize = 0;
                denyFurtherBlocks = FALSE;
                blockCounter = 0;
                blockIndex = 0;
                padding = 0;
                appInfo = GetAppInfo();
                SerWriteOk(CREATE_NEW_SESSION);
                break;
            case GET_BOOTLOADER_INFO:
                SerWriteOkLine(GET_BOOTLOADER_INFO, "TBL 3"
#ifdef USE_TINY_BOOTLOADER
                    "T"
#else
                    "S"
#endif
#if defined(SUPPORT_BAUDRATE_230400)
                    "D"
#elif defined(SUPPORT_BAUDRATE_115200)
                    "C"
#elif defined(SUPPORT_BAUDRATE_57600)
                    "B"
#else
                    "A"
#endif
                );
                break;
            case GET_HARDWARE_INFO:
                SerWriteOkLine(GET_HARDWARE_INFO, hwInfo.Info);
                break;
            case GET_APP_NAME:
                if (appInfo == NULL) appInfo = GetAppInfo();
                if (appInfo) {
                    SerPut('r');
                    SerPut(GET_APP_NAME);
                    SerWrite(appInfo->Name);
                    SerPut(' ');
                    SerWrite(appInfo->VersionStr);
                    SerPut('\n');
                }
                else {
                    SerWriteOk(GET_APP_NAME);
                }
                break;
            case GET_APP_INFO:
                if (appInfo == NULL) appInfo = GetAppInfo();
                if (appInfo) {
                    SerWriteOkLine(GET_APP_INFO, appInfo->Info);
                }
                else {
                    SerWriteOk(GET_APP_INFO);
                }
                break;
            case GET_CONFIG_INFO: {
                const config_info_t *configInfo;
                if (appInfo == NULL) appInfo = GetAppInfo();
                configInfo = GetConfigInfo(appInfo);
                if (configInfo) {
                    SerWriteOkLine(GET_CONFIG_INFO, appInfo->Desc);
                }
                else {
                    SerWriteOk(GET_CONFIG_INFO);
                }
                break; }
            case GET_DATA_INFO: {
                const char *dataInfo;
                if (appInfo == NULL) appInfo = GetAppInfo();
                dataInfo = GetDataInfo(appInfo);
                if (dataInfo) {
                    SerWriteOkLine(GET_DATA_INFO, dataInfo);
                }
                else {
                    SerWriteOk(GET_DATA_INFO);
                }
                break;  }
            case PREPARE_BLOCK_MODE_FOR_COMPARISON:
            case PREPARE_BLOCK_MODE_FOR_WRITE:
            BlockModeStart:
                dataSize = 0;
                blockIndex = 0;
                padding = 0;
                SerWriteOk(ch);
                break;
            case COMPARE_OBFUSCATED_BLOCK:
            case COMPARE_NON_OBFUSCATED_BLOCK:
                if ((prevCh != PREPARE_BLOCK_MODE_FOR_COMPARISON && prevCh != ch) || denyFurtherBlocks) {
                    SerWriteCommError(ch, ERR_PREPARE_COMPARISON_MISSING);
                    ch = 0;
                    break;
                }
                goto BlockMode;
            case WRITE_OBFUSCATED_BLOCK:
            case WRITE_NON_OBFUSCATED_BLOCK:
                if ((prevCh != PREPARE_BLOCK_MODE_FOR_WRITE && prevCh != ch) || denyFurtherBlocks) {
                    SerWriteCommError(ch, ERR_PREPARE_WRITE_MISSING);
                    ch = 0;
                    break;
                }
            BlockMode:
                MapLock();
                dataSize = (dataSize >= padding) ? dataSize - padding : 0;
                if (ch == COMPARE_NON_OBFUSCATED_BLOCK || ch == WRITE_NON_OBFUSCATED_BLOCK) {
                    dataBuffer = data + CHECKSUM_LENGTH;
                    if (dataSize < CHECKSUM_LENGTH) {
                        SerWriteCommError(ch, ERR_BUFFER_TOO_SMALL);
                        MapUnlock();
                        ch = 0;
                        break;
                    }
                    dataSize -= CHECKSUM_LENGTH;
#ifdef SUPPORT_BIG_ENDIAN
                    if (!IsLittleEndian()) ReverseArray(data, sizeof(UInt));
#endif
                    if (GetChecksum32(dataBuffer, dataSize) != *(UInt*)data) {
                        SerWriteCommError(ch, ERR_CHECKSUM);
                        MapUnlock();
                        ch = 0;
                        break;
                    }
#ifndef REUSE_DATA
                    if (!GetMap(dataBuffer, dataSize, mapData, mapDataSize, MAP_UNOBFUSCATED, appInfo, blockCounter)) {
#else
                    if (!GetMap(dataBuffer, dataSize, data, DATA_SIZE, MAP_UNOBFUSCATED, appInfo, blockCounter)) {
#endif
                        denyFurtherBlocks = TRUE;
                        SerWriteCommError(ch, ERR_MAP);
                        MapUnlock();
                        ch = 0;
                        break;
                    }
                }
                else {
                    const Byte *tag;
                    static MacArgs mArgs;
                    static char mac[16];
                    dataBuffer = data;
                    tag = (const Byte*)dataBuffer;
                    if (blockCounter == 0) {
                        const Byte *rxNonce = (const Byte *)dataBuffer;
                        tag = rxNonce + NONCE_LENGTH;
                        if (dataSize < TAG_LENGTH + NONCE_LENGTH) {
                            SerWriteCommError(ch, ERR_BUFFER_TOO_SMALL);
                            MapUnlock();
                            ch = 0;
                            break;
                        }
                        memcpy(nonce, rxNonce, NONCE_LENGTH);
                        dataBuffer += TAG_LENGTH + NONCE_LENGTH;
                        dataSize -= TAG_LENGTH + NONCE_LENGTH;
                    }
                    else {
                        if (dataSize < TAG_LENGTH) {
                            SerWriteCommError(ch, ERR_BUFFER_TOO_SMALL);
                            MapUnlock();
                            ch = 0;
                            break;
                        }
                        dataBuffer += TAG_LENGTH;
                        dataSize -= TAG_LENGTH;
                    }
                    zeromem(&mArgs, sizeof(MacArgs));
                    mArgs.Key = (Byte*)hwInfo.SigningKey;
                    mArgs.KeySize = 16;
                    mArgs.Msg = dataBuffer;
                    mArgs.MsgSize = dataSize;
                    mArgs.Nonce = nonce;
                    mArgs.NonceSize = NONCE_LENGTH;
                    mArgs.Type = blockCounter;
                    if (GetMac(&mArgs, mac, TAG_LENGTH) != TAG_LENGTH || !CompareChecksum(tag, mac, TAG_LENGTH)) {
                        SerWriteCommError(ch, ERR_CHECKSUM);
                        MapUnlock();
                        ch = 0;
                        break;
                    }
                    GetCtr(hwInfo.ObfuscationKey, 7, nonce, NONCE_LENGTH, dataBuffer, dataSize, blockCounter);
#ifndef REUSE_DATA
                    if (!GetMap(dataBuffer, dataSize, mapData, mapDataSize, MAP_OBFUSCATED, appInfo, blockCounter)) {
#else
                    if (!GetMap(dataBuffer, dataSize, data, DATA_SIZE, MAP_OBFUSCATED, appInfo, blockCounter)) {
#endif
                        denyFurtherBlocks = TRUE;
                        SerWriteCommError(ch, ERR_OBFUSCATED_MAP);
                        MapUnlock();
                        ch = 0;
                        break;
                    }
                }
                error = FALSE;
                if (ch == COMPARE_OBFUSCATED_BLOCK || ch == COMPARE_NON_OBFUSCATED_BLOCK) {
                    /* compare data */
                    unsigned i;
                    for (i = 0; i < map->MemoryCount; i++) {
                        memory_item_t *mi = map->MemoryItems + i;
                        if (!FlashCompare(mi->DataItems, mi->Address, (UInt)mi->DataCount)) {
                            error = TRUE;
                            break;
                        }
                    }
                }
                else {
                    /* erase/program data */
                    FlashEnter();
                    if (!FlashEraseRegions(map->EraseRegions)) {
                        error = TRUE;
                    }
                    unsigned i, k;
                    for (i = 0; i < map->EraseCount; i++) {
                        erase_item_t *ei = map->EraseItems + i;
                        if (ei->Length == 0) continue;
                        if (!FlashErase(ei->Address, ei->Length)) {
                            error = TRUE;
                        }
                    }
                    for (i = 0; i < map->MemoryCount; i++) {
                        memory_item_t *mi = map->MemoryItems + i;
                        if (mi->DataCount == 0) continue;
                        for (k = 0; k < mi->DataCount; k += 256) {
                            if (!FlashWrite(mi->DataItems + k, mi->Address + k, 256)) {
                                error = TRUE;
                                break;
                            }
                        }
                    }
                    FlashExit();
                    for (i = 0; i < map->MemoryCount; i++) {
                        memory_item_t *mi = map->MemoryItems + i;
                        if (!FlashCompare(mi->DataItems, mi->Address, (UInt)mi->DataCount)) {
                            error = TRUE;
                            break;
                        }
                    }
                }
#ifdef USE_ERROR_TEST
                if (testCounter < 2) error = TRUE;
                if (++testCounter > 2) testCounter = 0;
#endif
                if (error) {
                    SerWriteCommError(ch, ERR_FLASH);
                    MapUnlock();
                    ch = 0;
                    break;
                }
                if (map->LastBlock) denyFurtherBlocks = TRUE;
                blockCounter++;
                MapUnlock();
                goto BlockModeStart;
            case GET_BLOCK_COUNTER: 
                SerPut('r');
                SerPut(GET_BLOCK_COUNTER);
                if (GetHexString(blockCounter, strBlockCounter, sizeof(strBlockCounter))) {
                    SerWrite(strBlockCounter);
                }
                SerPut('\n');
                ch = prevCh;
                break; 
            case CHANGE_BAUDRATE: {
                unsigned baudRate = 0;
                switch (prevCh) {
                case PREPARE_BAUDRATE_9600:
                    baudRate = 9600; 
                    break;
#ifdef SUPPORT_BAUDRATE_57600
                case PREPARE_BAUDRATE_57600: 
                    baudRate = 57600; 
                    break;
#endif
#ifdef SUPPORT_BAUDRATE_115200
                case PREPARE_BAUDRATE_115200: 
                    baudRate = 115200; 
                    break;
#endif
#ifdef SUPPORT_BAUDRATE_230400
                case PREPARE_BAUDRATE_230400:
                    baudRate = 230400; 
                    break;
#endif
                }
                if (baudRate == 0) {
                    SerWriteCommError(ch, ERR_BAUDRATE);
                    break;
                }
                SerPut('r');
                SerPut(CHANGE_BAUDRATE);
                SerPut(prevCh);
                SerPut('\n');
                SerWaitForTx();
                /* 3 Chars require 3.125 ms @ 9600 baud */
                BEGIN_TIMER();
                END_TIMER(100); /* wait a bit longer than 3.125 ms (for WIN32 port driver) */
                SerSetBaudrate(baudRate);
                break; }
            case PREPARE_BAUDRATE_9600:
                SerWriteOk(PREPARE_BAUDRATE_9600);
                break;
#ifdef SUPPORT_BAUDRATE_57600
            case PREPARE_BAUDRATE_57600:
                SerWriteOk(PREPARE_BAUDRATE_57600);
                break;
#endif
#ifdef SUPPORT_BAUDRATE_115200
            case PREPARE_BAUDRATE_115200: 
                SerWriteOk(PREPARE_BAUDRATE_115200);
                break;
#endif
#ifdef SUPPORT_BAUDRATE_230400
            case PREPARE_BAUDRATE_230400: 
                SerWriteOk(PREPARE_BAUDRATE_230400);
                break;
#endif
            case UPDATE_APP_INFO: /* update appInfo */
                appInfo = GetAppInfo();
                SerWriteOk(UPDATE_APP_INFO);
                break;
            case RUN_APP:
                if (prevCh == PREPARE_TO_RUN_APP) {
                    SerWriteOk(RUN_APP);
                    goto CheckAppInfo;
                }
                else {
                    SerWriteCommError(RUN_APP, ERR_PREPARE_RUN_MISSING);
                }
                break;
            case PREPARE_TO_RUN_APP:
                SerWriteOk(PREPARE_TO_RUN_APP);
                break;
            default:
                if (ch >= ' ' && ch <= '\x60') { /* base64 or padding character */
                    ch -= ' ';
                    if (ch == '\x40') {
                        ch = '\0';
                        padding++;
                    }
                    block[blockIndex] = (Byte)ch;
                    if (++blockIndex == 4) blockIndex = 0;
                    if (blockIndex == 0 && dataSize < DATA_SIZE - 2) {
                        data[dataSize++] = (block[0] << 2) | (block[1] >> 4);
                        data[dataSize++] = (block[1] << 4) | (block[2] >> 2);
                        data[dataSize++] = (block[2] << 6) | block[3];
                    }
                    continue;
                }
                else if (ch >= '\xeb') { /* echo prompt */
                    SerPutWithoutChecksum('>');
                    continue;
                }
                else {
                    Byte b;
                    while (SerGet(&b));
                    SerWriteCommError(ch, ERR_INVALID_CHAR);
                }
                break;
            }
            prevCh = ch;
        }
    }
}