/*
 * This application made by SaEeD
 * RSA algorithem implementation
 * useful website: http://islab.oregonstate.edu/koc/ece575/02Project/Mor/RSAdemo.java
 * Email: 0mega7@bsdmail.com
 *  
 */

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;
 
public class RSA_Encrypt 
{
    private static BigInteger P;   //Prime Number
    private static BigInteger Q;   //Prime Number
    private static BigInteger Phi; // Phi= (q-1)*(p-1)
    private static BigInteger N;  //Modulus
    private static BigInteger E; //Encryption Exponent-> popular Public Key: e=2^16 +1 = 65537
    private static BigInteger D; //Decryption Exponent
    
     private final static BigInteger One = new BigInteger("1");
    
    public static void main(String [] args) throws IOException
    {
        System.out.println("!!! Welcome to SaEeD RSA De/Encryption Application !!!");
        if(args.length != 1){
          System.out.println("Usage: java RSA_Encrypt < encrypt | decrypt >");
          System.exit(-1);
        }
        BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
        String input = "";
        
        if(args[0].equals("encrypt")){ //reading the program arguments
            
            System.out.print("[+]Starting Encryption operation.\n" +
                    "[-]Please enter these values.\nP: ");
            input = stdin.readLine();
            P = new BigInteger(input);
            System.out.print("Q: ");
            input = stdin.readLine();
            Q = new BigInteger(input);
            System.out.print("E: ");
            input = stdin.readLine();
            E = new BigInteger(input);
            N = P.multiply(Q);
            Phi = (P.subtract(One)).multiply(Q.subtract(One));
            
            while(Phi.gcd(E).intValue() > 1){        
             E = E.add(One); 
             System.out.println("E:= " + E.toString() +" gcd: " + Phi.gcd(E).intValue());

            }
            System.out.println("[+]Correct Value for E is: " + E.toString());
            D = E.modInverse(Phi);            
            System.out.println("Keys Are:\n############\nEncryption Exponent = " + E.toString() + 
                    "\nDecryption Exponent = " + D.toString() + 
                    "\nModulus = " + N.toString()+ "\n############");
            System.out.print("[+]Now please enter your text.\n~> ");
            input = stdin.readLine();
            System.out.println("####################\n# Message in Bytes #\n####################");
            
            BigInteger msg = new BigInteger(input.getBytes());
            System.out.println(msg.toString());
            System.out.println("###########################\nEncrypted message is:\n"+
                    encrypt(msg, E, N).toString());
                        
        }else if(args[0].equals("decrypt")){
            System.out.print("[+]Starting Decryption operation.\n" +
                    "[-]Please enter these values.\nD: ");
            input = stdin.readLine();
            D = new BigInteger(input);
            System.out.print("N: ");
            input = stdin.readLine();
            N = new BigInteger(input);
            System.out.println("[+]Please enter the Cipher text now:");
            input = stdin.readLine();
            
            BigInteger cipher = new BigInteger(input);
            BigInteger text = decrypt(cipher, D, N);
            String str = new String(text.toByteArray());
            System.out.println("###########################\nDecrypted message is:\n" + str);
        }else{
            System.out.println("[!]Uknown Command[!]");
            
        }
        System.out.println("[+]Program Terminating...");
        System.exit(0);
    }
    
    //This is a little tricky because the user is allowed to choose
// the size of n.  The value to be encrypted must be less than n.
// So first I find the bit size of n, then subract one, and that is the
// size of the message that I will encrypt at one time.  This ensures 
// the message chunk that is encrypted is smaller than n.  I use a
// mask to take one chunk of message at a time.  Then the chunk is
// encrypted and placed in the result c.  During the next iteration
// the message is shifted right and the result is shifted left and combined
// with c.  The encrypted chunk must be the same bit size as n so that no
// data is lost. 
// Encryption is done using the modPow function provide by the BigInt class.

public static BigInteger encrypt(BigInteger m, BigInteger e, BigInteger n) {
	BigInteger c, bitmask;
	c = new BigInteger("0");
	int i = 0;
	bitmask = (new BigInteger("2")).pow(n.bitLength()-1).subtract(new BigInteger("1"));
	while (m.compareTo(bitmask) == 1) {
		c = m.and(bitmask).modPow(e,n).shiftLeft(i*n.bitLength()).or(c);
		m = m.shiftRight(n.bitLength()-1);
		i = i+1;
	}
	c = m.modPow(e,n).shiftLeft(i*n.bitLength()).or(c);
	return c;
}


//Decryption is done just as encryption above, only now the data is read in
// in chunks the same size as n, and the result, if correct, will be one bit 
// less than the size of n (because that was the original chuck size).

public static BigInteger decrypt(BigInteger c, BigInteger d, BigInteger n) {
	BigInteger m, bitmask;
	m = new BigInteger("0");
	int i = 0;
	bitmask = (new BigInteger("2")).pow(n.bitLength()).subtract(new BigInteger("1"));
	while (c.compareTo(bitmask) == 1) {
		m = c.and(bitmask).modPow(d,n).shiftLeft(i*(n.bitLength()-1)).or(m);
		c = c.shiftRight(n.bitLength());
		i = i+1;
	}
	m = c.modPow(d,n).shiftLeft(i*(n.bitLength()-1)).or(m);
	
  	return m;
}


}