// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2021 MediaTek Inc. All Rights Reserved.
 *
 * Author: Weijie Gao <weijie.gao@mediatek.com>
 *
 * Generic data upgrading command
 */

#include <command.h>
#include <env.h>
#include <image.h>
#include <linux/types.h>
#include <linux/mtd/mtd.h>
#include <linux/sizes.h>
#include <jffs2/jffs2.h>
#include <div64.h>
#include <environment.h>
#include <aes.h>
#include <u-boot/crc.h>

#include "load_data.h"
#include "colored_print.h"
#include "upgrade_helper.h"

#define ENCRPTED_IMG_TAG  "encrpted_img"
#define AES_KEY_256  "he9-4+M!)d6=m~we1,q2a3d1n&2*Z^%8$"
#define AES_NOR_IV  "J%1iQl8$=lm-;8AE@"

extern void aes_expand_key(u8 *key, u32 key_size, u8 *expkey);

extern void aes_cbc_encrypt_blocks(u32 key_size, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
			    u32 num_aes_blocks);
typedef  struct
{
    char model[32];        /*model name*/
    char region[32];       /*region*/
    char version[64];      /*version*/
    char dateTime[64];     /*date*/
    unsigned int productHwModel;  /*product hardware model*/
    char modelIndex;       /*model index - default 0:don't change model in nmrp upgrade - others: change model by index in nmrp upgrade*/
    char hwIdNum;          /*hw id list num*/
    char modelNum;         /*model list num*/
    char reserved0[13];    /*reserved*/
    char modelHwInfo[200]; /*save hw id list and model list*/
    char reserved[100];    /*reserved space, if add struct member, please adjust this reserved size to keep the head total size is 512 bytes*/
} __attribute__((__packed__)) image_head_t;

typedef  struct
{
    char checkSum[4];      /*checkSum*/
} __attribute__((__packed__)) image_tail_t;

static const struct data_part_entry *upgrade_parts;
static u32 num_parts;

static bool prompt_post_action(const struct data_part_entry *dpe)
{
	if (dpe->post_action == UPGRADE_ACTION_REBOOT)
		return confirm_yes("Reboot after upgrading? (Y/n):");

	if (dpe->post_action == UPGRADE_ACTION_BOOT)
		return confirm_yes("Run image after upgrading? (Y/n):");

	if (dpe->post_action == UPGRADE_ACTION_CUSTOM) {
		if (dpe->custom_action_prompt)
			return confirm_yes(dpe->custom_action_prompt);

		return true;
	}

	return false;
}

static int do_post_action(const struct data_part_entry *dpe, const void *data,
			  size_t size)
{
	int ret;

	if (dpe->do_post_action) {
		ret = dpe->do_post_action(dpe->priv, dpe, data, size);

		if (dpe->post_action == UPGRADE_ACTION_CUSTOM)
			return ret;
	}

	if (dpe->post_action == UPGRADE_ACTION_REBOOT) {
		printf("Rebooting ...\n\n");
		return run_command("reset", 0);
	}

	if (dpe->post_action == UPGRADE_ACTION_BOOT)
		return run_command("mtkboardboot", 0);

	return CMD_RET_SUCCESS;
}

static const struct data_part_entry *select_part(void)
{
	u32 i;
	char c;

	printf("\n");
	cprintln(PROMPT, "Available parts to be upgraded:");

	for (i = 0; i < num_parts; i++)
		printf("    %d - %s\n", i, upgrade_parts[i].name);

	while (1) {
		printf("\n");
		cprint(PROMPT, "Select a part:");
		printf(" ");

		c = getchar();
		if (c == '\r' || c == '\n')
			continue;

		printf("%c\n", c);
		break;
	}

	i = c - '0';
	if (c < '0' || i >= num_parts) {
		cprintln(ERROR, "*** Invalid selection! ***");
		return NULL;
	}

	return &upgrade_parts[i];
}

static const struct data_part_entry *find_part(const char *abbr)
{
	u32 i;

	if (!abbr)
		return NULL;

	for (i = 0; i < num_parts; i++) {
		if (!strcmp(upgrade_parts[i].abbr, abbr))
			return &upgrade_parts[i];
	}

	cprintln(ERROR, "*** Invalid upgrading part! ***");

	return NULL;
}

