/* Inexact source code package.
 *
 * Written in 2019 by <ben@hackade.org>.
 *
 * To the extent possible under law, the author have dedicated all copyright
 * and related and neighboring rights to this software to the public domain
 * worldwide. This software is distributed without any warranty.
 *
 * You should have received a copy of the CC0 Public Domain Dedication along with
 * this software. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
 */

#include <errno.h>
#include <getopt.h>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "argtable3.h"
#include "inexact.h"
#include "tests.h"

/* global arg_xxx structs */
struct arg_lit *help, *version, *gen, *dencrypt, *ddecrypt, *test, *base64, *weak,
    *nopassword, *symmetric;
struct arg_file *seckey, *pubkey, *infile, *outfile;
struct arg_end *end;
struct arg_int *taglen, *noncelen, *cipherlen;

int main(int argc, char *argv[]) {
    /* the global arg_xxx structs are initialised within the argtable */
    void *argtable[] = {
        help        = arg_litn("h", "help", 0, 1, "display this help and exit"),
        version     = arg_litn("v", "version", 0, 1, "display version info and exit"),
        gen         = arg_litn("g", "genkeys", 0, 1, "generate keys"),
        dencrypt    = arg_litn("e", "encrypt", 0, 1, "encrypt data"),
        ddecrypt    = arg_litn("d", "decrypt", 0, 1, "decrypt data"),
        symmetric   = arg_litn("s", "symmetric", 0, 1, "symmetric encryption with password"),
        seckey      = arg_filen("k", "seckey", "secretkey", 0, 1, "secret key file"),
        pubkey      = arg_filen("p", "pubkey", "publickey", 0, 1, "public key iles"),
        taglen      = arg_intn("t", "taglen", "<64,128,192,256>", 0, 1, "authentication message tag length in bits (default: 256)"),
        noncelen    = arg_intn("n", "noncelen", "<n>", 0, 1, "random nonce length in bytes (default: 32, must be >= 16)"),
        cipherlen   = arg_intn("c", "cipherlen", "<n>", 0, 1, "set random nonce length for <n> bytes output ciphertext size"),
        base64      = arg_litn(NULL, "base64", 0, 1, "use base64 format without transformation"),
        test        = arg_litn(NULL, "test", 0, 1, "test crypto and encoding internal functions"),
        nopassword  = arg_litn(NULL, "no-password", 0, 1, "generate secret key without password"),
        weak        = arg_litn("w", "weak", 0, 1, "use weak length for nonce and auth tag (-n 4 -t 32)"),
        infile      = arg_filen("i", "input-file", "<infile>", 0, 1, "input file (default: stdin)"),
        outfile     = arg_filen("o", "output-file", "<outfile>", 0, 1, "output file (default: stdout)"),
        end = arg_end(20),
    };

    int exitcode = 0;
    const char progname[] = "inexact";
    const char ver[] = "beta 1.0";
    FILE *fo = NULL;

    int nerrors;
    nerrors = arg_parse(argc, argv, argtable);

    /* special case: '--help' takes precedence over error reporting */
    if (help->count > 0) {
        printf("Usage: %s", progname);
        arg_print_syntax(stdout, argtable, "\n");
        printf(
            "INadvisable EXperimental Asymmetric Crypto Tool, by "
            "<ben@hackade.org>.\n\n");
        arg_print_glossary(stdout, argtable, "  %-25s %s\n");
        exitcode = 0;
        goto exit;
    }

    /* If the parser returned any errors then display them and exit */
    if (nerrors > 0) {
        /* Display the error details contained in the arg_end struct.*/
        arg_print_errors(stdout, end, progname);
        printf("Try '%s --help' for more information.\n", progname);
        exitcode = 1;
        goto exit;
    }

    /* check if an action is specified */
    int action_count = gen->count + ddecrypt->count + dencrypt->count +
                       version->count + help->count + test->count;
    if (action_count == 0) {
        printf("Missing parameters.\n");
        printf("Try '%s --help' for more information.\n", progname);
        exitcode = 1;
        goto exit;
    }

    if (version->count == 1) {
        printf("version: %s\n", ver);
        exitcode = 0;
        goto exit;
    }

    if (test->count > 0) {
        exitcode = test_all();
        goto exit;
    }

    /* check if more than one action is specified */
    if (action_count > 1) {
        printf("Invalid options.\n");
        printf("Try '%s --help' for more information.\n", progname);
        exitcode = 1;
        goto exit;
    }

    if (symmetric->count == 1) {
        if (gen->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (pubkey->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (seckey->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (nopassword->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
    }

    /* check if a secret key is specified */
    if (seckey->count == 0 && symmetric->count == 0) {
        printf("Missing secret key file operand.\n");
        printf("Try '%s --help' for more information.\n", progname);
        exitcode = 1;
        goto exit;
    }

    /* check if public key is specified */
    if (pubkey->count == 0 && symmetric->count == 0) {
        printf("Missing public key file operand.\n");
        printf("Try '%s --help' for more information.\n", progname);
        exitcode = 1;
        goto exit;
    }

    int nopassword_flag = (nopassword->count == 1);

    /* generate action */
    if (gen->count == 1) {
        if (base64->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (infile->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (outfile->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (taglen->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (noncelen->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (access(seckey->filename[0], F_OK) != -1) {
            char ch;
            printf("Overwrite '%s' ? ", seckey->filename[0]);
            int res = scanf("%c", &ch);
            if ((ch != 'Y' && ch != 'y') || res == 0) {
                exitcode = 1;
                goto exit;
            }
        }
        if (access(pubkey->filename[0], F_OK) != -1) {
            char ch;
            printf("Overwrite '%s' ? ", pubkey->filename[0]);
            int res = scanf(" %c", &ch);
            if ((ch != 'Y' && ch != 'y') || res == 0) {
                exitcode = 1;
                goto exit;
            }
        }

        return generate_keys(seckey->filename[0], pubkey->filename[0],
                             nopassword_flag);
    }

    /* read input infile data */
    unsigned char *data = NULL;
    int data_len = 0;

    if (infile->count == 0) {
        /* read stdin */
        int cap = 4096, len = 0;
        data = malloc(cap * sizeof(unsigned char));
        if (data == NULL) {
            printf("malloc %d bytes failed.\n", cap);
            exitcode = 1;
            goto exit;
        }
        int c = 0;
        do {
            c = fgetc(stdin);
            data[len] = c;
            if (++len == cap) {
                cap *= 2;
                data = realloc(data, cap * sizeof(unsigned char));
                if (data == NULL) {
                    printf("realloc %d bytes failed.\n", cap);
                    exitcode = 1;
                    goto exit;
                }
            }
        } while (!feof(stdin));
        fclose(stdin);
        data = realloc(data, len * sizeof(unsigned char));
        if (data == NULL) {
            printf("realloc %d bytes failed.\n", cap);
            exitcode = 1;
            goto exit;
        }
        data[len - 1] = '\0';
        data_len = len - 1;
    } else if (access(infile->filename[0], F_OK) == -1) {
        printf("Input file '%s' not found.\n", infile->filename[0]);
        exitcode = 1;
        goto exit;
    } else {
        FILE *fi = fopen(infile->filename[0], "rb");
        if (fi == NULL) {
            printf("open file '%s' failed: %s.\n", infile->filename[0],
                   strerror(errno));
            exitcode = 1;
            goto exit;
        }
        int rsb = fseek(fi, 0L, SEEK_END);
        if (rsb != 0) {
            printf("seek to end file '%s' failed: %s.\n", infile->filename[0],
                   strerror(errno));
            fclose(fi);
            exitcode = 1;
            goto exit;
        }
        data_len = ftell(fi);
        if (data_len == -1) {
            printf("tell file '%s' failed: %s.\n", infile->filename[0],
                   strerror(errno));
            fclose(fi);
            exitcode = 1;
            goto exit;
        }
        int rse = fseek(fi, 0L, SEEK_SET);
        if (rse != 0) {
            printf("seek file to begin'%s' failed: %s.\n", infile->filename[0],
                   strerror(errno));
            fclose(fi);
            exitcode = 1;
            goto exit;
        }
        data = malloc(data_len);
        if (data == NULL) {
            printf("malloc %d bytes failed.\n", data_len);
            fclose(fi);
            exitcode = 1;
            goto exit;
        }
        size_t readed = fread(data, 1, data_len, fi);
        if (readed != data_len) {
            printf("read file '%s' failed: %s.\n", infile->filename[0],
                   strerror(errno));
            fclose(fi);
            exitcode = 1;
            goto exit;
        }
        fclose(fi);
    }

    /* check size limitation for input data (32 bits) */
    if (data_len > 0xbfffffff) {
        printf("Size of data too big.\n");
        exitcode = 1;
        goto exit;
    }

    int auth_tag_len = 32;
    if (taglen->count == 1) {
        if (ddecrypt->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (dencrypt->count == 0) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        int tag_len_value = taglen->ival[0];
        if (tag_len_value != 64 && tag_len_value != 128 &&
            tag_len_value != 192 && tag_len_value != 256) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        } else {
            auth_tag_len = tag_len_value / 8;
        }
    }

    /* output file */
    if (outfile->count == 0) {
        fo = stdout;
    } else {
        fo = fopen(outfile->filename[0], "wb");
        if (fo == NULL) {
            printf("open file '%s' failed: %s.\n", outfile->filename[0],
                   strerror(errno));
            exitcode = 1;
            goto exit;
        }
    }

    const int rand_nonce_len_min = 16;
    int rand_nonce_len = 32;
    if (noncelen->count == 1) {
        if (ddecrypt->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (dencrypt->count == 0) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (cipherlen->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        int noncelen_value = noncelen->ival[0];
        if (noncelen_value < rand_nonce_len_min) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        } else {
            rand_nonce_len = noncelen_value;
        }
    }

    int base64_transformation = (base64->count == 0);
    if (cipherlen->count == 1) {
        if (dencrypt->count == 0 || noncelen->count == 1 || base64->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        size_t cipherlen_b64_value = cipherlen->ival[0];
        size_t cipherlen_value = cipherlen_b64_value * 3 / 4;

        size_t part0_len = 9;
        size_t part1_len = data_len + auth_tag_len + rand_nonce_len;
        size_t total_encrypted_len = part0_len + part1_len;
        size_t total_encrypted_len_without_rand1 =
            total_encrypted_len - rand_nonce_len;

        if (symmetric->count == 1) {
            total_encrypted_len_without_rand1 = 
                total_encrypted_len_without_rand1 + 64;
        }

        size_t rand_nonce_len_needed =
            cipherlen_value - total_encrypted_len_without_rand1;

        if (total_encrypted_len_without_rand1 > cipherlen_value ||
            rand_nonce_len_needed < rand_nonce_len_min) {
            unsigned long total_cipherlen_b64_min =
                (total_encrypted_len_without_rand1 + rand_nonce_len_min) * 4 /
                3;
            printf("Insufficient nonce length, cipherlen must be > %lu\n",
                   total_cipherlen_b64_min);
            exitcode = 2;
            goto exit;
        } else {
            rand_nonce_len = rand_nonce_len_needed;
        }
    }

    if (weak->count == 1) {
        if (dencrypt->count == 0 || noncelen->count == 1 || taglen->count == 1 ||
            cipherlen->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        rand_nonce_len = 4;
        auth_tag_len = 4;
    }

    if (dencrypt->count == 1) {
        unsigned char secretkey[32] = {0};
        unsigned char publickey[32] = {0};
        unsigned char salt[32] = {0};
        unsigned char *psalt = NULL;

        if (symmetric->count == 1) {
            if (get_symmetrickeys(salt, secretkey, publickey) != 0) {
                printf("Symmetric keys generation failed.\n");
                exitcode = 1;
                goto exit;
            }
            psalt = &salt[0];
        } else {
            if (access(seckey->filename[0], F_OK) == -1) {
                printf("Secret key file '%s' not found.\n",
                       seckey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (get_seckey(seckey->filename[0], secretkey, NULL) == 1) {
                printf("File '%s': invalid key.\n", seckey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (access(pubkey->filename[0], F_OK) == -1) {
                printf("Public key file '%s' not found.\n",
                       pubkey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (get_pubkey(pubkey->filename[0], publickey) == 1) {
                printf("File '%s': invalid key.\n", pubkey->filename[0]);
                exitcode = 1;
                goto exit;
            }
        }
        size_t encrypted_len = 0;
        unsigned char *encrypted = encrypt_data(
            secretkey, publickey, psalt, data, data_len, rand_nonce_len,
            auth_tag_len, base64_transformation, &encrypted_len);
        if (encrypted == NULL) {
            printf("Encryption failed.\n");
            exitcode = 1;
            goto exit;
        }
        fwrite(encrypted, 1, encrypted_len, fo);
        if (base64_transformation) {
            fwrite("\n", 1, 1, fo);
        }
        fflush(fo);
        memset(encrypted, 0, encrypted_len);
        memset(secretkey, 0, 32);
        memset(publickey, 0, 32);
        memset(data, 0, data_len);
        free(encrypted);
    }

    if (ddecrypt->count == 1) {
        unsigned char secretkey[32] = {0};
        unsigned char publickey[32] = {0};

        if (base64->count == 1) {
            printf("Invalid options.\n");
            printf("Try '%s --help' for more information.\n", progname);
            exitcode = 1;
            goto exit;
        }
        if (symmetric->count == 1) {
            if (check_get_symmetrickeys(data, data_len, secretkey, publickey) !=
                0) {
                exitcode = 1;
                goto exit;
            }
        } else {
            if (access(seckey->filename[0], F_OK) == -1) {
                printf("Secret key file '%s' not found.\n",
                       seckey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (get_seckey(seckey->filename[0], secretkey, NULL) == 1) {
                printf("File '%s': invalid key.\n", seckey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (access(pubkey->filename[0], F_OK) == -1) {
                printf("Public key file '%s' not found.\n",
                       pubkey->filename[0]);
                exitcode = 1;
                goto exit;
            }
            if (get_pubkey(pubkey->filename[0], publickey) == 1) {
                printf("File '%s': invalid key.\n", pubkey->filename[0]);
                exitcode = 1;
                goto exit;
            }
        }

        size_t decrypted_len = 0;
        unsigned char *decrypted =
            decrypt_data(secretkey, publickey, data, data_len,
                         (symmetric->count == 1), &decrypted_len);

        if (decrypted == NULL || decrypted_len == 0) {
            printf("Decryption failed.\n");
            exitcode = 1;
            goto exit;
        }
        fwrite(decrypted, 1, decrypted_len, fo);
        fflush(fo);
        memset(decrypted, 0, decrypted_len);
        memset(secretkey, 0, 32);
        memset(publickey, 0, 32);
        memset(data, 0, data_len);
        free(decrypted);
    }

exit:
    if (fo != NULL) fclose(fo);

    /* deallocate each non-null entry in argtable[] */
    arg_freetable(argtable, sizeof(argtable) / sizeof(argtable[0]));
    return exitcode;
}