#include <stdio.h>
#include <string.h>
#include <stdlib.h>

#include <criterion/criterion.h>
#include <criterion/new/assert.h>

#include "libskycoin.h"
#include "skyerrors.h"
#include "skystring.h"
#include "skytest.h"
#include "base64.h"

#define BUFFER_SIZE 1024

TestSuite(cipher_scrypt, .init = setup, .fini = teardown);

typedef struct {
	char* 		password;
	char* 		salt;
	GoInt 		N, r, p;
	GoUint8* 	output;
	GoInt		keyLength;
	GoInt		passwordLength;
	GoInt		saltLength;
} TESTVECTOR;

Test(cipher_scrypt, TestKey){
	GoUint8 g1[] = {
			0x48, 0x2c, 0x85, 0x8e, 0x22, 0x90, 0x55, 0xe6, 0x2f,
			0x41, 0xe0, 0xec, 0x81, 0x9a, 0x5e, 0xe1, 0x8b, 0xdb,
			0x87, 0x25, 0x1a, 0x53, 0x4f, 0x75, 0xac, 0xd9, 0x5a,
			0xc5, 0xe5, 0xa, 0xa1, 0x5f,
		};
	GoUint8 g2[] = {
			0x88, 0xbd, 0x5e, 0xdb, 0x52, 0xd1, 0xdd, 0x0, 0x18,
			0x87, 0x72, 0xad, 0x36, 0x17, 0x12, 0x90, 0x22, 0x4e,
			0x74, 0x82, 0x95, 0x25, 0xb1, 0x8d, 0x73, 0x23, 0xa5,
			0x7f, 0x91, 0x96, 0x3c, 0x37,
		};
	GoUint8 g3[] = {
			0xc3, 0xf1, 0x82, 0xee, 0x2d, 0xec, 0x84, 0x6e, 0x70,
			0xa6, 0x94, 0x2f, 0xb5, 0x29, 0x98, 0x5a, 0x3a, 0x09,
			0x76, 0x5e, 0xf0, 0x4c, 0x61, 0x29, 0x23, 0xb1, 0x7f,
			0x18, 0x55, 0x5a, 0x37, 0x07, 0x6d, 0xeb, 0x2b, 0x98,
			0x30, 0xd6, 0x9d, 0xe5, 0x49, 0x26, 0x51, 0xe4, 0x50,
			0x6a, 0xe5, 0x77, 0x6d, 0x96, 0xd4, 0x0f, 0x67, 0xaa,
			0xee, 0x37, 0xe1, 0x77, 0x7b, 0x8a, 0xd5, 0xc3, 0x11,
			0x14, 0x32, 0xbb, 0x3b, 0x6f, 0x7e, 0x12, 0x64, 0x40,
			0x18, 0x79, 0xe6, 0x41, 0xae,
		};
	GoUint8 g4[] = {
			0x48, 0xb0, 0xd2, 0xa8, 0xa3, 0x27, 0x26, 0x11, 0x98,
			0x4c, 0x50, 0xeb, 0xd6, 0x30, 0xaf, 0x52,
		};
	GoUint8 g5[] = {
			0x77, 0xd6, 0x57, 0x62, 0x38, 0x65, 0x7b, 0x20, 0x3b,
			0x19, 0xca, 0x42, 0xc1, 0x8a, 0x04, 0x97, 0xf1, 0x6b,
			0x48, 0x44, 0xe3, 0x07, 0x4a, 0xe8, 0xdf, 0xdf, 0xfa,
			0x3f, 0xed, 0xe2, 0x14, 0x42, 0xfc, 0xd0, 0x06, 0x9d,
			0xed, 0x09, 0x48, 0xf8, 0x32, 0x6a, 0x75, 0x3a, 0x0f,
			0xc8, 0x1f, 0x17, 0xe8, 0xd3, 0xe0, 0xfb, 0x2e, 0x0d,
			0x36, 0x28, 0xcf, 0x35, 0xe2, 0x0c, 0x38, 0xd1, 0x89,
			0x06,
		};
	GoUint8 g6[] = {
			0xfd, 0xba, 0xbe, 0x1c, 0x9d, 0x34, 0x72, 0x00, 0x78,
			0x56, 0xe7, 0x19, 0x0d, 0x01, 0xe9, 0xfe, 0x7c, 0x6a,
			0xd7, 0xcb, 0xc8, 0x23, 0x78, 0x30, 0xe7, 0x73, 0x76,
			0x63, 0x4b, 0x37, 0x31, 0x62, 0x2e, 0xaf, 0x30, 0xd9,
			0x2e, 0x22, 0xa3, 0x88, 0x6f, 0xf1, 0x09, 0x27, 0x9d,
			0x98, 0x30, 0xda, 0xc7, 0x27, 0xaf, 0xb9, 0x4a, 0x83,
			0xee, 0x6d, 0x83, 0x60, 0xcb, 0xdf, 0xa2, 0xcc, 0x06,
			0x40,
		};
	GoUint8 g7[] = {
			0x70, 0x23, 0xbd, 0xcb, 0x3a, 0xfd, 0x73, 0x48, 0x46,
			0x1c, 0x06, 0xcd, 0x81, 0xfd, 0x38, 0xeb, 0xfd, 0xa8,
			0xfb, 0xba, 0x90, 0x4f, 0x8e, 0x3e, 0xa9, 0xb5, 0x43,
			0xf6, 0x54, 0x5d, 0xa1, 0xf2, 0xd5, 0x43, 0x29, 0x55,
			0x61, 0x3f, 0x0f, 0xcf, 0x62, 0xd4, 0x97, 0x05, 0x24,
			0x2a, 0x9a, 0xf9, 0xe6, 0x1e, 0x85, 0xdc, 0x0d, 0x65,
			0x1e, 0x40, 0xdf, 0xcf, 0x01, 0x7b, 0x45, 0x57, 0x58,
			0x87,
		};
	TESTVECTOR good_ones[] = {
		{"password", "salt", 2, 10, 10, g1, sizeof(g1) / sizeof(GoUint8), -1, -1},
		{"password", "salt", 16, 100, 100, g2, sizeof(g2) / sizeof(GoUint8), -1, -1}, 
		{"this is a long \000 password", 
		 "and this is a long \000 salt",
					16384, 8, 1, g3, sizeof(g3) / sizeof(GoUint8), 25, 25},
		{"p", "s", 2, 1, 1, g4, sizeof(g4) / sizeof(GoUint8), -1, -1},
		{"", "", 16, 1, 1, g5, sizeof(g5) / sizeof(GoUint8), -1, -1},
		{"password", "NaCl", 1024, 8, 16, g6, sizeof(g6) / sizeof(GoUint8), -1, -1},
		{"pleaseletmein", "SodiumChloride", 16384, 8, 1, g7, sizeof(g7) / sizeof(GoUint8), -1, -1},
	};
	
	GoInt32 maxInt = (GoInt32)(~((GoUint32)0) >> 1);
	
	TESTVECTOR bad_ones[] = {
		{"p", "s", 0, 1, 1, NULL, -1, -1},                    // N == 0
		{"p", "s", 1, 1, 1, NULL, -1, -1},                    // N == 1
		{"p", "s", 7, 8, 1, NULL, -1, -1},                    // N is not power of 2
		{"p", "s", 16, maxInt / 2, maxInt / 2, NULL, -1, -1}, // p * r too large
	};
	
	GoUint32 errcode;
	GoSlice password;
	GoSlice salt;
	GoUint8 buffer[BUFFER_SIZE];
	coin__UxArray key = {buffer, 0, BUFFER_SIZE};
	
	int good_ones_count = sizeof(good_ones) / sizeof(good_ones[0]);
	int bad_ones_count = sizeof(bad_ones) / sizeof(bad_ones[0]);
	for(int i = 0; i < good_ones_count; i++){
		password.data = good_ones[i].password;
		if( good_ones[i].passwordLength < 0)
			password.len = strlen(good_ones[i].password);
		else
			password.len = good_ones[i].passwordLength;
		password.cap = password.len;
		
		salt.data = good_ones[i].salt;
		if( good_ones[i].saltLength < 0)
			salt.len = strlen(good_ones[i].salt);
		else
			salt.len = good_ones[i].saltLength;
		salt.cap = salt.len;
		
		errcode = SKY_scrypt_Key(password, salt, 
			good_ones[i].N, good_ones[i].r, good_ones[i].p, 
			good_ones[i].keyLength, &key);
		cr_assert(errcode == SKY_OK, "SKY_scrypt_Key failed");
		cr_assert(good_ones[i].keyLength == key.len, "SKY_scrypt_Key failed, incorrect generated key length.");
		int equal = 1;
		for(int j = 0; j < key.len; j++){
			if( ((GoUint8*)key.data)[j] != good_ones[i].output[j]){
				equal = 0;
			}
		}
		cr_assert(equal == 1, "SKY_scrypt_Key failed. Invalid key generated.");
	}
	
	for(int i = 0; i < bad_ones_count; i++){
		password.data = bad_ones[i].password;
		if( bad_ones[i].passwordLength < 0)
			password.len = strlen(bad_ones[i].password);
		else
			password.len = bad_ones[i].passwordLength;
		password.cap = password.len;
		
		salt.data = bad_ones[i].salt;
		if( bad_ones[i].saltLength < 0)
			salt.len = strlen(bad_ones[i].salt);
		else
			salt.len = bad_ones[i].saltLength;
		salt.cap = salt.len;
		
		errcode = SKY_scrypt_Key(password, salt, 
			bad_ones[i].N, bad_ones[i].r, bad_ones[i].p, 
			bad_ones[i].keyLength, &key);
		cr_assert(errcode == SKY_ERROR, "SKY_scrypt_Key didn\'t failed");
	}
}