(********************************************************************)
(*                                                                  *)
(*  aes_gcm.s7i   AES cipher with Galois Counter Mode (GCM)         *)
(*  Copyright (C) 2023, 2024  Thomas Mertes                         *)
(*                                                                  *)
(*  This file is part of the Seed7 Runtime Library.                 *)
(*                                                                  *)
(*  The Seed7 Runtime Library is free software; you can             *)
(*  redistribute it and/or modify it under the terms of the GNU     *)
(*  Lesser General Public License as published by the Free Software *)
(*  Foundation; either version 2.1 of the License, or (at your      *)
(*  option) any later version.                                      *)
(*                                                                  *)
(*  The Seed7 Runtime Library is distributed in the hope that it    *)
(*  will be useful, but WITHOUT ANY WARRANTY; without even the      *)
(*  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR *)
(*  PURPOSE.  See the GNU Lesser General Public License for more    *)
(*  details.                                                        *)
(*                                                                  *)
(*  You should have received a copy of the GNU Lesser General       *)
(*  Public License along with this program; if not, write to the    *)
(*  Free Software Foundation, Inc., 51 Franklin Street,             *)
(*  Fifth Floor, Boston, MA  02110-1301, USA.                       *)
(*                                                                  *)
(********************************************************************)

include "bin32.s7i";
include "bin64.s7i";
include "bigint.s7i";
include "aes.s7i";


const type: halfHTableType is array [0 .. 15] bin64;

const type: factorHType is new struct
    var halfHTableType: lowHalfHTable  is halfHTableType.value;
    var halfHTableType: highHalfHTable is halfHTableType.value;
  end struct;


