﻿/*----------------------------------------------------------------------------*/
/* Firmware Obfuscation V1.0.4                                                */
/* Written in 2019 by Rudy Tellert Elektronik                                 */
/*                                                                            */
/* To the extent possible under law, the author has dedicated all copyright   */
/* and related and neighboring rights to the software "firmware obfuscation"  */
/* and its documentation to the public domain. This software and its          */
/* documentation are distributed without any warranty.                        */
/*                                                                            */
/* See also                                                                   */
/* http://creativecommons.org/publicdomain/zero/1.0/                          */
/* ---------------------------------------------------------------------------*/

using System;

namespace Tellert.Firmware
{
    public static class Obfuscation
    {
        static byte[] NumToByteArray(int i)
        {
            byte[] arr;
            uint u = (uint)i;

            if (u < 0xff)
            {
                arr = new byte[1];
                arr[0] = (byte)u;
            }
            else if (u < 0xffff)
            {
                byte[] uArr = BitConverter.GetBytes((ushort)u);
                if (!BitConverter.IsLittleEndian) Array.Reverse(uArr);
                arr = new byte[3];
                arr[0] = 0xff;
                Array.Copy(uArr, 0, arr, 1, uArr.Length);
            }
            else if (u < 0xffffffff)
            {
                byte[] uArr = BitConverter.GetBytes((uint)u);
                if (!BitConverter.IsLittleEndian) Array.Reverse(uArr);
                arr = new byte[7];
                arr[0] = 0xff;
                arr[1] = 0xff;
                arr[2] = 0xff;
                Array.Copy(uArr, 0, arr, 3, uArr.Length);
            }
            else
            {
                throw new ArgumentException();
            }

            return arr;
        }

        static byte[] MacSegment(int seg, byte[] key, byte[] nonce, byte[] msg, int type, int devNo)
        {
            if (key == null) key = new byte[0];
            if (nonce == null) nonce = new byte[0];
            if (msg == null) msg = new byte[0];
            if (nonce.Length == 0 && msg.Length != 0) throw new ArgumentException();
            if (nonce.Length != 0 && devNo != 0) throw new ArgumentException();
            byte[] nulMsg = new byte[1];
            byte[] segMsg = NumToByteArray(seg);
            byte[] lenKeyMsg = NumToByteArray(key.Length);
            byte[] lenPermutation0Msg = NumToByteArray(Permutation128.BlockLength);
            byte[] lenNonceMsg = NumToByteArray(nonce.Length);
            byte[] typeMsg = NumToByteArray(type);
            byte[] lenMsg = NumToByteArray(msg.Length);
            byte[] devNoMsg = NumToByteArray(devNo);
            byte[] hash;
            Hash h;
            using (h = new Hash())
            {
                if (Hash.UseMacWithPermutation) h.Init(key);
                h.TransformBlock(segMsg, 0, segMsg.Length);
                if (!Hash.UseMacWithPermutation)
                {
                    h.TransformBlock(lenKeyMsg, 0, lenKeyMsg.Length);
                    h.TransformBlock(key, 0, key.Length);
                }
                else
                {
                    h.TransformBlock(lenPermutation0Msg, 0, lenPermutation0Msg.Length);
                    h.TransformBlock(h.Permutation0, 0, h.Permutation0.Length);
                }
                h.TransformBlock(lenNonceMsg, 0, lenNonceMsg.Length);
                h.TransformBlock(nonce, 0, nonce.Length);
                h.TransformBlock(typeMsg, 0, typeMsg.Length);
                if (nonce.Length != 0)
                {
                    h.TransformBlock(lenMsg, 0, lenMsg.Length);
                    hash = h.TransformFinalBlock(msg, 0, msg.Length);
                }
                else
                {
                    hash = h.TransformFinalBlock(devNoMsg, 0, devNoMsg.Length);
                }
            }
            if (Hash.IsDoubleHashingRequired)
            {
                using (h = new Hash())
                {
                    if (Hash.UseMacWithPermutation) h.Init(key);
                    h.TransformBlock(nulMsg, 0, nulMsg.Length);
                    hash = h.TransformFinalBlock(hash, 0, hash.Length);
                }
            }

            return hash;
        }