static int do_mtkupgrade(struct cmd_tbl *cmdtp, int flag, int argc,
			 char *const argv[])
{
	const struct data_part_entry *dpe = NULL;
	ulong data_load_addr;
	size_t data_size = 0;
	bool do_action;

	board_upgrade_data_parts(&upgrade_parts, &num_parts);

	if (!upgrade_parts || !num_parts) {
		printf("mtkupgrade is not configured!\n");
		return CMD_RET_FAILURE;
	}

	if (argc < 2)
		dpe = select_part();
	else
		dpe = find_part(argv[1]);

	if (!dpe)
		return CMD_RET_FAILURE;

	printf("\n");
	cprintln(PROMPT, "*** Upgrading %s ***", dpe->name);
	printf("\n");

	do_action = prompt_post_action(dpe);

	/* Set load address */
#if defined(CONFIG_SYS_LOAD_ADDR)
	data_load_addr = CONFIG_SYS_LOAD_ADDR;
#elif defined(CONFIG_LOADADDR)
	data_load_addr = CONFIG_LOADADDR;
#endif

	/* Load data */
	if (load_data(data_load_addr, &data_size, dpe->env_name))
		return CMD_RET_FAILURE;

	printf("\n");
	cprintln(PROMPT, "*** Loaded %zd (0x%zx) bytes at 0x%08lx ***",
		 data_size, data_size, data_load_addr);
	printf("\n");

	image_load_addr = data_load_addr;

	/* Validate data */
	if (dpe->validate) {
		if (dpe->validate(dpe->priv, dpe, (void *)data_load_addr,
				  data_size))
			return CMD_RET_FAILURE;
	}

	/* Write data */
	if (dpe->write(dpe->priv, dpe, (void *)data_load_addr, data_size))
		return CMD_RET_FAILURE;

	printf("\n");
	cprintln(PROMPT, "*** %s upgrade completed! ***", dpe->name);

	if (do_action) {
		puts("\n");
		return do_post_action(dpe, (void *)data_load_addr, data_size);
	}

	return CMD_RET_SUCCESS;
}

U_BOOT_CMD(mtkupgrade, 2, 0, do_mtkupgrade,
	   "MTK firmware/bootloader upgrading utility",
	   "mtkupgrade [<part>]\n"
	   "part    - upgrade data part\n"
);

//extern void led_time_tick(ulong times);