const func factorHType: computeFactorH (in string: h) is func
  result
    var factorHType: factorH is factorHType.value;
  local
    var bin64: vl is bin64(0);
    var bin64: vh is bin64(0);
    var integer: index is 0;
    var integer: subIndex is 0;
    var bin64: temp is bin64(0);
  begin
    vh := bin64(h[1 fixLen 8], BE);
    vl := bin64(h[9 fixLen 8], BE);

    factorH.lowHalfHTable[8]  := vl;        # 8 = 2#1000 corresponds to 1 in GF(2**128)
    factorH.highHalfHTable[8] := vh;
    factorH.highHalfHTable[0] := bin64(0);  # 0 corresponds to 0 in GF(2**128)
    factorH.lowHalfHTable[0]  := bin64(0);

    for index range [](4, 2, 1) do
      temp := bin64((vl & bin64(1)) <> bin64(0) ? 16#e1000000 : 0);
      vl := (vh << 63) | (vl >> 1);
      vh := (vh >> 1) >< (temp << 32);
      factorH.lowHalfHTable[index]  := vl;
      factorH.highHalfHTable[index] := vh;
    end for;
    for index range [](2, 4, 8) do
      vh := factorH.highHalfHTable[index];
      vl := factorH.lowHalfHTable[index];
      for subIndex range 1 to pred(index) do
        factorH.highHalfHTable[index + subIndex] := vh >< factorH.highHalfHTable[subIndex];
        factorH.lowHalfHTable[index + subIndex]  := vl >< factorH.lowHalfHTable[subIndex];
      end for;
    end for;
  end func;


const array bin64: last4 is [0] (
    bin64(16#0000000000000000),  bin64(16#1c20000000000000),
    bin64(16#3840000000000000),  bin64(16#2460000000000000),
    bin64(16#7080000000000000),  bin64(16#6ca0000000000000),
    bin64(16#48c0000000000000),  bin64(16#54e0000000000000),
    bin64(16#e100000000000000_), bin64(16#fd20000000000000_),
    bin64(16#d940000000000000_), bin64(16#c560000000000000_),
    bin64(16#9180000000000000_), bin64(16#8da0000000000000_),
    bin64(16#a9c0000000000000_), bin64(16#b5e0000000000000_));


const func string: gcmMult (in string: factor1,               # 128-bit factor 1 vector
                            in factorHType: factorH) is func  # Precalculated factor H
  result
    var string: product is "";  # 128-bit product vector
  local
    var integer: index is 0;
    var integer: lo is 0;
    var integer: hi is 0;
    var integer: rem is 0;
    var bin64: zh is bin64(0);
    var bin64: zl is bin64(0);
  begin
    lo := ord(factor1[16]) mod 16;
    hi := ord(factor1[16]) >> 4;
    zh := factorH.highHalfHTable[lo];
    zl := factorH.lowHalfHTable[lo];

    for index range 16 downto 1 do
      lo := ord(factor1[index]) mod 16;
      hi := ord(factor1[index]) >> 4;

      if index <> 16 then
        rem := ord(zl & bin64(16#f));
        zl := zh << 60 | zl >> 4;
        zh := zh >> 4;
        zh ><:= last4[rem];
        zh ><:= factorH.highHalfHTable[lo];
        zl ><:= factorH.lowHalfHTable[lo];
      end if;
      rem := ord(zl & bin64(16#f));
      zl := zh << 60 | zl >> 4;
      zh := zh >> 4;
      zh ><:= last4[rem];
      zh ><:= factorH.highHalfHTable[hi];
      zl ><:= factorH.lowHalfHTable[hi];
    end for;
    product := bytes(zh, BE, 8) & bytes(zl, BE, 8);
  end func;


(**
 *  [[cipher|cipherState]] implementation type describing the state of an AES GCM cipher.
 *  The data is encrypted / decrypted with the AES (Advanced encryption
 *  standard) cipher and Galois Counter Mode (GCM).
 *)
const type: aesGcmState is new struct
    var string: salt is "";                         # Comes from the initialization vector
    var string: nonce is " " mult 16;               # Current salt|explicit_nonce|counter value
    var aesState: aesCipherState is aesState.value; # Only the encryptionSubKey is used
    var factorHType: factorH is factorHType.value;  # Precalculated factor H for GHASH operation
    var integer: sequenceNumber is 0;               # Used to initialize the computedMac
    var string: recordTypeAndVersion is "";         # Used to initialize the computedMac
    var string: encodedNonce1 is "";                # Encoded nonce with a counter of 1
    var string: computedMac is "\0;" mult 16;       # The computed AEAD authentication tag (=MAC)
    var string: mac is "";                          # MAC appended to the encrypted data
  end struct;


type_implements_interface(aesGcmState, cipherState);


(**
 *  Block size used by the AES GCM cipher.
 *  AES is the Advanced Encryption Standard cipher.
 *  GCM is the Galois Counter Mode.
 *  @return the block size used by the AES GCM cipher.
 *)
const func integer: blockSize (AES_GCM) is 0;


(**
 *  Set key and initialization vector for the AES GCM cipher.
 *  AES is the Advanced Encryption Standard cipher.
 *  GCM is the Galois Counter Mode.
 *  @param aesKey The key to be used for AES GCM.
 *  @param initializationVector The salt used to create the nonce.
 *  @return the AES GCM cipher state.
 *)
const func aesGcmState: setAesGcmKey (in string: aesKey,
    in string: initializationVector) is func
  result
    var aesGcmState: state is aesGcmState.value;
  local
    var string: h is "\0;" mult 16;
  begin
    state.aesCipherState := setAesKey(aesKey, "");
    state.salt := initializationVector;
    h := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                        state.aesCipherState.rounds, h);
    state.factorH := computeFactorH(h);
  end func;


(**
 *  Set key and initialization vector for the AES GCM cipher.
 *  AES is the Advanced Encryption Standard cipher.
 *  GCM is the Galois Counter Mode.
 *  @param cipherKey The key to be used for AES GCM.
 *  @param initializationVector The salt used to create the nonce.
 *  @return the initial ''cipherState'' of a AES GCM cipher.
 *)
const func cipherState: setCipherKey (AES_GCM, in string: cipherKey,
    in string: initializationVector) is
  return toInterface(setAesGcmKey(cipherKey, initializationVector));


(**
 *  Initialize the authenticated encryption with associated data (AEAD).
 *  The AEAD authentication tag is computed in state.computedMac (MAC stands
 *  for message authentication code). The given parameters are used to
 *  initialize state.computedMac at the start of encryption or decryption.
 *)
const proc: initAead (inout aesGcmState: state, in string: recordTypeAndVersion,
    in integer: sequenceNumber) is func
  begin
    state.recordTypeAndVersion := recordTypeAndVersion;
    state.sequenceNumber := sequenceNumber;
  end func;


(**
 *  Obtain the computed MAC of data that has been decrypted with the AES GCM cipher.
 *  After a successful decryption getComputedMac and getMac should return
 *  the same value.
 *)
const func string: getComputedMac (in aesGcmState: state) is
  return state.computedMac;


(**
 *  Obtain the MAC that is appended to the encrypted data of the AES GCM cipher.
 *  After a successful decryption getComputedMac and getMac should return
 *  the same value.
 *)
const func string: getMac (in aesGcmState: state) is
  return state.mac;


##
#  Increment the counter of the salt|explicit_nonce|counter nonce.
#
const proc: incrementGcmNonceCounter (inout string: nonce) is func
  local
    var integer: index is 16;
  begin
    while index >= 13 and nonce[index] = '\255;' do
      nonce @:= [index] '\0;';
      decr(index);
    end while;
    if index >= 13 then
      nonce @:= [index] succ(nonce[index]);
    end if;
  end func;


const func string: aesGcmEncode (inout aesGcmState: state,
    in string: plainText) is func
  result
    var string: encoded is "";
  local
    var integer: length is 0;        # length, in bytes, of data to process
    var string: encodedNonce is "";  # Encoded nonce used for XORing with the plain text
    var integer: startIndex is 1;
    var integer: blockLength is 0;   # byte count to process up to 16 bytes
    var integer: index is 0;
    var bin32: encodedChar is bin32(0);
  begin
    # writeln("aesGcmEncode plainText  (length=" <& length(plainText) <& "): " <& literal(plainText));
    length := length(plainText);
    while length > 0 do
      # Clamp the length to process at most 16 bytes
      blockLength := length < 16 ? length : 16;

      incrementGcmNonceCounter(state.nonce);
      encodedNonce := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                     state.aesCipherState.rounds, state.nonce);

      # Encrypt the plainText with the encodedNonce
      for index range 1 to blockLength do
        encodedChar := bin32(encodedNonce[index]) >< bin32(plainText[startIndex + index - 1]);
        encoded &:= char(ord(encodedChar));
        # The authentication hash (computedMac) is computed with the encoded data.
        state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) >< encodedChar));
      end for;
      state.computedMac := gcmMult(state.computedMac, state.factorH);

      length     -:= blockLength;  # drop the remaining byte count to process
      startIndex +:= blockLength;  # bump our plainText index forward
    end while;
    # writeln("aesGcmEncode encoded (length=" <& length(encoded) <& "): " <& hex(encoded));
  end func;


const func string: aesGcmDecode (inout aesGcmState: state,
    in string: encoded) is func
  result
    var string: plainText is "";
  local
    var integer: length is 0;        # length, in bytes, of data to process
    var string: encodedNonce is "";  # Encoded nonce used for XORing with the encoded data
    var integer: startIndex is 1;
    var integer: blockLength is 0;   # byte count to process up to 16 bytes
    var integer: index is 0;
    var bin32: encodedChar is bin32(0);
  begin
    # writeln("aesGcmDecode encoded  (length=" <& length(encoded) <& "): " <& hex(encoded));
    length := length(encoded);
    while length > 0 do
      # Clamp the length to process at most 16 bytes
      blockLength := length < 16 ? length : 16;

      incrementGcmNonceCounter(state.nonce);
      encodedNonce := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                     state.aesCipherState.rounds, state.nonce);

      # Decrypt the encoded data with the encodedNonce
      for index range 1 to blockLength do
        encodedChar := bin32(encoded[startIndex + index - 1]);
        # The authentication hash (computedMac) is computed with the encoded data.
        state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) >< encodedChar));
        plainText &:= char(ord(bin32(encodedNonce[index]) >< encodedChar));
      end for;
      state.computedMac := gcmMult(state.computedMac, state.factorH);

      length     -:= blockLength;  # drop the remaining byte count to process
      startIndex +:= blockLength;  # bump our encoded index forward
    end while;
    # writeln("aesGcmDecode plainText (length=" <& length(plainText) <& "): " <& literal(plainText));
  end func;


const proc: initializeComputedMac (inout aesGcmState: state, in integer: plainTextLength) is func
  begin
    state.encodedNonce1 := encodeAesBlock(state.aesCipherState.encryptionSubKey,
                                          state.aesCipherState.rounds, state.nonce);
    state.computedMac := bytes(state.sequenceNumber, UNSIGNED, BE, 8) &
                               state.recordTypeAndVersion &  # 3 bytes
                               bytes(plainTextLength, UNSIGNED, BE, 2) &
                               "\0;" mult 3;
    state.computedMac := gcmMult(state.computedMac, state.factorH);
  end func;


const proc: finalizeComputedMac (inout aesGcmState: state, in integer: plainTextLength) is func
  local
    const integer: lengthOfMacInitializationValue is 13; # Length of additional AEAD data
    var string: lengthBuffer is "";
    var integer: index is 0;
  begin
    lengthBuffer := bytes(lengthOfMacInitializationValue * 8, UNSIGNED, BE, 8) &
                    bytes(plainTextLength * 8, UNSIGNED, BE, 8);
    for index range 1 to 16 do
      state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) ><
                                             bin32(lengthBuffer[index])));
    end for;
    state.computedMac := gcmMult(state.computedMac, state.factorH);
    for index range 1 to 16 do
      state.computedMac @:= [index] char(ord(bin32(state.computedMac[index]) ><
                                             bin32(state.encodedNonce1[index])));
    end for;
    # writeln("computedMac: " <& hex(state.computedMac));
  end func;


(**
 *  Encode a string with the AES GCM cipher.
 *  @return the encoded string.
 *)
const func string: encode (inout aesGcmState: state, in string: plainText) is func
  result
    var string: encoded is "";
  begin
    encoded := bytes(rand(bin64), BE, 8);
    state.nonce := state.salt &
                   encoded &        # explicit nonce
                   "\0;\0;\0;\1;";  # start "counting" from 1 (not 0)
    # writeln("complete nonce: " <& hex(state.nonce));
    initializeComputedMac(state, length(plainText));
    encoded &:= aesGcmEncode(state, plainText);
    finalizeComputedMac(state, length(plainText));
    encoded &:= state.computedMac;
  end func;


(**
 *  Decode a string with the AES GCM cipher.
 *  @return the decoded string.
 *)
const func string: decode (inout aesGcmState: state, in string: encoded) is func
  result
    var string: plainText is "";
  begin
    if length(encoded) > 24 then
      state.nonce := state.salt &
                     encoded[.. 8] &  # explicit nonce
                     "\0;\0;\0;\1;";  # start "counting" from 1 (not 0)
      # writeln("complete nonce: " <& hex(state.nonce));
      initializeComputedMac(state, length(encoded) - 24);
      plainText := aesGcmDecode(state, encoded[9 .. length(encoded) - 16]);
      finalizeComputedMac(state, length(plainText));
      state.mac := encoded[length(encoded) - 15 ..];
      # writeln("mac: " <& hex(state.mac));
    end if;
  end func;