(********************************************************************)
(*                                                                  *)
(*  inflate.s7i   Inflate uncompression algorithm                   *)
(*  Copyright (C) 2008, 2013, 2015, 2017, 2022, 2023  Thomas Mertes *)
(*                                                                  *)
(*  This file is part of the Seed7 Runtime Library.                 *)
(*                                                                  *)
(*  The Seed7 Runtime Library is free software; you can             *)
(*  redistribute it and/or modify it under the terms of the GNU     *)
(*  Lesser General Public License as published by the Free Software *)
(*  Foundation; either version 2.1 of the License, or (at your      *)
(*  option) any later version.                                      *)
(*                                                                  *)
(*  The Seed7 Runtime Library is distributed in the hope that it    *)
(*  will be useful, but WITHOUT ANY WARRANTY; without even the      *)
(*  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR *)
(*  PURPOSE.  See the GNU Lesser General Public License for more    *)
(*  details.                                                        *)
(*                                                                  *)
(*  You should have received a copy of the GNU Lesser General       *)
(*  Public License along with this program; if not, write to the    *)
(*  Free Software Foundation, Inc., 51 Franklin Street,             *)
(*  Fifth Floor, Boston, MA  02110-1301, USA.                       *)
(*                                                                  *)
(********************************************************************)


include "bitdata.s7i";
include "bytedata.s7i";
include "huffman.s7i";


const integer: INFLATE_END_OF_BLOCK is 256;


const proc: getNonCompressedBlock (inout lsbInBitStream: compressedStream,
    inout string: uncompressed) is func
  local
    var integer: length is 0;
    var integer: nlength is 0;
    var string: uncompressedBlock is "";
  begin
    length := bytes2Int(gets(compressedStream, 2), UNSIGNED, LE);
    nlength := bytes2Int(gets(compressedStream, 2), UNSIGNED, LE);
    uncompressedBlock := gets(compressedStream, length);
    if length(uncompressedBlock) <> length then
      raise RANGE_ERROR;
    else
      uncompressed &:= uncompressedBlock;
    end if;
  end func;


const func integer: getLiteralOrLength (inout lsbInBitStream: compressedStream) is func
  result
    var integer: literalOrLength is 0;
  begin
    literalOrLength := reverseBits[7][getBits(compressedStream, 7)];
    if literalOrLength <= 2#0010111 then
      literalOrLength +:= 256;
    else
      literalOrLength <<:= 1;
      literalOrLength +:= getBit(compressedStream);
      if literalOrLength <= 2#10111111 then
        literalOrLength -:= 2#00110000;
      elsif literalOrLength <= 2#11000111 then
        literalOrLength +:= 280 - 2#11000000;
      else
        literalOrLength <<:= 1;
        literalOrLength +:= getBit(compressedStream);
        literalOrLength -:= 2#110010000 - 144;
      end if;
    end if;
  end func;


const func integer: getDistance (inout lsbInBitStream: compressedStream) is
  return reverseBits[5][getBits(compressedStream, 5)];


const func integer: decodeLength (inout lsbInBitStream: compressedStream,
    in integer: code) is func
  result
    var integer: length is 0;
  begin
    if code <= 264 then
      length := code - 254;
    elsif code <= 268 then
      length := 11 + ((code - 265) << 1) + getBit(compressedStream);
    elsif code <= 272 then
      length := 19 + ((code - 269) << 2) + getBits(compressedStream, 2);
    elsif code <= 276 then
      length := 35 + ((code - 273) << 3) + getBits(compressedStream, 3);
    elsif code <= 280 then
      length := 67 + ((code - 277) << 4) + getBits(compressedStream, 4);
    elsif code <= 284 then
      length := 131 + ((code - 281) << 5) + getBits(compressedStream, 5);
    elsif code = 285 then
      length := 258;
    else
      raise RANGE_ERROR;
    end if;
  end func;