int decrypt_image(void)
{
    ulong size = 0;
    ulong file_size = 0;
    unsigned char *src_addr = NULL;
    unsigned char *dst_addr = NULL;
    size_t data_load_addr;
    image_head_t *image_head = NULL;
    char *image_tag = NULL;
    char *fenv_hwid = NULL;
    char *fenv_model = NULL;
    //char *fenv_region = NULL;
    ulong image_size = 0;
    ulong encrypted_size = 0;
    ulong block_size = 0;
    ulong checksum = 0;
    ulong curr_checksum = 0;
    unsigned char key_exp[AES256_EXPAND_KEY_LENGTH] = {0};
    unsigned int aes_blocks = 0;

    env_set("decrypt_result", "bad");
    env_set("filesize_result", "bad");

    size = sizeof(image_head_t) + strlen(ENCRPTED_IMG_TAG) + 4 * 2;

    file_size = env_get_hex("filesize", 0);
    if (file_size < size)
    {
        printf("Image head not found!\n");
        return 1;
    }

    /* Set load address */
#if defined(CONFIG_SYS_LOAD_ADDR)
    data_load_addr = CONFIG_SYS_LOAD_ADDR;
#elif defined(CONFIG_LOADADDR)
    data_load_addr = CONFIG_LOADADDR;
#endif
    
    src_addr = (unsigned char *)data_load_addr;
    dst_addr = (unsigned char *)data_load_addr;

    image_head = (image_head_t *)src_addr;
    src_addr += sizeof(image_head_t);
   
    image_tag = (char *)src_addr;
    if (strncmp(image_tag, ENCRPTED_IMG_TAG, strlen(ENCRPTED_IMG_TAG)))
    {
        printf("Encrpted tag not found!\n");
        return 1;
    }
    src_addr += strlen(ENCRPTED_IMG_TAG);

    printf("Image is encrypted\n");
    printf("model: %s\n", image_head->model);
    printf("region: %s\n", image_head->region);
    printf("version: %s\n", image_head->version);
    printf("dateTime: %s\n", image_head->dateTime);
    //printf("CONFIG_NETGEAR_MODLE_NAME: %s\n", CONFIG_NETGEAR_MODLE_NAME);
    printf("productHwModel: %d\n", image_head->productHwModel);
    printf("modelIndex: %d\n", image_head->modelIndex);
    printf("hwIdNum: %d\n", image_head->hwIdNum);
    printf("modelNum: %d\n", image_head->modelNum);
    printf("modelHwInfo: %s\n", image_head->modelHwInfo);

    //
    if (image_head->modelIndex != 0 && image_head->modelIndex <= image_head->modelNum)
    {
        char loop = 0;
        char delims[] = ";";
        char *result = NULL;
        char achModelHwInfo[512] = { 0 };

        strcpy(achModelHwInfo, image_head->modelHwInfo);
        result = strtok(achModelHwInfo, delims);
        //skip hw id
        while(result != NULL && loop < image_head->hwIdNum)
        {
            printf("hwid is \"%s\"\n", result);
            loop++;
            result = strtok(NULL, delims);
        }
        //model list
        loop = 0;
        while(result != NULL && loop < image_head->modelNum)
        {
            loop++;
            printf("model[%d] is \"%s\"\n", loop, result);

            if (loop == image_head->modelIndex)
            {
                env_set("fenv_model", result);
                break;
            }
            result = strtok(NULL, delims);  
        }
    }

    fenv_hwid = env_get("fenv_hw_id");
    printf("fenv_hwid: %s\n", fenv_hwid);
    if (!fenv_hwid || !strstr(image_head->modelHwInfo, fenv_hwid))
    {
        printf("Image hw id not match!\n");
        return 1;
    }

    fenv_model = env_get("fenv_model");
    printf("fenv_model: %s\n", fenv_model);
    if (!fenv_model || !strstr(image_head->modelHwInfo, fenv_model))
    {
        printf("Image model not match!\n");
        return 1;
    }
#if 0
    fenv_region = env_get("fenv_region");
    if (!fenv_region || strcmp(fenv_region, image_head->region))
    {
        printf("Image region not match!\n");
        return 1;
    }
#endif

    image_size = ntohl(*(uint *)src_addr);
    src_addr += sizeof(uint);
    printf("size: 0x%lx\n", image_size);

    encrypted_size = DIV_ROUND_UP(image_size, AES_BLOCK_LENGTH) * AES_BLOCK_LENGTH;
    
    block_size = ntohl(*(uint *)src_addr);
    src_addr += sizeof(uint);
    printf("block size: 0x%lx\n", block_size);
    
    if (block_size % AES_BLOCK_LENGTH)
    {
        printf("Image block size not times of AES_BLOCK_LENGTH!\n");
        return 1;
    }

    if (file_size < (size + encrypted_size + sizeof(image_tail_t)))
    {
        printf("Image incomplete!\n");
        return 1;
    }

    checksum = ntohl(*(uint *)(src_addr + encrypted_size));
    printf("checksum: 0x%lx\n", checksum);

    curr_checksum = crc32_no_comp(0, (uint *)data_load_addr, size + encrypted_size);
    printf("curr_checksum: 0x%lx\n", curr_checksum);
    if (curr_checksum != checksum)
    {
        printf("Image checksum error!\n");
        return 1;
    }

    printf("Decrypt image...\n");

    aes_expand_key((u8 *)AES_KEY_256, AES256_KEY_LENGTH, key_exp);
    for (size = 0; size < encrypted_size; size += block_size)
    {
        if (size + block_size > encrypted_size)
        {
            block_size = encrypted_size - size;
        }

        aes_blocks = DIV_ROUND_UP(block_size, AES_BLOCK_LENGTH);
        aes_cbc_decrypt_blocks(AES256_KEY_LENGTH, key_exp, (u8 *)AES_NOR_IV, (u8 *)src_addr, (u8 *)dst_addr, aes_blocks);

        src_addr += block_size;
        dst_addr += block_size;
        //led_time_tick(get_timer(0));
    }
    printf("Decrypt finish\n");

    env_set_hex("filesize", image_size);

    env_set("filesize_result", "good");

    env_set("decrypt_result", "good");

    return 0;
}

static int do_write_img(struct cmd_tbl *cmdtp, int flag, int argc, char *const argv[])
{
    size_t data_load_addr;
    uint32_t data_size = 0;
    const struct data_part_entry *dpe = NULL;
    
    board_upgrade_data_parts(&upgrade_parts, &num_parts);

    if (!upgrade_parts || !num_parts) {
        printf("mtkupgrade is not configured!\n");
        return CMD_RET_FAILURE;
    }
    
    dpe = find_part("fw");
   
    if (!dpe)
        return CMD_RET_FAILURE;

    cprintln(PROMPT, "*** Upgrading %s ***", dpe->name);
    printf("\n");

/* Set load address */
#if defined(CONFIG_SYS_LOAD_ADDR)
    data_load_addr = CONFIG_SYS_LOAD_ADDR;
#elif defined(CONFIG_LOADADDR)
    data_load_addr = CONFIG_LOADADDR;
#endif
 
    if (0 != decrypt_image())
        return CMD_RET_FAILURE;

    data_size = env_get_hex("filesize", 0);
    printf("filesize  = %d\n", data_size);
    printf("\n");

    /* Write data */
    if (dpe->write(dpe->priv, dpe, (void *)data_load_addr, data_size))
        return CMD_RET_FAILURE;

    printf("\n");
    cprintln(PROMPT, "*** %s upgrade completed! ***", dpe->name);
    printf("\n");
    return CMD_RET_SUCCESS;
}

U_BOOT_CMD(writeimg, 2, 0, do_write_img,
    "write firmware",
    "write firmware\n"
);
