#include <iostream>
#include <string>
#include <vector>
#include <ctime>
#include <iomanip>
#include <cassert>

#include "botan_all.h"

#include "cppcodec/base32_default_crockford.hpp"

#include "crypto.h"

#define DEFKEY "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"

#if defined USE_AEAD_MODE
#define DEFALGO "ChaCha20Poly1305"
#else
#define DEFALGO "ChaCha"
#endif

#if defined FIXED_IV || defined MAYBE_FIXED_IV
static std::vector<uint8_t> _giv(8);
static const std::vector<uint8_t>& giv = _giv;
#endif

#if defined MAYBE_FIXED_IV
extern bool gNoRandomIv;
#endif

static struct crypto_init {
  Botan::secure_vector<uint8_t> key;
} *gcrypto = nullptr;
std::unique_ptr<crypto_init> auto_cleanup;

static std::string decode_str(const std::string& mail);
extern bool gDebug;
extern bool gCheckLocalpartSize;
extern bool gLowercase;

template <typename T>
static std::string flex_encode(const T& a) {
  switch (gEncode) {
  case encode_method::base_32:
    return base32::encode(a);
  case encode_method::base_58:
    return Botan::base58_check_encode(a);
  default:
    throw std::runtime_error{"unknown encoding method"};
  }
}

static std::string flex_encode(const std::string& a) {
  switch (gEncode) {
  case encode_method::base_32:
    return base32::encode(a);
  case encode_method::base_58:
    {
        const uint8_t* s = reinterpret_cast<const uint8_t*>(a.data());
	size_t l = a.size();
	return Botan::base58_check_encode(s, l);
    }
  default:
    throw std::runtime_error{"unknown encoding method"};
  }
}

template <typename T>
static std::vector<uint8_t> flex_decode(const T& a) {
  switch (gEncode) {
  case encode_method::base_32:
    return base32::decode<std::vector<uint8_t> >(a);
  case encode_method::base_58:
    return Botan::base58_check_decode(a);
  default:
    throw std::runtime_error{"unknown encoding method"};
  }
}

static std::string flex_decode_to(const std::string& a) {
  switch (gEncode) {
  case encode_method::base_32:
    return base32::decode<std::string>(a);
  case encode_method::base_58:
    {
      std::vector<uint8_t> vec = Botan::base58_check_decode(a);
      return std::string(reinterpret_cast<const char*>(vec.data()), vec.size());
    }
  default:
    throw std::runtime_error{"unknown encoding method"};
  }
}

static Botan::secure_vector<uint8_t> flex_decode_tosv(const std::string& a) {
  switch (gEncode) {
  case encode_method::base_32:
    return base32::decode<Botan::secure_vector<uint8_t> >(a);
  case encode_method::base_58:
    {
      std::vector<uint8_t> vec = Botan::base58_check_decode(a);
      return Botan::secure_vector<uint8_t>(vec.begin(), vec.end());
    }
  default:
    throw std::runtime_error{"unknown encoding method"};
  }
}

void init_crypto(const std::string& key) {
  if (!gcrypto)
    gcrypto = new crypto_init;
  std::unique_ptr<Botan::HashFunction> h(Botan::HashFunction::create("SHA-256"));
  if (key.empty())
    h->update(reinterpret_cast<const uint8_t*>(DEFKEY), strlen(DEFKEY));
  else
    h->update(key);
  gcrypto->key = h->final();
#if defined FIXED_IV || defined MAYBE_FIXED_IV
  std::unique_ptr<Botan::HashFunction> hiv(Botan::HashFunction::create("SHA-384"));
  if (key.empty())
    h->update(reinterpret_cast<const uint8_t*>(DEFKEY), strlen(DEFKEY));
  else
    h->update(key);
  Botan::secure_vector<uint8_t> iv = h->final();
  for (int i = 0; i < giv.size(); ++i)
    _giv[i] = iv[i];
#endif
  auto_cleanup.reset(gcrypto);
}