const func integer: decodeDistance (inout lsbInBitStream: compressedStream,
    in integer: code) is func
  result
    var integer: distance is 0;
  begin
    case code of
      when {0, 1, 2, 3}:
        distance := succ(code);
      when {4, 5}:
        distance := 5 + ((code - 4) << 1) + getBit(compressedStream);
      when {6, 7}:
        distance := 9 + ((code - 6) << 2) + getBits(compressedStream, 2);
      when {8, 9}:
        distance := 17 + ((code - 8) << 3) + getBits(compressedStream, 3);
      when {10, 11}:
        distance := 33 + ((code - 10) << 4) + getBits(compressedStream, 4);
      when {12, 13}:
        distance := 65 + ((code - 12) << 5) + getBits(compressedStream, 5);
      when {14, 15}:
        distance := 129 + ((code - 14) << 6) + getBits(compressedStream, 6);
      when {16, 17}:
        distance := 257 + ((code - 16) << 7) + getBits(compressedStream, 7);
      when {18, 19}:
        distance := 513 + ((code - 18) << 8) + getBits(compressedStream, 8);
      when {20, 21}:
        distance := 1025 + ((code - 20) << 9) + getBits(compressedStream, 9);
      when {22, 23}:
        distance := 2049 + ((code - 22) << 10) + getBits(compressedStream, 10);
      when {24, 25}:
        distance := 4097 + ((code - 24) << 11) + getBits(compressedStream, 11);
      when {26, 27}:
        distance := 8193 + ((code - 26) << 12) + getBits(compressedStream, 12);
      when {28, 29}:
        distance := 16385 + ((code - 28) << 13) + getBits(compressedStream, 13);
      otherwise:
        raise RANGE_ERROR;
    end case;
  end func;


const proc: decodeFixedHuffmanCodes (inout lsbInBitStream: compressedStream,
    inout string: uncompressed) is func
  local
    var integer: literalOrLength is 0;
    var integer: length is 0;
    var integer: distance is 0;
    var integer: number is 0;
    var integer: nextPos is 0;
  begin
    literalOrLength := getLiteralOrLength(compressedStream);
    while literalOrLength <> INFLATE_END_OF_BLOCK do
      if literalOrLength < INFLATE_END_OF_BLOCK then
        uncompressed &:= chr(literalOrLength);
      else
        length := decodeLength(compressedStream, literalOrLength);
        distance := getDistance(compressedStream);
        distance := decodeDistance(compressedStream, distance);
        if length > distance then
          nextPos := succ(length(uncompressed));
          uncompressed &:= "\0;" mult length;
          for number range nextPos to nextPos + length - 1 do
            uncompressed @:= [number] uncompressed[number - distance];
          end for;
        else # hopefully length(uncompressed) >= distance holds
          uncompressed &:= uncompressed[succ(length(uncompressed)) - distance fixLen length];
        end if;
      end if;
      literalOrLength := getLiteralOrLength(compressedStream);
    end while;
  end func;


