(********************************************************************)
(*                                                                  *)
(*  pkcs1.s7i     RSA public-key cryptography standard #1 support.  *)
(*  Copyright (C) 2013 - 2015, 2019 - 2021, 2023  Thomas Mertes     *)
(*                                                                  *)
(*  This file is part of the Seed7 Runtime Library.                 *)
(*                                                                  *)
(*  The Seed7 Runtime Library is free software; you can             *)
(*  redistribute it and/or modify it under the terms of the GNU     *)
(*  Lesser General Public License as published by the Free Software *)
(*  Foundation; either version 2.1 of the License, or (at your      *)
(*  option) any later version.                                      *)
(*                                                                  *)
(*  The Seed7 Runtime Library is distributed in the hope that it    *)
(*  will be useful, but WITHOUT ANY WARRANTY; without even the      *)
(*  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR *)
(*  PURPOSE.  See the GNU Lesser General Public License for more    *)
(*  details.                                                        *)
(*                                                                  *)
(*  You should have received a copy of the GNU Lesser General       *)
(*  Public License along with this program; if not, write to the    *)
(*  Free Software Foundation, Inc., 51 Franklin Street,             *)
(*  Fifth Floor, Boston, MA  02110-1301, USA.                       *)
(*                                                                  *)
(********************************************************************)


include "bytedata.s7i";
include "msgdigest.s7i";
include "hmac.s7i";


const type: rsaKey is new struct
    var bigInteger: modulus  is 0_;  # n
    var bigInteger: exponent is 0_;  # public: e, private: d
    var integer: modulusLen is 0;    # k denotes the length in octets of the RSA modulus n.
  end struct;

# n  The RSA modulus, a positive integer
# e  The RSA public exponent, a positive integer
# d  The RSA private exponent, a positive integer


(**
 *  Type to describe a pair of RSA keys (private key and public key).
 *)
const type: rsaKeyPair is new struct
    var rsaKey: publicKey is rsaKey.value;
    var rsaKey: privateKey is rsaKey.value;
  end struct;


(**
 *  Generate an RSA key (private or public) from ''modulus'' and ''exponent''.
 *)
const func rsaKey: rsaKey (in bigInteger: modulus, in bigInteger: exponent) is func
  result
    var rsaKey: aKey is rsaKey.value;
  begin
    aKey.modulus := modulus;
    aKey.exponent := exponent;
    # writeln("bitLength(modulus): " <& bitLength(modulus));
    aKey.modulusLen := pred(bitLength(modulus)) mdiv 8 + 1;
    # writeln("modulusLen: " <& aKey.modulusLen);
  end func;


const func string: literal (in rsaKey: aKey) is
  return "rsaKey(" <& aKey.modulus <& "_, " <&
                      aKey.exponent <& "_)";


(**
 *  Generate an RSA key pair from ''publicKey'' and ''privateKey''.
 *)
const func rsaKeyPair: rsaKeyPair (in rsaKey: publicKey, in rsaKey: privateKey) is func
  result
    var rsaKeyPair: aKeyPair is rsaKeyPair.value;
  begin
    aKeyPair.publicKey := publicKey;
    aKeyPair.privateKey := privateKey;
  end func;


(**
 *  Generate an RSA key pair from ''modulus'', ''publicExponent'' and ''privateExponent''.
 *)
const func rsaKeyPair: rsaKeyPair (in bigInteger: modulus,
    in bigInteger: publicExponent, in bigInteger: privateExponent) is func
  result
    var rsaKeyPair: aKeyPair is rsaKeyPair.value;
  begin
    aKeyPair.publicKey.modulus := modulus;
    aKeyPair.publicKey.exponent := publicExponent;
    aKeyPair.publicKey.modulusLen := pred(bitLength(modulus)) mdiv 8 + 1;
    aKeyPair.privateKey.modulus := modulus;
    aKeyPair.privateKey.exponent := privateExponent;
    aKeyPair.privateKey.modulusLen := aKeyPair.publicKey.modulusLen;
  end func;


const func string: literal (in rsaKeyPair: aKeyPair) is
  return "rsaKeyPair(" <& literal(aKeyPair.publicKey) <& ", " <&
                          literal(aKeyPair.privateKey) <& ")";


const rsaKeyPair: stdRsaKeyPair is rsaKeyPair(
    24879323998282062445559239376138627672251789324639932255760945448830518737190739384640974288529252990629999154107088744522353230056093690582734642121734324688998783187082956551041534076711835155928707829063256162504640988273175086145464721272768789724523888480071692284791450208699922038033347977493657812564984507698611204922456314960345290535639893058558496084848086047929946200129560169331200998008734185791703119413114661834404705794384334141821974803140053993070324815699000969159963681641979473501237806530434553142494356819076895646395931159040168081404923835928757739530245969373700156294158081824152324534461_,
    65537_,
    13088626772857606374994147660260732180049394881287449598152583688371128141673593733366670911392214849793873855002581835202125435492530910232563666189681493608607352285338757891981781465383991523965240833886856981107389901791087944521771406381777047043992471840577258748417529339084119078189629851351546974405368864243464335513384244053469543096484265307474308910575994412661107393352461835904945528516443603610779331782072425524550819129133717693244132152897414694106718806022258591593987246936431975237846760817339494993145835504324086601210101710742420357193277289205188098444617372451223173144294880256099265243073_);


const func boolean: isProbablyPrime (in bigInteger: primeCandidate, in var integer: count) is func
  result
    var boolean: isProbablyPrime is TRUE;
  local
    var bigInteger: aRandomNumber is 0_;
  begin
    while isProbablyPrime and count > 0 do
      aRandomNumber := rand(1_, pred(primeCandidate));
      isProbablyPrime := modPow(aRandomNumber, pred(primeCandidate), primeCandidate) = 1_;
      decr(count);
    end while;
    # writeln(count);
  end func;


const func bigInteger: getProbablyPrime (in integer: binDigits, in integer: count) is func
  result
    var bigInteger: probablyPrime is 0_;
  begin
    probablyPrime := rand(0_, 2_**binDigits - 1_);
    if not odd(probablyPrime) then
      incr(probablyPrime);
    end if;
    while not isProbablyPrime(probablyPrime, count) do
      # write(".");
      # flush(OUT);
      probablyPrime +:= 2_;
    end while;
    # writeln;
  end func;


(**
 *  Generate a new RSA keyPair (private key and public key).
 *)
const func rsaKeyPair: genRsaKeyPair (in integer: keyLength, in bigInteger: exponent) is func
  result
    var rsaKeyPair: keyPair is rsaKeyPair.value;
  local
    const integer: numTests is 10;
    var bigInteger: p is 0_;
    var bigInteger: q is 0_;
    var bigInteger: modulus is 0_;         # n
    var bigInteger: phiOfModulus is 0_;    # φ(n)
    var bigInteger: privateExponent is 0_; # d
  begin
    p := getProbablyPrime(keyLength mdiv 2, numTests);
    q := getProbablyPrime(keyLength mdiv 2, numTests);
    modulus := p * q;
    phiOfModulus := pred(p) * pred(q);
    keyPair.publicKey := rsaKey(modulus, exponent);
    keyPair.privateKey := rsaKey(modulus, modInverse(exponent, phiOfModulus));
  end func;


(**
 *  Convert a nonnegative [[bigint|bigInteger]] to an octet [[string]] of a specified length (I2OSP).
 *  @return an octet [[string]] with the big-endian representation.
 *  @exception RANGE_ERROR If the result does not fit into ''length''.
 *)
const func string: int2Octets (in bigInteger: number, in integer: length) is
  return bytes(number, UNSIGNED, BE, length);


(**
 *  Convert an octet [[string]] to a nonnegative [[bigint|bigInteger]] (OS2IP).
 *  @return a nonnegative bigInteger created from the big-endian bytes.
 *  @exception RANGE_ERROR If characters beyond '\255;' are present.
 *)
const func bigInteger: octets2int (in string: stri) is
  return bytes2BigInt(stri, UNSIGNED, BE);


(**
 *  Encodes a [[string]] with the EME-OAEP encoding.
 *  @return the EME-OAEP encoded [[string]].
 *)
const func string: emeOaepEncoding (in string: message, in string: label, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: lHash is "";
    var string: ps is "";
    var string: db is "";
    var string: seed is "";
    var string: dbMask is "";
    var string: maskedDb is "";
    var string: seedMask is "";
    var string: maskedSeed is "";
  begin
    lHash := sha1(label);
    ps := "\0;" mult modulusLen - length(message) - 2 * hLen - 2;
    db := lHash & ps & "\1;" & message;
    # Generate a random octet string seed of length hLen.
    seed := int2Octets(rand(0_, 2_ ** (hLen * 8) - 1_), hLen);
    dbMask := mgf1Sha1(seed, modulusLen - hLen - 1);
    maskedDb := db >< dbMask;
    seedMask := mgf1Sha1(maskedDb, hLen);
    maskedSeed := seed >< seedMask;
    encodedMessage := "\0;" & maskedSeed & maskedDb;
  end func;


(**
 *  Decodes an EME-OAEP encoded [[string]].
 *  @return the decoded string.
 *  @exception RANGE_ERROR If ''encodedMessage'' is not in the
 *                         EME-OAEP format.
 *)
const func string: emeOaepDecoding (in string: encodedMessage, in string: label, in integer: modulusLen) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: y is "";
    var string: maskedSeed is "";
    var string: maskedDb is "";
    var string: seedMask is "";
    var string: seed is "";
    var string: dbMask is "";
    var string: db is "";
    var string: lHash is "";
    var integer: pos is 0;
  begin
    y := encodedMessage[1 len 1];
    maskedSeed := encodedMessage[2 len hLen];
    maskedDb := encodedMessage[hLen + 2 ..];  # length: modulusLen - hLen - 1;
    seedMask := mgf1Sha1(maskedDb, hLen);
    seed := maskedSeed >< seedMask;
    dbMask := mgf1Sha1(seed, modulusLen - hLen - 1);
    db := maskedDb >< dbMask;
    lHash := db[.. hLen];
    # writeln("lHash: " <& hex(lHash));
    # writeln("raw: " <& literal(db));
    if lHash <> sha1(label) then
      raise RANGE_ERROR;
    end if;
    pos := succ(hLen);
    while db[pos] <> '\1;' do
      incr(pos);
    end while;
    message := db[succ(pos) ..];
  end func;


(**
 *  Encodes a [[string]] with the EME-PKCS1-v1_5 encoding.
 *  @return the EME-PKCS1-v1_5 encoded [[string]].
 *)
const func string: emePkcs1V15Encoding (in string: message, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    var integer: pos is 0;
    var string: ps is "";
  begin
    for pos range 1 to modulusLen - length(message) - 3 do
      ps &:= rand('\1;', '\255;');
    end for;
    encodedMessage := "\0;\2;" & ps & "\0;" & message;
  end func;


(**
 *  Decodes an EME-PKCS1-v1_5 encoded [[string]].
 *  @return the decoded string.
 *  @exception RANGE_ERROR If ''encodedMessage'' is not in the
 *                         EME-PKCS1-v1_5 format.
 *)
const func string: emePkcs1V15Decoding (in string: encodedMessage, in integer: modulusLen) is func
  result
    var string: message is "";
  local
    var integer: pos is 0;
  begin
    if not startsWith(encodedMessage, "\0;\2;") then
      # writeln("encodedMessage: " <& hex(encodedMessage));
      raise RANGE_ERROR;
    else
      pos := pos(encodedMessage[3 ..], '\0;');
      if pos = 0 then
        raise RANGE_ERROR;
      else
        message := encodedMessage[pos + 3 ..];
      end if;
    end if;
  end func;


(**
 *  Encodes a [[string]] with the EMSA-PKCS1-v1_5 encoding.
 *  @return the EMSA-PKCS1-v1_5 encoded [[string]].
 *)
const func string: emsaPkcs1V15Encoding (in string: message, in integer: modulusLen) is func
  result
    var string: encodedMessage is "";
  local
    var string: ps is "";
  begin
    ps := "\255;" mult modulusLen - length(message) - 3;
    encodedMessage := "\0;\1;" & ps & "\0;" & message;
  end func;


(**
 *  Decodes an EMSA-PKCS1-v1_5 encoded [[string]].
 *  @return the decoded string.
 *  @exception RANGE_ERROR If ''encodedMessage'' is not in the
 *                         EME-PKCS1-v1_5 format.
 *)
const func string: emsaPkcs1V15Decoding (in string: encodedMessage) is func
  result
    var string: message is "";
  local
    var integer: pos is 0;
  begin
    if not startsWith(encodedMessage, "\0;\1;") then
      # writeln("encodedMessage: " <& hex(encodedMessage));
      raise RANGE_ERROR;
    else
      pos := pos(encodedMessage[3 ..], '\0;');
      if pos < 9 or encodedMessage[3 len pred(pos)] <> "\255;" mult pred(pos) then
        raise RANGE_ERROR;
      else
        message := encodedMessage[pos + 3 ..];
      end if;
    end if;
  end func;


(**
 *  Encrypts a [[bigint|bigInteger]] with the RSAES encryption.
 *  @return the RSAES encrypted [[bigint|bigInteger]].
 *  @exception RANGE_ERROR If ''message'' is too big for the
 *                         RSAEP encryption.
 *)
const func bigInteger: rsaEncrypt (in rsaKey: encryptionKey, in bigInteger: message) is func
  result
    var bigInteger: ciphertext is 0_;
  begin
    if message >= encryptionKey.modulus then
      raise RANGE_ERROR;
    else
      ciphertext := modPow(message, encryptionKey.exponent, encryptionKey.modulus);
    end if;
  end func;


(**
 *  Decrypts a [[bigint|bigInteger]] with the RSADP decryption.
 *  @return the RSADP decrypted [[bigint|bigInteger]].
 *  @exception RANGE_ERROR If ''ciphertext'' is too big for the
 *                         RSADP decryption.
 *)
const func bigInteger: rsaDecrypt (in rsaKey: decryptionKey, in bigInteger: ciphertext) is func
  result
    var bigInteger: message is 0_;
  begin
    # writeln("rsaDecrypt: modulus    = " <& decryptionKey.modulus radix 16);
    # writeln("rsaDecrypt: exponent   = " <& decryptionKey.exponent radix 16);
    # writeln("rsaDecrypt: ciphertext = " <& ciphertext radix 16);
    if ciphertext >= decryptionKey.modulus then
      raise RANGE_ERROR;
    else
      message := modPow(ciphertext, decryptionKey.exponent, decryptionKey.modulus);
    end if;
  end func;


(**
 *  Encrypts a [[string]] of bytes with the RSAES-OAEP encryption.
 *  @return the RSAES-OAEP encrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''message'' is too long
 *                         to be encrypted.
 *  @exception RANGE_ERROR If ''message'' contains characters beyond '\255;'.
 *)
const func string: rsaesOaepEncrypt (in rsaKey: encryptionKey, in string: message,
    in string: label) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      raise RANGE_ERROR;
    else
      encodedMessage := emeOaepEncoding(message, label, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;


(**
 *  Decrypts a [[string]] of bytes with the RSAES-OAEP decryption.
 *  @return the RSAES-OAEP decrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''ciphertext'' is too long
 *                         to be decrypted.
 *  @exception RANGE_ERROR If ''ciphertext'' contains characters beyond '\255;'.
 *)
const func string: rsaesOaepDecrypt (in rsaKey: decryptionKey, in string: ciphertext,
    in string: label) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      # writeln("length(ciphertext): " <& length(ciphertext));
      # writeln("modulusLen: " <& decryptionKey.modulusLen);
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emeOaepDecoding(encodedMessage, label, decryptionKey.modulusLen);
    end if;
  end func;


(**
 *  Encrypts a [[string]] of bytes with the RSAES-PKCS1-V1_5 encryption.
 *  @return the RSAES-PKCS1-V1_5 encrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''message'' is too long
 *                         to be encrypted.
 *  @exception RANGE_ERROR If ''message'' contains characters beyond '\255;'.
 *)
const func string: rsaesPkcs1V15Encrypt (in rsaKey: encryptionKey, in string: message) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      # writeln("length(message): " <& length(message));
      # writeln("modulusLen: " <& encryptionKey.modulusLen);
      raise RANGE_ERROR;
    else
      encodedMessage := emePkcs1V15Encoding(message, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;


(**
 *  Decrypts a [[string]] of bytes with the RSAES-PKCS1-V1_5 decryption.
 *  @return the RSAES-PKCS1-V1_5 decrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''ciphertext'' is too long
 *                         to be decrypted.
 *  @exception RANGE_ERROR If ''ciphertext'' contains characters beyond '\255;'.
 *)
const func string: rsaesPkcs1V15Decrypt (in rsaKey: decryptionKey, in string: ciphertext) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      # writeln("length(ciphertext): " <& length(ciphertext));
      # writeln("modulusLen: " <& decryptionKey.modulusLen);
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emePkcs1V15Decoding(encodedMessage, decryptionKey.modulusLen);
    end if;
  end func;


(**
 *  Encrypts a [[string]] of bytes with the RSASSA-PKCS1-V1_5 encryption.
 *  @return the RSASSA-PKCS1-V1_5 encrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''message'' is too long
 *                         to be encrypted.
 *  @exception RANGE_ERROR If ''message'' contains characters beyond '\255;'.
 *)
const func string: rsassaPkcs1V15Encrypt (in rsaKey: encryptionKey, in string: message) is func
  result
    var string: encryptedMessage is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(message) > encryptionKey.modulusLen - 2 * hLen - 2 then
      # writeln("length(message): " <& length(message));
      # writeln("modulusLen: " <& encryptionKey.modulusLen);
      raise RANGE_ERROR;
    else
      encodedMessage := emsaPkcs1V15Encoding(message, encryptionKey.modulusLen);
      encryptedMessage := int2Octets(rsaEncrypt(encryptionKey, octets2int(encodedMessage)),
                                     encryptionKey.modulusLen);
    end if;
  end func;


(**
 *  Decrypts a [[string]] of bytes with the RSASSA-PKCS1-V1_5 decryption.
 *  @return the RSASSA-PKCS1-V1_5 decrypted [[string]] of bytes.
 *  @exception RANGE_ERROR If the length of the ''ciphertext'' is too long
 *                         to be decrypted.
 *  @exception RANGE_ERROR If ''ciphertext'' contains characters beyond '\255;'.
 *)
const func string: rsassaPkcs1V15Decrypt (in rsaKey: decryptionKey, in string: ciphertext) is func
  result
    var string: message is "";
  local
    const integer: hLen is 20;  # Length (in bytes) of the sha1 output.
    var string: encodedMessage is "";
  begin
    if length(ciphertext) <> decryptionKey.modulusLen or
        decryptionKey.modulusLen < 2 * hLen + 2 then
      # writeln("length(ciphertext): " <& length(ciphertext));
      # writeln("modulusLen: " <& decryptionKey.modulusLen);
      raise RANGE_ERROR;
    else
      encodedMessage := int2Octets(rsaDecrypt(decryptionKey, octets2int(ciphertext)),
                                   decryptionKey.modulusLen);
      message := emsaPkcs1V15Decoding(encodedMessage);
    end if;
  end func;