//
// Andrei.Ryjpv@Spasu.NET
// Tue Jun 12 10:06:10 CEST 2007
//
// Messages signing/encryption and forwarding via UDP
//
// Nice-to-have leftovers:
//
//      Verbosity (if getenv FORMSEC_VERBOSE==true)
//
//      Padding so that all encrypned lines are same length
//
//      Encrypt the signature along with payload
//      (for now, it's obviously encrypted with client's private key,
//      but its position is easily detectable within the message
//
//      Buid instructions:
//
//      Prepare the openssl source code directory and compile openssl,
//      for instance:
//              wget --passive-ftp ftp://ftp.sunfreeware.com/pub/freeware/SOURCES/openssl-0.9.8e.tar.gz
//              cat openssl-0.9.8e.tar.gz | gzip -d | tar xf -
//              cd openssl-0.9.8e && ./config && make
//
//      Set the env variables, for instance
//      SSL_SOURCE=`pwd`
//      OBJLIBSAGR="$SSL_SOURCE/apps/apps.o $SSL_SOURCE/apps/app_rand.o $SSL_SOURCE/libcrypto.a"
//      INCLUDES="-I$SSL_SOURCE -I$SSL_SOURCE/include -I$SSL_SOURCE/apps"
//      Solaris# LIBS="-lnsl -lsocket -lresolv"
//      Linux  # LIBS="-ldl"
//      Cygwin # LIBS=""
//
//      cd $WHERE_YOUR_rsa.c_IS
//      gcc $INCLUDES -o rsasig rsasig.c $OBJLIBSAGR $LIBS
//
//      if openssl has been compiled in 64-bit mode (./Configure solaris64-sparcv9-gcc)
//      rsasig must be compiled the same way, i.e. gcc -m64 -mcpu=ultrasparc -O3
//      64-bit binaries are strongly recomended for better performance
//
#include <sys/types.h>
#include <stdio.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <ctype.h>
#include <limits.h>

#include <openssl/opensslconf.h>
#undef OPENSSL_NO_RSA

#include <apps.h>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/rsa.h>
#include <openssl/sha.h>
#include <openssl/des.h>

#define RCV             1
#define SND             2
#define ENC             3
#define DEC             4
#define FWD             5

#define padsize         8
#define NDIDITS         8
#define NDIDITS_AND_0  10

typedef int             result;
#define success         1
#define failure         0

struct                  sockaddr_in sock_server, sock_client;
static time_t           epoch;
struct tm               tms;

int                     si, so, connected, megabuflen, outbuflen, inbuflen, rsa_mode;
int                     iaddr, oaddr, ip, op, lport, i, rc, ln;

socklen_t               socklen = sizeof(struct sockaddr_in);
char                    *progname, *fromport, *toport, *fromaddr, *toaddr;

char                    *hostname;
char                    UMIDbuf[80]; // UMID:hostname:DATE:LineNum
static char             *megabuf=NULL;
unsigned char           *inbuf=NULL, *outbuf=NULL, *midbuf=NULL, *outptr=NULL, pad;
unsigned char           *p, *q, *z;
char                    *priv_key_filename, *pub_key_filename;
int                     priv_keyform=FORMAT_PEM, pub_keyform=FORMAT_PEM;
BIO                     *in = NULL;

EVP_PKEY                *priv_key = NULL, *pub_key=NULL;
RSA                     *priv_rsa = NULL, *pub_rsa = NULL;
int                     inlen=0, midlen=0, outlen=0;
int                     priv_keysize, pub_keysize, bufsize;
SHA_CTX                 sha_c;
unsigned char           digest[SHA_DIGEST_LENGTH];
DES_cblock              Key;
DES_key_schedule        Sh1, Sh2, Sh3;

struct {
	char            clen[8];  // msg len in text form to keep endianness happy
	DES_cblock      key;
} Header;


void Debug(int point) { fprintf(stderr, "\nDebug: %4d\n\n", point); fflush(stderr); }

void diep(char *s) { perror(s); exit(1); }

