/* Inexact source code package.
 *
 * Written in 2019 by <ben@hackade.org>.
 *
 * To the extent possible under law, the author have dedicated all copyright
 * and related and neighboring rights to this software to the public domain
 * worldwide. This software is distributed without any warranty.
 *
 * You should have received a copy of the CC0 Public Domain Dedication along with
 * this software. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
 */

#include "inexact.h"
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>  // for chmod only
#include <unistd.h>
#include "argon2.h"
#include "base64.h"
#include "chacha20_drng.h"
#include "curve25519.h"
#include "norx_inexact.h"
#include "readpassphrase.h"
#include "sha3.h"

/*
 * Generate a random private key and derive the X25519 public key.
 * Store result in base64 to files.
 *
 */
int generate_keys(const char *seckey_filename, const char *pubkey_filename,
                  int no_password) {
    unsigned char *privatekey_b64 = NULL;
    unsigned char *publickey_b64 = NULL;
    unsigned char *privatekey_b64t = NULL;
    unsigned char *publickey_b64t = NULL;
    unsigned char salt[32] = {0};
    unsigned char privatekey_buffer[64] = {0};

    char password[256] = {0};
    char *password_out = NULL;
    char password_verif[256] = {0};
    char *password_verif_out = NULL;
    size_t password_len = 0;
    size_t password_verif_len = 0;

    const uint32_t t_cost = 3;
    const uint32_t m_cost = (1 << 12);
    const uint32_t parallelism = 1;

    size_t private_base64_len = 0;
    size_t private_b64t_len = 0;
    size_t public_base64_len = 0;
    size_t public_b64t_len = 0;

    FILE *fs = NULL;
    FILE *fp = NULL;

    struct chacha20_drng *drng = NULL;

    int exitcode = 1;

    curve25519_key privatekey;
    curve25519_key publickey;

    /* Secret (or private) key */

    if (no_password) {
        int ret = drng_chacha20_init(&drng);
        if (ret) {
            printf("Chacha20 allocation failed: %d\n", ret);
            goto exit;
        }
        if (drng_chacha20_get(drng, privatekey, sizeof(curve25519_key))) {
            printf("Getting random numbers failed\n");
            goto exit;
        }

        curve25519_donna_basepoint(publickey, privatekey);

        memcpy(privatekey_buffer, privatekey, sizeof(curve25519_key));
        memcpy(privatekey_buffer + sizeof(curve25519_key), publickey,
               sizeof(curve25519_key));
    } else {
        struct chacha20_drng *drng;
        int ret = drng_chacha20_init(&drng);
        if (ret) {
            printf("Chacha20 allocation failed: %d\n", ret);
            goto exit;
        }
        if (drng_chacha20_get(drng, salt, sizeof(salt))) {
            printf("Getting random numbers failed\n");
            goto exit;
        }

        password_out =
            readpassphrase("Password : ", password, sizeof(password), 0);
        if (password_out != password) {
            printf("password input failed.\n");
            goto exit;
        }
        password_len = strlen(password);

        password_verif_out =
            readpassphrase("Verifying, please re-enter : ", password_verif,
                           sizeof(password_verif), 0);
        if (password_verif_out != password_verif) {
            printf("password input failed.\n");
            goto exit;
        }
        password_verif_len = strlen(password_verif);

        if (password_len != password_verif_len) {
            printf("Mismatch.\n");
            goto exit;
        }

        if (memcmp(password, password_verif, password_len) != 0) {
            printf("Mismatch.\n");
            goto exit;
        }

        int a2res = argon2id_hash_raw(t_cost, m_cost, parallelism, password,
                                      password_len, salt, sizeof(salt),
                                      privatekey, sizeof(curve25519_key));
        if (a2res != ARGON2_OK) {
            printf("argon2 failed.");
            goto exit;
        }

        curve25519_donna_basepoint(publickey, privatekey);
        memcpy(privatekey_buffer, salt, sizeof(salt));
        memcpy(privatekey_buffer + sizeof(salt), publickey,
               sizeof(curve25519_key));
    }

    privatekey_b64 = base64_encode(privatekey_buffer, sizeof(privatekey_buffer),
                                   &private_base64_len);
    if (privatekey_b64 == NULL) {
        printf("base64 encoding failed.\n");
        goto exit;
    }

    privatekey_b64t =
        b64t_encode(privatekey_b64, private_base64_len, &private_b64t_len);
    if (privatekey_b64t == NULL) {
        printf("b64t encoding failed.\n");
        goto exit;
    }

    fs = fopen(seckey_filename, "wb");
    if (fs == NULL) {
        printf("secret key file opening failed: %s.\n", strerror(errno));
        goto exit;
    }

    ssize_t slen = fwrite(privatekey_b64t, 1, private_b64t_len, fs);
    if (slen != private_b64t_len) {
        printf("secret key file writing failed: %s.\n", strerror(errno));
        goto exit;
    }
    if (fwrite("\n", 1, 1, fs) != 1) {
        printf("public key file writing failed: %s.\n", strerror(errno));
        goto exit;
    }

    int res = chmod(seckey_filename, S_IRUSR | S_IWUSR);
    if (res != 0) {
        printf("secret key file chmod failed: %s.\n", strerror(errno));
        goto exit;
    }

    /* Public key */

    publickey_b64 =
        base64_encode(publickey, sizeof(curve25519_key), &public_base64_len);

    if (publickey_b64 == NULL) {
        printf("base64 encoding failed.\n");
        goto exit;
    }

    publickey_b64t =
        b64t_encode(publickey_b64, public_base64_len, &public_b64t_len);

    if (publickey_b64t == NULL) {
        printf("b64t encoding failed.\n");
        goto exit;
    }

    fp = fopen(pubkey_filename, "wb");
    if (fp == NULL) {
        printf("public key file opening failed: %s.\n", strerror(errno));
        goto exit;
    }

    ssize_t plen = fwrite(publickey_b64t, 1, public_b64t_len, fp);
    if (plen != public_b64t_len) {
        printf("public key file writing failed: %s.\n", strerror(errno));
        goto exit;
    }
    if (fwrite("\n", 1, 1, fp) != 1) {
        printf("public key file writing failed: %s.\n", strerror(errno));
        goto exit;
    }

    exitcode = 0;

exit:
    drng_chacha20_destroy(drng);

    memset(privatekey_b64, 0, private_base64_len);
    memset(privatekey_b64t, 0, private_b64t_len);
    memset(privatekey, 0, sizeof(curve25519_key));
    memset(publickey_b64, 0, public_base64_len);
    memset(publickey_b64t, 0, public_b64t_len);
    memset(publickey, 0, sizeof(curve25519_key));
    memset(salt, 0, sizeof(salt));
    memset(privatekey_buffer, 0, sizeof(privatekey_buffer));
    memset(password, 0, sizeof(password));
    memset(password_verif, 0, sizeof(password_verif));

    free(privatekey_b64);
    free(publickey_b64);
    free(privatekey_b64t);
    free(publickey_b64t);

    if (fs != NULL) {
        fclose(fs);
    }
    if (fp != NULL) {
        fclose(fp);
    }

    return exitcode;
}