        public static byte[] Mac(byte[] key, byte[] nonce, byte[] msg, int typeOrMsgNo = 0, int len = 128, int devNo = 0)
        {
            if (len % 8 != 0) throw new ArgumentException();
            len /= 8;
            byte[] result = new byte[len];
            int seg = 1;
            int pos = 0;
            while (len != 0)
            {
                byte[] segment = MacSegment(seg, key, nonce, msg, typeOrMsgNo, devNo);
                seg++;
                int n = Math.Min(len, segment.Length);
                Array.Copy(segment, 0, result, pos, n);
                pos += n;
                len -= n;
            }

            return result;
        }

        public static byte[] KeyedHash(byte[] key, byte[] msg)
        {
            byte[] hash;
            using (Hash h = new Hash())
            {
                if (Hash.UseMacWithPermutation) h.Init(key);
                else h.TransformBlock(key, 0, key.Length);
                hash = h.TransformFinalBlock(msg, 0, msg.Length); 
            }
            return hash;
        }

        public static byte[] Ctr(byte[] key, byte[] nonce, byte[] msg, int msgNo = 0)
        {
            if (msg == null) msg = new byte[0];
            byte[] result = new byte[msg.Length];

            if (Permutation128.UsePermutation128)
            {
                int len = msg.Length;
                int pos = 0;
                int i = 0;
                using (Permutation128 p = new Permutation128())
                {
                    p.Init(key);
                    byte[] plaintextBlock = new byte[16];
                    while (len != 0)
                    {
                        Array.Copy(BitConverter.GetBytes(i), 0, plaintextBlock, 0, 4);
                        if (!BitConverter.IsLittleEndian) Array.Reverse(plaintextBlock, 0, 4);
                        Array.Copy(BitConverter.GetBytes(msgNo), 0, plaintextBlock, 4, 4);
                        if (!BitConverter.IsLittleEndian) Array.Reverse(plaintextBlock, 4, 4);
                        Array.Copy(nonce, 0, plaintextBlock, 8, 8);
                        byte[] obfuscatedBlock = p.ForwardTransform(plaintextBlock);
                        i++;
                        int n = Math.Min(Permutation128.BlockLength, len);
                        for (int j = 0; j < n; j++)
                        {
                            result[pos] = (byte)(obfuscatedBlock[j] ^ msg[pos]);
                            pos++;
                        }
                        len -= n;
                    }
                }
            }
            else
            {
                result = msg;
            }

            return result;
        }

        public static byte[] CreateNonce(ushort index = 0, string nonceFileName = null)
        {
            byte[] nonce = new byte[8];
            DateTime dt = DateTime.UtcNow;
            Utc utc = new Utc(dt);

            if (nonceFileName != null) utc = Utc.GetNonce(nonceFileName);
            nonce[0] = (byte)((dt.Millisecond * 255) / 999);
            Array.Copy(BitConverter.GetBytes(utc.Ticks), 0, nonce, 1, 4);
            if (!BitConverter.IsLittleEndian) Array.Reverse(nonce, 1, 4);
            Array.Copy(BitConverter.GetBytes(index), 0, nonce, 5, 2);
            if (!BitConverter.IsLittleEndian) Array.Reverse(nonce, 5, 2);

            return nonce;
        }

        public static byte[] CreateObfuscatedNonce(byte[] key, ushort index = 0, string nonceFileName = null)
        {
            byte[] nonce = CreateNonce(index, nonceFileName);

            using (Permutation64 p = new Permutation64())
            {
                p.Init(key);
                nonce = p.ForwardTransform(nonce);
            }

            return nonce;
        }

        public static void Add(byte[] data, uint salt)
        {
            byte[] reference = (byte[])data.Clone();
            uint state = (uint)(3000001321 * salt + 9000000101);
            data[0] ^= (byte)state;
            for (int i = 1; i < data.Length; i++)
            {
                state = (uint)(3000001321 * (state + reference[i - 1]) + 9000000101);
                data[i] ^= (byte)(state / 2);
            }
        }

        public static void Remove(byte[] data, uint salt)
        {
            uint state = (uint)(3000001321 * salt + 9000000101);
            data[0] ^= (byte)state;
            for (int i = 1; i < data.Length; i++)
            {
                state = (uint)(3000001321 * (state + data[i - 1]) + 9000000101);
                data[i] ^= (byte)(state / 2);
            }
        }
    }

    static public class Checksum
    {
        static public uint Compute(byte[] arr, int offset, int len)
        {
            uint state = 1;

            for (int i = 0; i < len; i++)
            {
                state += (uint)((state << 8) + (arr[offset + i] + 1));
            }

            return state;
        }

        static public uint Compute(byte[] arr)
        {
            return Compute(arr, 0, arr.Length);
        }
    }
}