const proc: decodeDynamicHuffmanCodes (inout lsbInBitStream: compressedStream,
    inout string: uncompressed) is func
  local
    const array integer: mapToOrderedLengths is []
        (16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15);
    var integer: literalOrLengthTableSize is 0;  # hlit + 257
    var integer: distanceTableSize is 0;         # hdist + 1
    var integer: combinedDataTableSize is 0;     # hclen + 4
    var array integer:     combinedDataCodeLengths is [0 .. 18] times 0;
    var lsbHuffmanDecoder: combinedDataDecoder is lsbHuffmanDecoder.value;
    var array integer:     combinedData is [0 .. -1] times 0;
    var array integer:     literalOrLengthCodeLengths is [0 .. -1] times 0;
    var lsbHuffmanDecoder: literalOrLengthDecoder is lsbHuffmanDecoder.value;
    var array integer:     distanceCodeLengths is [0 .. -1] times 0;
    var lsbHuffmanDecoder: distanceDecoder is lsbHuffmanDecoder.value;
    var integer: number is 0;
    var integer: combinedDataElement is 0;
    var integer: factor is 0;
    var integer: literalOrLength is 0;
    var integer: length is 0;
    var integer: distance is 0;
    var integer: nextPos is 0;
  begin
    literalOrLengthTableSize := getBits(compressedStream, 5) + 257;
    distanceTableSize        := getBits(compressedStream, 5) + 1;
    combinedDataTableSize    := getBits(compressedStream, 4) + 4;
    for number range 1 to combinedDataTableSize do
      combinedDataCodeLengths[mapToOrderedLengths[number]] := getBits(compressedStream, 3);
    end for;
    combinedDataDecoder := createLsbHuffmanDecoder(combinedDataCodeLengths);
    number := 1;
    while number <= literalOrLengthTableSize + distanceTableSize do
      combinedDataElement := getHuffmanSymbol(compressedStream, combinedDataDecoder);
      if combinedDataElement <= 15 then
        combinedData &:= combinedDataElement;
        incr(number);
      elsif combinedDataElement = 16 then
        factor := getBits(compressedStream, 2) + 3;
        combinedData &:= factor times
            combinedData[maxIdx(combinedData)];
        number +:= factor;
      elsif combinedDataElement = 17 then
        factor := getBits(compressedStream, 3) + 3;
        combinedData &:= factor times 0;
        number +:= factor;
      else # combinedDataElement = 18
        factor := getBits(compressedStream, 7) + 11;
        combinedData &:= factor times 0;
        number +:= factor;
      end if;
    end while;
    literalOrLengthCodeLengths := combinedData[.. pred(literalOrLengthTableSize)];
    distanceCodeLengths        := combinedData[literalOrLengthTableSize .. ];
    literalOrLengthDecoder := createLsbHuffmanDecoder(literalOrLengthCodeLengths);
    distanceDecoder        := createLsbHuffmanDecoder(distanceCodeLengths);
    literalOrLength := getHuffmanSymbol(compressedStream, literalOrLengthDecoder);
    while literalOrLength <> INFLATE_END_OF_BLOCK do
      if literalOrLength < INFLATE_END_OF_BLOCK then
        uncompressed &:= chr(literalOrLength);
      else
        length := decodeLength(compressedStream, literalOrLength);
        distance := getHuffmanSymbol(compressedStream, distanceDecoder);
        distance := decodeDistance(compressedStream, distance);
        if length > distance then
          nextPos := succ(length(uncompressed));
          uncompressed &:= "\0;" mult length;
          for number range nextPos to nextPos + length - 1 do
            uncompressed @:= [number] uncompressed[number - distance];
          end for;
        else # hopefully length(uncompressed) >= distance holds
          uncompressed &:= uncompressed[succ(length(uncompressed)) - distance fixLen length];
        end if;
      end if;
      literalOrLength := getHuffmanSymbol(compressedStream, literalOrLengthDecoder);
    end while;
  end func;


const proc: processCompressedBlock (inout lsbInBitStream: compressedStream,
    inout string: uncompressed, inout boolean: bfinal) is func
  local
    var integer: btype is 0;
  begin
    bfinal := odd(getBit(compressedStream));
    btype := getBits(compressedStream, 2);
    case btype of
      when {0}:
        getNonCompressedBlock(compressedStream, uncompressed);
      when {1}:
        decodeFixedHuffmanCodes(compressedStream, uncompressed);
      when {2}:
        decodeDynamicHuffmanCodes(compressedStream, uncompressed);
      otherwise:
        raise RANGE_ERROR;
    end case;
  end func;


(**
 *  Decompress a file that was compressed with DEFLATE.
 *  DEFLATE is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE format.
 *)
const func string: inflate (inout file: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
    close(compressedStream);
  end func;


(**
 *  Decompress a string that was compressed with DEFLATE.
 *  DEFLATE is a compression algorithm that uses a combination of
 *  the LZ77 algorithm and Huffman coding.
 *  @return the uncompressed string.
 *  @exception RANGE_ERROR If ''compressed'' is not in DEFLATE format.
 *)
const func string: inflate (in string: compressed) is func
  result
    var string: uncompressed is "";
  local
    var lsbInBitStream: compressedStream is lsbInBitStream.value;
    var boolean: bfinal is FALSE;
  begin
    compressedStream := openLsbInBitStream(compressed);
    repeat
      processCompressedBlock(compressedStream, uncompressed, bfinal);
    until bfinal;
  end func;