void usage()
{
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "\nSorry, wrong usage. Try one of:                                  \n");
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "  Unbound-receive, write to stdout                                 \n");
    fprintf(stderr, "    %s rcv port                                                    \n", progname);
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "  Read from stdin, send to addr:port                               \n");
    fprintf(stderr, "    %s snd toaddr:toport                                           \n", progname);
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "  Read from stdin, sign, encrypt, write to stdout                  \n");
    fprintf(stderr, "    %s enc  Client_privkey Server_pubkey                           \n", progname);
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "  Read from stdin, decrypt, verify, write to stdout                \n");
    fprintf(stderr, "    %s dec  Server_privkey Client_pubkey                           \n", progname);
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "  Bound_receive, sign, encrypt, send:                              \n");
    fprintf(stderr, "    %s fwd  Client_privkey Server_pubkey  fromaddr:port toaddr:port\n", progname);
    fprintf(stderr, "                                                                   \n");
    fprintf(stderr, "                                                                   \n");
    exit(1);
}


void Load_Structures() {

    in = BIO_new_fp(stdin, BIO_NOCLOSE);
    if (!bio_err) bio_err = BIO_new_fp(stderr, BIO_NOCLOSE);
    if (!load_config(bio_err, NULL))  diep("bio_err");

    ERR_load_crypto_strings();
    OpenSSL_add_all_algorithms();
    pad = RSA_PKCS1_PADDING;

    app_RAND_load_file(NULL, bio_err, 0);

    if(rsa_mode == FWD  &&  !(hostname=getenv("HOST"))) {
	fprintf(stderr, "'HOST' variable not set in env - exiting\n");
	Cleanup_And_Exit();
    }

    bufsize = 2 * (priv_keysize + pub_keysize);
    megabuflen = 16 * 1024  + bufsize;
    outbuflen = megabuflen;
    inbuflen =  megabuflen;
    inbuf =  OPENSSL_malloc(megabuflen);
    midbuf = OPENSSL_malloc(megabuflen);
    outbuf = OPENSSL_malloc(megabuflen);
    megabuf = OPENSSL_malloc(megabuflen);
    megabuf[megabuflen]='\0';

}


void Load_Keys() {

    priv_key = load_key(bio_err, priv_key_filename, priv_keyform, 0, NULL, NULL, "Private Key");
    if(!priv_key) usage();
    pub_key = load_pubkey(bio_err, pub_key_filename, pub_keyform, 0, NULL, NULL, "Public Key");
    if(!pub_key) usage();

    priv_rsa = EVP_PKEY_get1_RSA(priv_key); EVP_PKEY_free(priv_key);
    if(!priv_rsa) {
	fprintf(stderr, "Error getting private RSA key\n");
	ERR_print_errors(bio_err);
	Cleanup_And_Exit();
    }
    priv_keysize = RSA_size(priv_rsa);

    pub_rsa = EVP_PKEY_get1_RSA(pub_key); EVP_PKEY_free(pub_key);
    if(!pub_rsa) {
	fprintf(stderr, "Error getting public RSA key\n");
	ERR_print_errors(bio_err);
	Cleanup_And_Exit();
    }
    pub_keysize = RSA_size(pub_rsa);

}