/*
 * Return key file content in key variable.
 *
 */
int get_seckey(const char *keyfile, unsigned char *skey, unsigned char *pkey) {
    unsigned char *base64_decoded = NULL;
    unsigned char *b64t_decoded = NULL;

    size_t base64_decoded_len = 0;
    size_t b64t_decoded_len = 0;
    size_t password_len = 0;

    FILE *fs = NULL;

    unsigned char salt[32] = {0};
    char password[256] = {0};
    char *password_out = NULL;
    curve25519_key pubkey;
    curve25519_key pubkey_from_secret;
    curve25519_key seckey;

    const uint32_t t_cost = 3;
    const uint32_t m_cost = (1 << 12);
    const uint32_t parallelism = 1;

    int exitcode = 1;

    fs = fopen(keyfile, "rb");
    if (fs == NULL) {
        printf("key file opening failed: %s.\n", strerror(errno));
        goto exit;
    }

    int rsb = fseek(fs, 0L, SEEK_END);
    if (rsb != 0) {
        printf("seek to end file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    size_t sz = ftell(fs);
    if (sz == -1) {
        printf("tell file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    int rse = fseek(fs, 0L, SEEK_SET);
    if (rse != 0) {
        printf("seek file to begin'%s' failed: %s.\n", keyfile,
               strerror(errno));
        goto exit;
    }

    /* max_size = base64(sizeof(curve25519_key)) = 64 * 4 / 3 + 1 -> 86 */
    unsigned char file_data[87] = {0};

    size_t readed = fread(&file_data, 1, sz, fs);
    if (readed != sz) {
        printf("read file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    b64t_decoded = b64t_decode(file_data, readed, &b64t_decoded_len);
    if (b64t_decoded == NULL) {
        printf("b64t decoding failed.\n");
        goto exit;
    }

    base64_decoded =
        base64_decode(b64t_decoded, b64t_decoded_len, &base64_decoded_len);
    if (base64_decoded == NULL) {
        printf("base64 decoding failed.\n");
        goto exit;
    }

    if (base64_decoded_len != 64) {
        printf("decoded size mismatch.\n");
        goto exit;
    }

    memcpy(seckey, base64_decoded, sizeof(curve25519_key));
    memcpy(pubkey, base64_decoded + sizeof(curve25519_key),
           sizeof(curve25519_key));
    curve25519_donna_basepoint(pubkey_from_secret, seckey);

    int password_protected_flag =
        (memcmp(pubkey, pubkey_from_secret, sizeof(curve25519_key)) != 0);
    if (password_protected_flag) {
        memcpy(salt, base64_decoded, sizeof(salt));
        password_out =
            readpassphrase("Password : ", password, sizeof(password), 0);
        if (password_out != password) {
            printf("password input failed.\n");
            goto exit;
        }
        password_len = strlen(password);

        int a2res = argon2id_hash_raw(t_cost, m_cost, parallelism, password,
                                      password_len, salt, sizeof(salt), seckey,
                                      sizeof(curve25519_key));
        if (a2res != ARGON2_OK) {
            printf("argon2 failed.");
            goto exit;
        }
        curve25519_donna_basepoint(pubkey_from_secret, seckey);
        if (memcmp(pubkey_from_secret, pubkey, sizeof(curve25519_key)) != 0) {
            printf("Bad password\n");
            goto exit;
        }
    }

    memcpy(skey, seckey, sizeof(curve25519_key));
    if (pkey != NULL) {
        memcpy(pkey, pubkey_from_secret, sizeof(curve25519_key));
    }
    exitcode = 0;

exit:
    memset(file_data, 0, sizeof(file_data));
    memset(base64_decoded, 0, base64_decoded_len);
    memset(b64t_decoded, 0, b64t_decoded_len);
    memset(salt, 0, sizeof(salt));
    memset(password, 0, sizeof(password));
    memset(seckey, 0, sizeof(curve25519_key));
    memset(pubkey, 0, sizeof(curve25519_key));
    memset(pubkey_from_secret, 0, sizeof(curve25519_key));

    if (fs != NULL) {
        fclose(fs);
    }

    free(base64_decoded);
    free(b64t_decoded);

    return exitcode;
}

/*
 * Return key file content in key variable.
 *
 */
int get_pubkey(const char *keyfile, unsigned char *pkey) {
    unsigned char *base64_decoded = NULL;
    unsigned char *b64t_decoded = NULL;

    size_t b64t_decoded_len = 0;
    size_t base64_decoded_len = 0;

    FILE *fs = NULL;

    curve25519_key pubkey;

    int exitcode = 1;

    fs = fopen(keyfile, "rb");
    if (fs == NULL) {
        printf("key file opening failed: %s.\n", strerror(errno));
        goto exit;
    }

    int rsb = fseek(fs, 0L, SEEK_END);
    if (rsb != 0) {
        printf("seek to end file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    size_t sz = ftell(fs);
    if (sz == -1) {
        printf("tell file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    int rse = fseek(fs, 0L, SEEK_SET);
    if (rse != 0) {
        printf("seek file to begin'%s' failed: %s.\n", keyfile,
               strerror(errno));
        goto exit;
    }

    /* max_size = base64(sizeof(curve25519_key)) = 32 * 4 / 3 + 1 -> 44 */
    unsigned char file_data[44] = {0};
    size_t readed = fread(&file_data, 1, sz, fs);
    if (readed != sz) {
        printf("read file '%s' failed: %s.\n", keyfile, strerror(errno));
        goto exit;
    }

    b64t_decoded = b64t_decode(file_data, readed, &b64t_decoded_len);
    if (b64t_decoded == NULL) {
        printf("b64t decoding failed.\n");
        goto exit;
    }

    base64_decoded =
        base64_decode(b64t_decoded, b64t_decoded_len, &base64_decoded_len);
    if (base64_decoded == NULL) {
        printf("base64 decoding failed.\n");
        goto exit;
    }

    if (base64_decoded_len != 32) {
        printf("decoded size mismatch.\n");
        goto exit;
    }

    memcpy(pkey, base64_decoded, 32);

    exitcode = 0;

exit:
    memset(file_data, 0, sizeof(file_data));
    memset(b64t_decoded, 0, b64t_decoded_len);
    memset(base64_decoded, 0, base64_decoded_len);
    memset(pubkey, 0, sizeof(curve25519_key));

    if (fs != NULL) {
        fclose(fs);
    }

    free(b64t_decoded);
    free(base64_decoded);

    return exitcode;
}

/*
 * Encrypt data, return allocated message in base64.
 *
 */
unsigned char *encrypt_data(const unsigned char *seckey,
                            const unsigned char *pubkey,
                            const unsigned char *salt,
                            const unsigned char *data,
                            size_t data_len,            // in bytes
                            size_t rand_len,            // in bytes
                            size_t tag1_len,            // in bytes
                            int base64_transformation,  // 0 or 1
                            size_t *out_encrypted_len) {
    unsigned char *rand = NULL;
    unsigned char *encrypted1 = NULL;
    unsigned char *part1 = NULL;
    unsigned char *encrypted = NULL;
    unsigned char *encrypted_b64 = NULL;
    unsigned char *encrypted_b64t = NULL;
    unsigned char *out = NULL;

    size_t encrypted1_len = 0;
    size_t norx_encrypted1_len = 0;
    size_t tag1_len_bits = tag1_len * 8;
    size_t part1_len = 0;
    size_t part0_len = 0;
    size_t norx_params_encrypted_len = 0;
    size_t encrypted_b64_len = 0;
    size_t encrypted_b64t_len = 0;
    size_t encrypted_len = 0;

    const size_t shared_secret_len = 32;
    unsigned char shared_secret[32] = {0};

    const size_t nonce1_len = 32;
    const size_t tag0_len = 4;
    const size_t tag0_len_bits = tag0_len * 8;
    const size_t params_len = 5;
    const size_t params_encrypted_len = 9;  // params_len + tag0_len
    const size_t nonce0_len = 32;
    const size_t header0_len = 4;
    const size_t encrypted_len_expected = data_len + tag1_len;

    const uint8_t *key1 = 0;
    const uint8_t *nonce1 = 0;
    unsigned char params[5] = {0};
    unsigned char params_encrypted[9] = {0};
    const uint8_t *key0 = 0;
    const uint8_t *nonce0 = 0;
    unsigned char header0[4] = {0};

    struct chacha20_drng *drng = NULL;

    *out_encrypted_len = 0;

    // generate shared secret from DH X25519
    curve25519_donna(shared_secret, seckey, pubkey);

    // generate part 1 random data
    rand = malloc(rand_len);
    if (rand == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }

    int ret = drng_chacha20_init(&drng);
    if (ret) {
        printf("Chacha20 allocation failed: %d\n", ret);
        goto exit;
    }
    if (drng_chacha20_get(drng, rand, rand_len)) {
        printf("Getting random numbers failed\n");
        goto exit;
    }

    // generate part 0 message
    part0_len = params_len + tag0_len;

    params[0] = (rand_len >> 24) & 0xFF;
    params[1] = (rand_len >> 16) & 0xFF;
    params[2] = (rand_len >> 8) & 0xFF;
    params[3] = rand_len & 0xFF;

    params[4] = tag1_len;

    // calculate part length
    encrypted1_len = tag1_len + data_len;
    part1_len = encrypted1_len + rand_len;
    encrypted_len = part0_len + part1_len;

    // generate part 0 header
    header0[0] = (encrypted_len >> 24) & 0xFF;
    header0[1] = (encrypted_len >> 16) & 0xFF;
    header0[2] = (encrypted_len >> 8) & 0xFF;
    header0[3] = encrypted_len & 0xFF;

    // generate nonce 1
    sha3_context nc1;
    sha3_Init256(&nc1);
    sha3_Update(&nc1, params, params_len);
    sha3_Update(&nc1, rand, rand_len);
    nonce1 = sha3_Finalize(&nc1);

    // generate key for part1
    sha3_context kc1;
    sha3_Init256(&kc1);
    sha3_Update(&kc1, nonce1, nonce1_len);
    sha3_Update(&kc1, shared_secret, shared_secret_len);
    key1 = sha3_Finalize(&kc1);

    // encrypt message data
    encrypted1 = malloc(encrypted1_len);
    if (encrypted1 == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }

    norx_aead_encrypt(encrypted1, &norx_encrypted1_len, params, params_len,
                      data, data_len, NULL, 0, nonce1, key1, tag1_len_bits);
    if (encrypted1_len != norx_encrypted1_len ||
        norx_encrypted1_len != encrypted_len_expected) {
        printf("Norx encryption failed.\n");
        goto exit;
    }

    // generate full part 1 buffer
    part1 = malloc(part1_len);
    if (part1 == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }
    memcpy(part1, rand, rand_len);
    memcpy((part1 + rand_len), encrypted1, encrypted1_len);

    // generate part 0 nonce: sha3-256(rand+encrypted1)
    sha3_context nc0;
    sha3_Init256(&nc0);
    sha3_Update(&nc0, part1, part1_len);
    nonce0 = sha3_Finalize(&nc0);

    // generate key for part 0
    sha3_context kc0;
    sha3_Init256(&kc0);
    sha3_Update(&kc0, nonce0, nonce0_len);
    sha3_Update(&kc0, shared_secret, shared_secret_len);
    key0 = sha3_Finalize(&kc0);

    // encrypt params (part 0 message)
    norx_aead_encrypt(params_encrypted, &norx_params_encrypted_len, header0,
                      header0_len, params, params_len, NULL, 0, nonce0, key0,
                      tag0_len_bits);
    if (params_encrypted_len != norx_params_encrypted_len) {
        printf("Norx encryption failed.\n");
        goto exit;
    }

    // symmetric case
    if (salt) {
        encrypted_len = encrypted_len + 64;
    }

    // generate full message buffer
    encrypted = malloc(encrypted_len);
    if (encrypted == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }

    if (salt) {
        memcpy(encrypted, salt, 32);
        memcpy(32 + encrypted, pubkey, 32);
        memcpy(64 + encrypted, params_encrypted, params_encrypted_len);
        memcpy(64 + encrypted + params_encrypted_len, part1, part1_len);
    } else {
        memcpy(encrypted, params_encrypted, params_encrypted_len);
        memcpy(encrypted + params_encrypted_len, part1, part1_len);
    }

    // encode full message in base64
    encrypted_b64 = base64_encode(encrypted, encrypted_len, &encrypted_b64_len);
    if (encrypted_b64 == NULL || encrypted_b64_len == 0 ||
        encrypted_b64_len < encrypted_len) {
        printf("Base64 encoding failed.\n");
        goto exit;
    }

    // transform base64 encoding
    if (base64_transformation) {
        encrypted_b64t =
            b64t_encode(encrypted_b64, encrypted_b64_len, &encrypted_b64t_len);
        if (encrypted_b64t == NULL || encrypted_b64t_len == 0) {
            memset(encrypted_b64, 0, encrypted_b64_len);
            free(encrypted_b64);
            goto exit;
        }
        out = encrypted_b64t;
        *out_encrypted_len = encrypted_b64t_len;
        memset(encrypted_b64, 0, encrypted_b64_len);
        free(encrypted_b64);
    } else {
        out = encrypted_b64;
        *out_encrypted_len = encrypted_b64_len;
    }

exit:
    drng_chacha20_destroy(drng);

    memset(rand, 0, rand_len);
    memset(encrypted1, 0, encrypted1_len);
    memset(part1, 0, part1_len);
    memset(encrypted, 0, encrypted_len);
    memset(params, 0, params_len);
    memset(params_encrypted, 0, params_encrypted_len);
    memset(shared_secret, 0, shared_secret_len);
    memset(header0, 0, header0_len);

    memset(&kc0, 0, sizeof(sha3_context));
    memset(&kc1, 0, sizeof(sha3_context));
    memset(&nc0, 0, sizeof(sha3_context));
    memset(&nc1, 0, sizeof(sha3_context));

    free(rand);
    free(encrypted1);
    free(part1);
    free(encrypted);

    return out;
}

int get_symmetrickeys(unsigned char *salt_out, unsigned char *seckey_out,
                      unsigned char *pubkey_out) {
    struct chacha20_drng *drng;

    const uint32_t t_cost = 3;
    const uint32_t m_cost = (1 << 12);
    const uint32_t parallelism = 1;

    char password[256] = {0};
    char *password_out = NULL;
    char password_verif[256] = {0};
    char *password_verif_out = NULL;
    size_t password_len = 0;
    size_t password_verif_len = 0;

    int exitcode = 1;

    int ret = drng_chacha20_init(&drng);
    if (ret) {
        printf("Chacha20 allocation failed: %d\n", ret);
        goto exit;
    }
    if (drng_chacha20_get(drng, salt_out, 32)) {
        printf("Getting random numbers failed\n");
        goto exit;
    }

    password_out = readpassphrase("Password : ", password, sizeof(password), 0);
    if (password_out != password) {
        printf("password input failed.\n");
        goto exit;
    }
    password_len = strlen(password);

    password_verif_out =
        readpassphrase("Verifying, please re-enter : ", password_verif,
                       sizeof(password_verif), 0);
    if (password_verif_out != password_verif) {
        printf("password input failed.\n");
        goto exit;
    }
    password_verif_len = strlen(password_verif);

    if (password_len != password_verif_len) {
        printf("Mismatch.\n");
        goto exit;
    }

    if (memcmp(password, password_verif, password_len) != 0) {
        printf("Mismatch.\n");
        goto exit;
    }

    int a2res =
        argon2id_hash_raw(t_cost, m_cost, parallelism, password, password_len,
                          salt_out, 32, seckey_out, sizeof(curve25519_key));
    if (a2res != ARGON2_OK) {
        printf("argon2 failed.");
        goto exit;
    }

    curve25519_donna_basepoint(pubkey_out, seckey_out);

    exitcode = 0;

exit:
    drng_chacha20_destroy(drng);
    memset(password, 0, sizeof(password));
    memset(password_verif, 0, sizeof(password_verif));

    return exitcode;
}

int check_get_symmetrickeys(const unsigned char *data, const size_t data_len,
                            unsigned char *seckey_out,
                            unsigned char *pubkey_out) {
    const uint32_t t_cost = 3;
    const uint32_t m_cost = (1 << 12);
    const uint32_t parallelism = 1;

    char password[256] = {0};
    char *password_out = NULL;
    size_t password_len = 0;

    unsigned char *data_b64t_decoded = NULL;
    unsigned char *encrypted = NULL;
    size_t data_b64t_decoded_len = 0;
    size_t encrypted_len = 0;

    unsigned char salt[32] = {0};

    curve25519_key pubkey_from_secret;
    int exitcode = 1;

    data_b64t_decoded = b64t_decode(data, data_len, &data_b64t_decoded_len);

    encrypted =
        base64_decode(data_b64t_decoded, data_b64t_decoded_len, &encrypted_len);
    if (encrypted == NULL) {
        printf("base64 decoding failed.\n");
        goto exit;
    }

    memcpy(salt, encrypted, 32);
    memcpy(pubkey_from_secret, encrypted + 32, 32);

    password_out = readpassphrase("Password : ", password, sizeof(password), 0);
    if (password_out != password) {
        printf("password input failed.\n");
        goto exit;
    }
    password_len = strlen(password);

    int a2res =
        argon2id_hash_raw(t_cost, m_cost, parallelism, password, password_len,
                          salt, 32, seckey_out, sizeof(curve25519_key));
    if (a2res != ARGON2_OK) {
        printf("argon2 failed.");
        goto exit;
    }

    curve25519_donna_basepoint(pubkey_out, seckey_out);
    if (memcmp(pubkey_out, pubkey_from_secret, sizeof(curve25519_key)) != 0) {
        printf("Wrong password\n");
        memset(seckey_out, 0, sizeof(curve25519_key));
        memset(salt, 0, 32);
        memset(pubkey_out, 0, sizeof(curve25519_key));
        goto exit;
    }

    exitcode = 0;

exit:
    memset(password, 0, sizeof(password));
    memset(encrypted, 0, encrypted_len);
    memset(data_b64t_decoded, 0, data_b64t_decoded_len);

    free(encrypted);
    free(data_b64t_decoded);

    return exitcode;
}

/*
 * Decrypt data, return allocated clear text message.
 *
 */
unsigned char *decrypt_data(const unsigned char *seckey,
                            const unsigned char *pubkey,
                            const unsigned char *data, size_t data_len,
                            int symmetric_flag, size_t *data_len_out) {
    unsigned char *data_b64t_decoded = NULL;
    unsigned char *encrypted = NULL;
    unsigned char *part1 = NULL;
    unsigned char *rand = NULL;
    unsigned char *message = NULL;
    unsigned char *encrypted1 = NULL;

    size_t shared_secret_len = 32;
    unsigned char shared_secret[32] = {0};

    size_t encrypted_len = 0;
    size_t data_b64t_decoded_len = 0;
    size_t part1_len = 0;
    size_t rand_len = 0;
    size_t tag1_len = 0;
    size_t tag1_len_bits = 0;
    size_t encrypted1_len = 0;
    size_t message_len = 0;
    size_t norx_message_len = 0;

    const size_t tag1_len_min = 4;
    const size_t rand_len_min = 4;
    const size_t nonce1_len = 32;
    const size_t tag0_len = 4;
    const size_t tag0_len_bits = tag0_len * 8;
    const size_t part0_len = 9;  // params_len + tag0_len;
    const size_t nonce0_len = 32;

    const uint8_t *nonce0 = 0;
    const size_t header0_len = 4;
    unsigned char header0[4] = {0};
    const uint8_t *key0 = 0;
    unsigned char encrypted0[9] = {0};
    const size_t params_len = 5;
    unsigned char params[5];
    size_t norx_params_len = 0;
    const uint8_t *nonce1 = 0;
    const uint8_t *key1 = 0;

    // shared secret DH X25519
    curve25519_donna(shared_secret, seckey, pubkey);

    // decode base64
    data_b64t_decoded = b64t_decode(data, data_len, &data_b64t_decoded_len);

    // decode base64 transformation
    encrypted =
        base64_decode(data_b64t_decoded, data_b64t_decoded_len, &encrypted_len);
    if (encrypted == NULL) {
        printf("base64 decoding failed.\n");
        goto exit;
    }
    if (encrypted_len < part0_len + tag1_len_min + rand_len_min) {
        printf("Size mismatch.\n");
        goto exit;
    }

    if (symmetric_flag) {
        encrypted = encrypted + 64;
        encrypted_len = encrypted_len - 64;
    }

    // part1 message allocation and copy
    part1_len = encrypted_len - part0_len;
    part1 = malloc(part1_len);
    if (part1 == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }
    memcpy(part1, encrypted + part0_len, part1_len);

    // generate part 0 header
    header0[0] = (encrypted_len >> 24) & 0xFF;
    header0[1] = (encrypted_len >> 16) & 0xFF;
    header0[2] = (encrypted_len >> 8) & 0xFF;
    header0[3] = encrypted_len & 0xFF;

    // nonce 0 generation: sha3-256(part1)
    sha3_context nc0;
    sha3_Init256(&nc0);
    sha3_Update(&nc0, part1, part1_len);
    nonce0 = sha3_Finalize(&nc0);

    // generate key 0
    sha3_context kc0;
    sha3_Init256(&kc0);
    sha3_Update(&kc0, nonce0, nonce0_len);
    sha3_Update(&kc0, shared_secret, shared_secret_len);
    key0 = sha3_Finalize(&kc0);

    // decrypt part 0 part (params)
    memcpy(encrypted0, encrypted, part0_len);

    int n0 = norx_aead_decrypt(params, &norx_params_len, header0, header0_len,
                               encrypted0, part0_len, NULL, 0, nonce0, key0,
                               tag0_len_bits);

    if (n0 != 0) {
        printf("Norx0 decryption failed.\n");
        goto exit;
    }

    // get part 1 encryption params
    rand_len =
        (params[0] << 24) + (params[1] << 16) + (params[2] << 8) + params[3];
    tag1_len = params[4];
    tag1_len_bits = tag1_len * 8;

    if (tag1_len != 4 && tag1_len != 8 && tag1_len != 16 && tag1_len != 24 &&
        tag1_len != 32) {
        printf("bad auth tag len.\n");
        goto exit;
    }
    if (rand_len > (encrypted_len - part0_len - tag1_len) ||
        rand_len < rand_len_min) {
        printf("size mismatch.\n");
        goto exit;
    }

    // allocate and copy rand 1
    rand = malloc(rand_len);
    if (rand == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }
    memcpy(rand, encrypted + part0_len, rand_len);

    // generate nonce1: sha3-256(rand)
    sha3_context nc1;
    sha3_Init256(&nc1);
    sha3_Update(&nc1, params, params_len);
    sha3_Update(&nc1, rand, rand_len);
    nonce1 = sha3_Finalize(&nc1);

    // generate key for part1
    sha3_context kc1;
    sha3_Init256(&kc1);
    sha3_Update(&kc1, nonce1, nonce1_len);
    sha3_Update(&kc1, shared_secret, shared_secret_len);
    key1 = sha3_Finalize(&kc1);

    // get encrypted message buffer
    encrypted1 = encrypted + part0_len + rand_len;
    encrypted1_len = encrypted_len - part0_len - rand_len;

    message_len = encrypted1_len - tag1_len;
    message = malloc(message_len);
    if (message == NULL) {
        printf("malloc failed.\n");
        goto exit;
    }

    int n1 = norx_aead_decrypt(message, &norx_message_len, params, params_len,
                               encrypted1, encrypted1_len, NULL, 0, nonce1,
                               key1, tag1_len_bits);
    if (n1 != 0 || norx_message_len != message_len) {
        printf("Norx1 decryption failed.\n");
        goto exit;
    }

    *data_len_out = message_len;

exit:
    memset(data_b64t_decoded, 0, data_b64t_decoded_len);
    memset(encrypted, 0, encrypted_len);
    memset(part1, 0, part1_len);
    memset(rand, 0, rand_len);
    memset(shared_secret, 0, shared_secret_len);
    memset(encrypted0, 0, part0_len);
    memset(params, 0, params_len);
    memset(header0, 0, header0_len);

    memset(&kc0, 0, sizeof(sha3_context));
    memset(&kc1, 0, sizeof(sha3_context));
    memset(&nc0, 0, sizeof(sha3_context));
    memset(&nc1, 0, sizeof(sha3_context));

    free(data_b64t_decoded);
    if (symmetric_flag)
        free(encrypted - 64);
    else
        free(encrypted);
    free(part1);
    free(rand);

    return message;
    return NULL;
}