void test_crypto() {
  std::cout << "Botan version " << Botan::version_string() << std::endl;
  // Check version is 2.16.0 (change when upgrading botan)
  static_assert(BOTAN_VERSION_MAJOR == 2 && BOTAN_VERSION_MINOR == 19 &&
                BOTAN_VERSION_PATCH == 0, "Botan version mismatch");
  // Check linked-in version is the compiled version
  assert(Botan::runtime_version_check(BOTAN_VERSION_MAJOR, BOTAN_VERSION_MINOR,
                                      BOTAN_VERSION_PATCH).empty());

  std::vector<std::string> p{"\x61", "\x61\x53", "\x61\x53\x74",
      "\x61\x53\x74\x6c"};

  for (std::vector<std::string>::iterator it = p.begin(); it != p.end(); ++it) {
    std::string enc = flex_encode(*it);
    std::cout << enc << std::endl;
    std::string dec = flex_decode_to(enc);
    assert(dec == *it);
  }
  auto exp = std::chrono::system_clock::now() + std::chrono::hours(72);
  std::string plaintext = "test+mark:" + date2s(exp);
  Botan::secure_vector<uint8_t> pt(plaintext.data(),
                          plaintext.data() + plaintext.length());

  #if defined USE_AEAD_MODE
  std::unique_ptr<Botan::AEAD_Mode> cipher(Botan::get_aead(DEFALGO, Botan::ENCRYPTION));
  assert(cipher);
  #else
  std::unique_ptr<Botan::StreamCipher> cipher(Botan::StreamCipher::create(DEFALGO));
  assert(cipher);
  #endif

  //generate fresh nonce (IV)
  #if defined FIXED_IV
  const std::vector<uint8_t>& iv = giv;
  #else
  std::unique_ptr<Botan::RandomNumberGenerator> rng(new Botan::AutoSeeded_RNG);
  std::vector<uint8_t> iv(8);
  rng->randomize(iv.data(), iv.size());
  #endif

  //set key and IV
  cipher->set_key(gcrypto->key);
  #if defined USE_AEAD_MODE
  cipher->set_ad(iv);
  #else
  cipher->set_iv(iv.data(), iv.size());
  #endif

  // Encrypt
  std::cout << "plaintext len: " << pt.size() << std::endl;
  #if defined USE_AEAD_MODE
  cipher->start(iv);
  cipher->finish(pt);
  #else
  cipher->encipher(pt);
  #endif
  std::cout << "ciphertext len: " << pt.size() << std::endl;
  std::string mail = flex_encode(iv) + "-" + flex_encode(pt);
  std::cout << mail << std::endl;
  std::cout << "final len: " << mail.size() << std::endl;

  // Decrypt
  auto pos = mail.find('-');
  assert(pos != std::string::npos);
  std::string _iv = flex_decode_to(mail.substr(0, pos));
  std::string _mail = mail.substr(pos + 1);
  #if defined USE_AEAD_MODE
  Botan::secure_vector<uint8_t> ct = flex_decode_tosv(_mail);
  std::unique_ptr<Botan::AEAD_Mode> dcipher(Botan::get_aead(DEFALGO, Botan::DECRYPTION));
  assert(dcipher);
  #else
  std::vector<uint8_t> ct = flex_decode(_mail);
  std::unique_ptr<Botan::StreamCipher> dcipher(Botan::StreamCipher::create(DEFALGO));
  assert(dcipher);
  #endif

  dcipher->set_key(gcrypto->key);
  #if defined USE_AEAD_MODE
  std::vector<uint8_t> nonce(_iv.begin(), _iv.end());
  std::vector<uint8_t> ad(_iv.begin(), _iv.end());
  dcipher->set_ad(ad);
  dcipher->start(nonce);
  dcipher->finish(ct);
  #else
  std::vector<uint8_t> __iv(_iv.begin(), _iv.end());
  dcipher->set_iv(__iv.data(), __iv.size());
  dcipher->decrypt(ct);
  #endif
  std::cout << std::string(ct.begin(), ct.end()) << std::endl;

  assert(std::string(ct.begin(), ct.end()) == plaintext);

  // Test functions
  std::string mail2 = generate_address("test+mark", exp);
  if (gDebug)
    std::cout << mail2 << " (" << mail2.size() << ")" << std::endl;
  std::string dec = decode_str(mail2);
  if (gDebug)
    std::cout << dec << " (" << dec.size() << ")" << std::endl;
  assert(dec.find("test+mark") != std::string::npos);
  std::string plaintext2 = "test+mark:" + date2s(exp);
  assert(plaintext2 == dec);
  auto _dec = decode(mail2 + "@example.com");
  assert(_dec.date_valid());

  // Test dates
  auto today = std::chrono::system_clock::now();
  auto yesterday = std::chrono::system_clock::now() - std::chrono::hours(24);
  std::string mail3 = generate_address("bastien", today);
  std::string mail4 = generate_address("bastien", yesterday);
  auto ok = decode(mail3 + "@example.com");
  assert(ok.date_valid());
  auto nok = decode(mail4 + "@example.com");
  assert(!nok.date_valid());
}

std::string generate_address(const std::string& username,
                             const std::chrono::system_clock::time_point& date
                             #if MAYBE_FIXED_IV
                             , bool aNoRandomIv
                             #endif
  ) {
  #if MAYBE_FIXED_IV
  return generate_address(username, date2s(date), aNoRandomIv);
  #else
  return generate_address(username, date2s(date));
  #endif
}

