// Distributed under the Thrift Software License
//
// See accompanying file LICENSE or visit the Thrift site at:
// http://developers.apache.com/thrift/

#include <cstring>
#include <sys/socket.h>
#include <sys/poll.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <errno.h>
#include <iostream>

#include "TTransportException.h"
#include <Thrift.h>
#include "TSSLContext.h"
#include <boost/shared_ptr.hpp>

using namespace apache::thrift;
using namespace apache::thrift::transport;

namespace apache { namespace thrift { namespace transport {

      using namespace std;

      // Default password callback: Ask the user.
      int pem_passwd_cb(char *buf, int size, int rwflag, void *userdata){
	string password = "";
	cout << "Enter PEM Password: " << endl;
	cin >> password;
	
	strncpy(buf, password.c_str(), size);
	buf[size - 1] = '\0';
	return(strlen(buf));
      }

      /*
       * Wrapper around a SSL Context.
       * Can be passed via a shared_ptr to many client or server instances.
       */
      TSSLContext::TSSLContext() :
	ctx_(NULL),
	useEGD_(false), // Use a EGD device
	egd_("")        // Path to the EGD device.
      {
	// Set up OpenSSL a bit here.
	SSL_library_init();                      /* initialize library */
	SSL_load_error_strings();                /* readable error messages */

	// Prime the PRNG. (We need randomness)
	seedRand();

	// Now generate a ssl context.
	ctx_ = SSL_CTX_new(SSLv3_method());
	if (ctx_ == NULL){ 
	  GlobalOutput(ERR_reason_error_string(ERR_get_error()));
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSecureSocket: Unable to create ctx");
	}

	// Callback for when we attempt to load a password protected
	// PEM file.
	SSL_CTX_set_default_passwd_cb(ctx_, &pem_passwd_cb);

	// void * data passed to the callback function.
	// By default, we just ask the user, so we don't need to pass anything.
	SSL_CTX_set_default_passwd_cb_userdata(ctx_, NULL);

	// By default, ask for client certs as well as server ones.
	// This can be changed later by the application.
	SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, NULL);
      }
      
      TSSLContext::~TSSLContext(){
	ERR_free_strings();
	SSL_CTX_free(ctx_);
      }

      void TSSLContext::loadTrustStore(std::string& trustStoreFile){
	if(!ctx_){
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: context not created.");
	}
	if(! SSL_CTX_load_verify_locations(ctx_, trustStoreFile.c_str(), 
					   NULL)){
	  GlobalOutput(ERR_reason_error_string(ERR_get_error()));
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: Unable to load trust store");
	}
      }
      void TSSLContext::loadCertChainFile(std::string& certChainFile){
	if(!ctx_){
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: context not created.");
	}
	if(SSL_CTX_use_certificate_chain_file(ctx_, 
					      certChainFile.c_str()) < 1){
	  GlobalOutput(ERR_reason_error_string(ERR_get_error()));
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: Unable to load server keys");
	}
      }      
      void TSSLContext::loadPrivateKey(std::string& privateKeyFile){
	if(!ctx_){
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: context not created.");
	}
	if(SSL_CTX_use_PrivateKey_file(ctx_, 
				       privateKeyFile.c_str(), 
				       SSL_FILETYPE_PEM) < 1){
	  GlobalOutput(ERR_reason_error_string(ERR_get_error()));
	  throw TTransportException(TTransportException::UNKNOWN, 
				    "TSSLContext: Unable to load priv. key.");
	}
      }

      // Seed the context with randomness.
      int TSSLContext::seedRand(){
	if(!useEGD_){
	  if (RAND_status()){ 
	    // We already have all of the random bits we need
	    // retrieved from a random device (like /dev/urandom).
	    return SEED_WITH_URANDOM;
	  }
	} else {
	  int timesThrough = 0;
	  while (!RAND_status() && 
		 (timesThrough++ < TIMES_TO_TRY_ADDING_RAND)){ 
	    RAND_egd(egd_.c_str());
	  }
	  if (RAND_status()){ // Was the seeding successful?
	    return SEED_WITH_EGD;
	  }
	}
	return SEED_FAILED;
      }

      int TSSLContext::addRand(void *rand){
	throw TTransportException(TTransportException::UNKNOWN, 
				  "TSSLContext: addRand() not supported.");
      }

      SSL_CTX * TSSLContext::getSSLCTX(){
	return ctx_;
      }

}}} // apache::thrift::transport