void Set_Receive() {
    if ((si = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) diep("input socket()");
    sock_server.sin_family = AF_INET;
    if ( !(ip = atoi(fromport)) ) diep("Wrong IP address or port number to listen on");
    sock_server.sin_port = htons(ip);
    sock_server.sin_addr.s_addr = htonl(INADDR_ANY);
    if(rsa_mode == FWD) // we bind to local address in FWD mode only
	if (inet_aton(fromaddr, &(sock_client.sin_addr)) == 0) diep("Wrong IP address for binding");

    if (bind(si, (struct sockaddr *) &sock_server, socklen) == -1) diep("Input bind()");
    connected = 0; // will be set in the loop, for local address binding in FWD mode
}


void Set_Send() {
    if ((so = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) diep("socket");
    sock_server.sin_family = AF_INET;
    if ( !(op = atoi(toport)) ) diep("Wrong IP address or port number to send to");
    sock_server.sin_port = htons(op);
    if (inet_aton(toaddr, &(sock_server.sin_addr)) == 0) diep("Wrong IP address for sending");
}


result Receive() {

    ln++;
    inlen = recvfrom(si, inbuf, inbuflen, 0, (struct sockaddr *) &sock_client, &socklen);
    if (inlen > 0) {
	if(rsa_mode == FWD  && !connected) { // bind to local address
	    connected = 1;

	    //
	    // Switch it off for now, as HUPing syslogd breaks the connection -
	    // therefore we shall listen in unbound mode only
	    //
	    // if (connect(si, (struct sockaddr *) &sock_client, socklen) == -1) {
	    //     perror("connect()");
	    //    Cleanup_And_Exit();
	    // }

	}
	if (inbuf[inlen-1] == '\n') inbuf[--inlen] = '\0';
	if(rsa_mode == FWD) {
	    int ll;
	    epoch=time(0);
	    if( localtime_r(&epoch, &tms) == NULL) {
		perror("localtime_r()");
		Cleanup_And_Exit();
	    }
	    sprintf(UMIDbuf, " UMID:%s:%02d%02d%02d%02d%02d%02d:%07d",
		hostname, tms.tm_year-100, tms.tm_mon+1, tms.tm_mday, tms.tm_hour, tms.tm_min, tms.tm_sec, ln);
	    ll=strlen(UMIDbuf);
	    strncpy(inbuf+inlen, UMIDbuf, inbuflen-inlen-ll-2);
	    inlen += ll;
	    // ;; fprintf(stderr, "Forwarding line %s\n", inbuf);
	    // ;; fflush(stderr);
	}
	outptr = inbuf; outlen = inlen; // may be reset by further processing
	return(success);
    } else {
	perror("receive()");
	Cleanup_And_Exit();
    }
}


result Send() { return( (sendto(so, outptr, outlen, 0, (struct sockaddr *) &sock_server, socklen) == outlen) ); }


result Read() {
    ln++;
    inlen = BIO_gets(in, inbuf, inbuflen);
    if (inlen > 0) {
	if (inbuf[inlen-1] == '\n') inbuf[--inlen] = '\0';
	outptr = inbuf; outlen = inlen; // may be reset by further processing
	return(success);
    } else
	return(failure);
}


result Write() {
    int ol;
    if(outlen <= 0) {
	ERR_print_errors(bio_err);
	fprintf(stderr, "ERROR - Output empty\n");
	return(failure);
    }
    ol=write(1, outptr, outlen) ;
    // if (rsa_mode == FWD || rsa_mode == ENC) write(1,"\n",1);
    write(1,"\n",1);
    return(ol == outlen);
}


result Sign_And_Encrypt() {

    // Fill in the DES key, append the (int)inlen, encrypt both together with a server's pubkey
    // and write to megabuf, then base64encode to outbuf
    RAND_seed(Key, sizeof Key);
    DES_random_key(&Key);
    if(rc=DES_set_key_checked(&Key, &Sh1)) fprintf(stderr, "Key Check failed on encoding - %d\n", rc);
    if(rc=DES_set_key_checked(&Key, &Sh2)) fprintf(stderr, "Key Check failed on encoding - %d\n", rc);
    if(rc=DES_set_key_checked(&Key, &Sh3)) fprintf(stderr, "Key Check failed on encoding - %d\n", rc);

    bzero(Header.clen, sizeof Header.clen);
    sprintf(Header.clen, "%d", inlen);
    bcopy((unsigned char *)&Key, (unsigned char *)&Header.key, sizeof Key);
    bcopy((unsigned char *)&Header, outbuf, sizeof Header);

    outlen = RSA_public_encrypt(sizeof Header, outbuf, megabuf, pub_rsa, pad);

    if(outlen != pub_keysize) {
	fprintf(stderr, "Symkey encryption failed\n");
	return(failure);
    }

    // calculate the digest (cleartext + (int)inlen) and sign it with client's privkey
    SHA1_Init(&sha_c);
    SHA1_Update(&sha_c, inbuf, inlen);
    SHA1_Final(&(digest[0]),&sha_c);

    //tlen = RSA_private_encrypt(SHA_DIGEST_LENGTH, digest, megabuf+pub_keysize, priv_rsa, pad);
    outlen = RSA_private_encrypt(SHA_DIGEST_LENGTH, digest, midbuf,  priv_rsa, pad);

    if(outlen != priv_keysize) {
	fprintf(stderr, "Signing failed\n");
	Cleanup_And_Exit();
    }

    //
    // append cleartext payload to the signed hash, and symmetrically encrypt'em together,
    // writing right after the symmetrically encrypted password.
    // This way, we don't need to salt the payload, as the signed hash is already salted
    // (asymm encrypted hash + cleartext payload) -> symm encrypt -> append to symm.encr.passwd
    //
    bcopy(inbuf, midbuf+priv_keysize, inlen);
    DES_ede3_cbc_encrypt(midbuf, megabuf+pub_keysize, priv_keysize+inlen, &Sh1, &Sh2, &Sh3, &Key, DES_ENCRYPT);

    outlen = b64enc(megabuf, outbuf, priv_keysize+pub_keysize+inlen+padsize, outbuflen);
    outptr = outbuf;
    return(outlen > 0);

}


result Decrypt_And_Verify() {

    // b64 decode inbuf to megabuf, but only if it at least contains the public key-encrypted hash
    if (inlen < (pub_keysize * 8 / 6)) return(failure);
    inlen = b64dec(inbuf, megabuf, inlen, megabuflen);

    if (inlen <= 0 ) return(failure);

    // the input structure is:
    //      asymmetrically encrypted inlen+symkey structure,
    //      signed hash
    //      symmetrically encrypted (signed hash + payload)
    //
    // Decrypt the DES key+inlen using server's privkey, extract inlen
    midlen = RSA_private_decrypt(priv_keysize, megabuf, midbuf, priv_rsa, pad);

    if(midlen != sizeof Header) {
	fprintf(stderr, "Symmetric key decryption failed for line %d - %d b expected, %d received\n", ln, sizeof Header, midlen);
	return(failure);
    }

    // Fetch the inlen and symkey from the Header structure
    bcopy(midbuf, (unsigned char *)&Header, sizeof Header);
    bcopy((unsigned char *)&Header.key, (unsigned char *)&Key, sizeof Key);
    inlen = atoi(Header.clen);

    if (!inlen) return(failure);
    if (inlen < 0  || ((inlen+priv_keysize+pub_keysize)*8/6) > megabuflen) {
	fprintf(stderr, "Wrong input lenth decoded in line %d - %d bytes, maxbuf is %d b\n", ln, inlen, megabuflen);
	return(failure);
    }

    outlen = inlen;

    // Make the key schedules and decrypt the payload
    if(rc=DES_set_key_checked(&Key, &Sh1)) fprintf(stderr, "Key Check failed on decoding - %d\n", rc);
    if(rc=DES_set_key_checked(&Key, &Sh2)) fprintf(stderr, "Key Check failed on decoding - %d\n", rc);
    if(rc=DES_set_key_checked(&Key, &Sh3)) fprintf(stderr, "Key Check failed on decoding - %d\n", rc);

    DES_ede3_cbc_encrypt(megabuf+priv_keysize, midbuf, pub_keysize+inlen, &Sh1, &Sh2, &Sh3, &Key, DES_DECRYPT);
    // bcopy(midbuf+pub_keysize, outbuf, outlen);
    outptr = midbuf+pub_keysize;


    // Decrypt the hash using client's pubkey
    midlen = RSA_public_decrypt(pub_keysize, midbuf, inbuf, pub_rsa, pad);

    if(midlen != SHA_DIGEST_LENGTH) {
	fprintf(stderr, "Wrong digest length - %d expected, %d received\n", SHA_DIGEST_LENGTH, midlen);
	return(failure);
    }

    SHA1_Init(&sha_c);
    SHA1_Update(&sha_c, outptr, outlen);
    SHA1_Final(&(digest[0]),&sha_c);

    p = inbuf; q = digest; z = digest + SHA_DIGEST_LENGTH;
    while(q < z) {
	if(*p++ != *q++) {
	    fprintf(stderr, "Wrong digest match at position %d, line %d, %d b long:\n", q-digest, ln, outlen);
	    fflush(stderr);
	    write(2, "\n==>", 4);
	    write(2, outptr, outlen);
	    write(2, "<==\n", 4);
	    return(failure);
	}
    }

    return(success);
}


Cleanup_And_Exit() {
    RSA_free(priv_rsa);
    RSA_free(pub_rsa);
    if(inbuf)   OPENSSL_free(inbuf);
    if(midbuf)  OPENSSL_free(midbuf);
    if(outbuf)  OPENSSL_free(outbuf);
    if(megabuf) OPENSSL_free(megabuf);
    exit(0);
}


char A[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

int b64dec(unsigned char *in, unsigned char *out, int ilen, int mlen)
{
    int  n = 0, k = 256, c;
    unsigned char b = 0, v[256], uc;
    unsigned char *p, *q, *ep, *eq;

    while(  k > 0      )  v[--k] = 64;
    while(  uc = A[k]  )  v[ uc & 255 ] = k++;

    p=in; q=out; ep=p+ilen; eq=out+mlen;
    while( p<ep  &&  q<eq ) {
	c= *p++;
	if(  c > 255  ||  c < 0  ||  v[c] > 63  )  continue;
	n += 6;
	if(  n > CHAR_BIT  ){
	    b |=  v[c]  >>  n - CHAR_BIT;
	    *q++ = b;
	    n -=  CHAR_BIT;
	    b =  0;
	}
	b |=  v[c]  <<  CHAR_BIT - n;
    }

    if(  n == CHAR_BIT )  *q++ = b;
    return((int)((long)q-(long)out));
}


int b64enc(unsigned char *in, unsigned char *out, int ilen, int mlen)
{
    int n = 0, k = 0, c;
    unsigned char b = 0;
    unsigned char *p, *q, *ep, *eq;

    p=in; q=out; ep=p+ilen; eq=out+mlen;
    while( p<ep  &&  q<eq ) {
	unsigned char a;
	int t;

	c= *p++;
	a = c;
	t = CHAR_BIT;

	while(  t > CHAR_BIT - n  ){
	    b |=  a >> n;
	    a <<=  CHAR_BIT - n;
	    t -=  CHAR_BIT - n;
	    n +=  CHAR_BIT - n;
	    *q++ = A[ b >> CHAR_BIT-6 ];
	    b <<=  6;
	    n -=  6;
	}
	b |=  a >> n;
	n +=  t;
    }

    while(  n > 0  ) {
	*q++ = A[ b >> CHAR_BIT-6 ];
	b <<=  6;
	n -=  6;
    }

    while(  n != 0  ){
	if(  n < 0  )  n += CHAR_BIT;
	n -= 6;
	*q++ = '=';
    }

    return((int)((long)q-(long)out));
}



int main(int argc, char **argv)
{
    progname= *argv++; if(!(--argc)) usage();

    ln = 0; // increments in Read() and Receive();

    if (!strncmp(*argv,"fwd",3))  {  // bound-receive, sign, encrypt, send

	rsa_mode=FWD;
	argv++; if(!(--argc)) usage();
	priv_key_filename = *argv++; if (!(--argc)) usage();
	pub_key_filename = *argv++; if (!(--argc)) usage();
	fromaddr= *argv++; if(!(--argc)) usage();
	if (!(fromport= strchr(fromaddr, ':'))) usage();
	*fromport++ = '\0';
	toaddr= *argv++; if((--argc)) usage(); // too many args;
	if (!(toport= strchr(toaddr, ':'))) usage();
	*toport++ = '\0';
	Load_Structures();
	Load_Keys();
	Set_Receive();
	Set_Send();
	while(Receive())
	   if(Sign_And_Encrypt())
	       if(!Send())
		   Cleanup_And_Exit();

    } else if (!strncmp(*argv,"enc",3))  {  // read-sign-encrypt-write

	rsa_mode=ENC;
	argv++; if(!(--argc)) usage();
	priv_key_filename = *argv++; if (!(--argc)) usage(); // too many args;
	pub_key_filename = *argv++;  if ((--argc)) usage();
	Load_Structures();
	Load_Keys();

	while(Read())
	    if(Sign_And_Encrypt())
		if(!Write()) Cleanup_And_Exit();

    } else if (!strncmp(*argv,"dec",3))  {  // read-decrypt-verify-write

	rsa_mode=DEC;
	argv++; if(!(--argc)) usage();
	priv_key_filename = *argv++; if (!(--argc)) usage();
	pub_key_filename = *argv++;  if ((--argc)) usage(); // too many args;
	Load_Structures();
	Load_Keys();

	while(Read())
	    if(Decrypt_And_Verify())
		if(!Write()) Cleanup_And_Exit();


    } else if (!strncmp(*argv,"snd",3))  {  // unbound-receive, write to stdout

	rsa_mode=SND;
	argv++; if(!(--argc)) usage();
	toaddr= *argv++; if((--argc)) usage(); // too many args;
	if (!(toport= strchr(toaddr, ':'))) usage();
	*toport++ = '\0';
	Load_Structures();
	Set_Send();

	while(Read()) if(!Send()) Cleanup_And_Exit();

    } else if (!strncmp(*argv,"rcv",3))  {  // unbound-receive, write to stdout

	rsa_mode=RCV;
	argv++; if(!(--argc)) usage();
	fromport= *argv++; if((--argc)) usage(); // too many args;
	Load_Structures();
	Set_Receive();

	while(Receive()) if(!Write()) Cleanup_And_Exit();

    } else

	usage();

    Cleanup_And_Exit();
}