std::string generate_address(const std::string& username,
                             const std::string& date
                             #if MAYBE_FIXED_IV
                             , bool aNoRandomIv
                             #endif
  ) {
  if (!gcrypto)
    init_crypto();

  std::string plaintext(username + ":" + date);
  Botan::secure_vector<uint8_t> pt(plaintext.data(),
    plaintext.data() + plaintext.length());
  #if defined USE_AEAD_MODE
  std::unique_ptr<Botan::AEAD_Mode>
    cipher(Botan::get_aead(DEFALGO, Botan::ENCRYPTION));
  #else
  std::unique_ptr<Botan::StreamCipher>
    cipher(Botan::StreamCipher::create(DEFALGO));
  #endif
  if (!cipher)
    throw std::runtime_error{"Cannot load cipher"};

  //generate fresh nonce (IV)
  #if defined FIXED_IV
  const std::vector<uint8_t>& iv = giv;
  #else
  std::unique_ptr<Botan::RandomNumberGenerator> rng(new Botan::AutoSeeded_RNG);
  std::vector<uint8_t> iv(8);
  #if MAYBE_FIXED_IV
  if (gNoRandomIv || aNoRandomIv)
    std::copy(giv.begin(), giv.end(), iv.begin());
  else
  #endif
    rng->randomize(iv.data(), iv.size());
  #endif

  //set key and IV
  cipher->set_key(gcrypto->key);
  #if defined USE_AEAD_MODE
  cipher->set_ad(iv);
  #else
  cipher->set_iv(iv.data(), iv.size());
  #endif

  // Encrypt
  #if defined USE_AEAD_MODE
  cipher->start(iv);
  cipher->finish(pt);
  #else
  cipher->encipher(pt);
  #endif
  #if defined FIXED_IV
  std::string ret = flex_encode(pt);
  #elif defined MAYBE_FIXED_IV
  std::string ret = (gNoRandomIv || aNoRandomIv) ? flex_encode(pt) :
    flex_encode(iv) + "-" + flex_encode(pt);
  #else
  std::string ret = flex_encode(iv) + "-" + flex_encode(pt);
  #endif
  if (ret.size() > 64 && gCheckLocalpartSize)
    throw std::runtime_error{"Local part too long (" + std::to_string(ret.size()) + ")"};
  if (gLowercase)
    std::transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
  return ret;
}

static std::string decode_str(const std::string& mail) {
  if (!gcrypto)
    init_crypto();

  auto pos = mail.find('-');
  if (pos == std::string::npos) {
    #if !defined FIXED_IV && !defined MAYBE_FIXED_IV
    throw std::runtime_error{"format error (missing -)"};
    #endif
  }
  #if defined FIXED_IV || defined MAYBE_FIXED_IV
  std::string _iv = (pos == std::string::npos) ? "" :
    flex_decode_to(mail.substr(0, pos));
  std::string _mail =  (pos == std::string::npos) ? mail :
    mail.substr(pos + 1);
  #else
  std::string _iv = flex_decode_to(mail.substr(0, pos));
  std::string _mail = mail.substr(pos + 1);
  #endif
  Botan::secure_vector<uint8_t> ct =
    flex_decode_tosv(_mail);
  #if defined USE_AEAD_MODE
  std::unique_ptr<Botan::AEAD_Mode> dcipher(Botan::get_aead(DEFALGO, Botan::DECRYPTION));
  #else
  std::unique_ptr<Botan::StreamCipher> dcipher(Botan::StreamCipher::create(DEFALGO));
  #endif
  if (!dcipher)
    throw std::runtime_error{"Cannot load cipher"};
  dcipher->set_key(gcrypto->key);
  #if defined FIXED_IV
  const std::vector<uint8_t>& __iv = giv;
  #elif defined MAYBE_FIXED_IV
  std::vector<uint8_t> __iv;
  if (_iv.empty())
    __iv = giv;
  else
    __iv = std::vector<uint8_t>(_iv.begin(), _iv.end());
  #else
  std::vector<uint8_t> __iv(_iv.begin(), _iv.end());
  #endif
  #if defined USE_AEAD_MODE
  dcipher->set_ad(__iv);
  dcipher->start(__iv);
  dcipher->finish(ct);
  #else
  dcipher->set_iv(__iv.data(), __iv.size());
  dcipher->decrypt(ct);
  #endif
  return std::string(ct.begin(), ct.end());
}

decoded_address decode(const std::string& address) {
  auto pos = address.find('@');
  if (pos == std::string::npos)
    throw std::runtime_error{"format error (missing @)"};
  decoded_address res;
  std::string mail = address.substr(0, pos);
  res.domain = address.substr(pos + 1);
  std::string plaintext = decode_str(mail);
  pos = plaintext.find(':');
  if (pos == std::string::npos)
    throw std::runtime_error{"plaintext error (missing :)"};
  res.username = plaintext.substr(0, pos);
  res.sdate = plaintext.substr(pos + 1);
  std::tm tm = {};
  std::stringstream ss(res.sdate);
  ss >> std::get_time(&tm, dateFmt);
  if (ss.fail())
    throw std::runtime_error{"plaintext error (cannot parse date)"};
  res.date = std::chrono::system_clock::from_time_t(std::mktime(&tm));
  res.date += std::chrono::hours{23} + std::chrono::minutes{59}
            + std::chrono::seconds{59};
  return res;
}
