(********************************************************************)
(*                                                                  *)
(*  zstd.s7i      Zstandard compression support library             *)
(*  Copyright (C) 2020 - 2023  Thomas Mertes                        *)
(*                                                                  *)
(*  This file is part of the Seed7 Runtime Library.                 *)
(*                                                                  *)
(*  The Seed7 Runtime Library is free software; you can             *)
(*  redistribute it and/or modify it under the terms of the GNU     *)
(*  Lesser General Public License as published by the Free Software *)
(*  Foundation; either version 2.1 of the License, or (at your      *)
(*  option) any later version.                                      *)
(*                                                                  *)
(*  The Seed7 Runtime Library is distributed in the hope that it    *)
(*  will be useful, but WITHOUT ANY WARRANTY; without even the      *)
(*  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR *)
(*  PURPOSE.  See the GNU Lesser General Public License for more    *)
(*  details.                                                        *)
(*                                                                  *)
(*  You should have received a copy of the GNU Lesser General       *)
(*  Public License along with this program; if not, write to the    *)
(*  Free Software Foundation, Inc., 51 Franklin Street,             *)
(*  Fifth Floor, Boston, MA  02110-1301, USA.                       *)
(*                                                                  *)
(********************************************************************)


include "bytedata.s7i";
include "bitdata.s7i";


const string: ZSTD_MAGIC is "(\16#B5;/\16#FD;";


const type: zstdFrameHeader is new struct
    var integer: frameContentSize is 0;
    var integer: windowSize is 0;
    var integer: dictionaryId is 0;
    var boolean: contentChecksumFlag is FALSE;
  end struct;


const func integer: zstdWindowSize (in integer: exponent, in integer: mantissa) is func
  result
    var integer: windowSize is 0;
  local
    var integer: windowBase is 0;
    var integer: windowAdd is 0;
  begin
    windowBase := 1 << (10 + exponent);
    windowAdd := (windowBase mdiv 8) * mantissa;
    windowSize := windowBase + windowAdd;
  end func;


const proc: readFrameHeader (inout file: compressed, inout zstdFrameHeader: header) is func
  local
    var integer: descriptor is 0;
    var integer: fcsFieldSize is 0;
    var boolean: singleSegmentFlag is FALSE;
    var integer: dictionaryIdFlag is 0;
    var integer: dictionaryIdFieldSize is 0;
    var integer: windowDescriptor is 0;
    var string: stri is "";
  begin
    descriptor := ord(getc(compressed));
    # writeln("descriptor: " <& descriptor);
    case (descriptor >> 6) mod 4 of
      when {0}: fcsFieldSize := (descriptor >> 5) mod 2;
      when {1}: fcsFieldSize := 2;
      when {2}: fcsFieldSize := 4;
      when {3}: fcsFieldSize := 8;
    end case;
    # writeln("fcsFieldSize: " <& fcsFieldSize);
    singleSegmentFlag := odd(descriptor >> 5);
    header.contentChecksumFlag := odd(descriptor >> 2);
    # writeln("contentChecksumFlag: " <& header.contentChecksumFlag);
    dictionaryIdFlag := descriptor mod 4;
    if dictionaryIdFlag = 3 then
      dictionaryIdFieldSize := 4;
    else
      dictionaryIdFieldSize := dictionaryIdFlag;
    end if;
    if not singleSegmentFlag then
      windowDescriptor := ord(getc(compressed));
      # writeln("windowDescriptor: " <& windowDescriptor);
      header.windowSize := zstdWindowSize(windowDescriptor >> 3, windowDescriptor mod 8);
      # writeln("windowSize: " <& header.windowSize);
    end if;
    if dictionaryIdFieldSize <> 0 then
      stri := gets(compressed, dictionaryIdFieldSize);
      if length(stri) <> dictionaryIdFieldSize then
        raise RANGE_ERROR;
      end if;
      header.dictionaryId := bytes2Int(stri, UNSIGNED, LE);
    end if;
    # writeln("dictionaryId: " <& header.dictionaryId);
    if fcsFieldSize <> 0 then
      stri := gets(compressed, fcsFieldSize);
      if length(stri) <> fcsFieldSize then
        raise RANGE_ERROR;
      end if;
    end if;
    case fcsFieldSize of
      when {0}: header.frameContentSize := -1;
      when {1}: header.frameContentSize :=       bytes2Int(stri, UNSIGNED, LE);
      when {2}: header.frameContentSize := 256 + bytes2Int(stri, UNSIGNED, LE);
      when {4}: header.frameContentSize :=       bytes2Int(stri, UNSIGNED, LE);
      when {8}: header.frameContentSize :=       bytes2Int(stri, UNSIGNED, LE);
    end case;
    # writeln("frameContentSize: " <& header.frameContentSize);
  end func;


# FSE means Finite State Entropy. It refers to an entropy codec.
# FSE encoding/decoding involves a state that is carried over between symbols.
# Decoding must be done in the opposite direction as encoding.
# Therefore, all FSE bitstreams are read from end to beginning.

const type: fseValueMap is hash [integer] integer;

const type: fseWeightsType is new struct
    var integer: accuracyLog is 0;
    var fseValueMap: fseValues is fseValueMap.value;  # probability = value - 1
  end struct;

const type: fseTableEntry is new struct
    var integer: baseline is 0;
    var integer: numberOfAdditionalBits is 0;
    var integer: numberOfBits is 0;
    var integer: symbol is 0;
  end struct;

const type: fseDecodingType is new struct
    var integer: accuracyLog is 0;
    var array fseTableEntry: decodingTable is 0 times fseTableEntry.value;
  end struct;


const func fseWeightsType: getFseCompressedHuffmanWeights (inout file: compressed) is func
  result
    var fseWeightsType: fseWeights is fseWeightsType.value;
  local
    var lsbInBitStream: huffmanWeightsBitStream is lsbInBitStream.value;
    var integer: remainingProbabilities is 0;
    var integer: currentSymbol is 0;
    var integer: bitsNeeded is 0;
    var integer: value is 0;
    var integer: thresh is 0;
    var integer: probability is 0;
    var integer: skip is 0;
    var integer: index is 0;
  begin
    huffmanWeightsBitStream := openLsbInBitStream(compressed);
    fseWeights.accuracyLog := getBits(huffmanWeightsBitStream, 4) + 5;
    remainingProbabilities := 1 << fseWeights.accuracyLog;
    while remainingProbabilities > 0 do
      bitsNeeded := log2(succ(remainingProbabilities));
      value := getBits(huffmanWeightsBitStream, bitsNeeded);
      thresh := (1 << succ(bitsNeeded)) - 1 - (remainingProbabilities + 1);
      if value >= thresh then
        value +:= getBit(huffmanWeightsBitStream) << bitsNeeded;
        if value >= 1 << bitsNeeded then
          value -:= thresh;
        end if;
      end if;

      fseWeights.fseValues @:= [currentSymbol] value;
      incr(currentSymbol);

      probability := value - 1;
      if probability = -1 then
         # Counts as 1 because it will get one cell in the decoding table.
        decr(remainingProbabilities);
      else
        remainingProbabilities -:= probability;
      end if;

      if probability = 0 then
        # The next two bits specify how many symbols following have probability 0 too.
        repeat
          skip := getBits(huffmanWeightsBitStream, 2);
          for index range 1 to skip do
            # Does not count into remainingProbabilities because probability = 0
            fseWeights.fseValues @:= [currentSymbol] 1;  # value = probability + 1
            incr(currentSymbol);
          end for;
        until skip <> 3;  # As long as skip = 3 the next two bits specify how many symbols to skip.
      end if;
    end while;
    close(huffmanWeightsBitStream);

    if remainingProbabilities <> 0 then
      raise RANGE_ERROR;
    end if;
  end func;


const func fseDecodingType: buildDecodingTable (in fseWeightsType: fseWeights) is func
  result
    var fseDecodingType: fseDecoding is fseDecodingType.value;
  local
    var array integer: nextStateOfSymbol is 0 times 0;
    var integer: tableSize is 0;
    var integer: highPosition is 0;
    var integer: symbol is 0;
    var integer: probability is 0;
    var integer: position is 0;
    var integer: index is 0;
    var integer: nextState is 0;
  begin
    nextStateOfSymbol := [0 .. length(fseWeights.fseValues)] times 0;

    tableSize := 1 << fseWeights.accuracyLog;
    fseDecoding.accuracyLog := fseWeights.accuracyLog;
    fseDecoding.decodingTable := [0 .. pred(tableSize)] times fseTableEntry.value;
    highPosition := pred(tableSize);
    # Find all symbols with a probability of -1.
    # This is a special probability that means: Less than 1.
    for symbol range sort(keys(fseWeights.fseValues)) do
      probability := fseWeights.fseValues[symbol] - 1;
      if probability = -1 then
        # Symbols with this probability get a single cell at the end of the table.
        fseDecoding.decodingTable[highPosition].symbol := symbol;
        decr(highPosition);
        # These symbols define a full state reset, reading accuracyLog bits.
        nextStateOfSymbol[symbol] := 1;
      else
        nextStateOfSymbol[symbol] := probability;
      end if;
    end for;

    # All remaining symbols are allocated in their natural order.
    for symbol range sort(keys(fseWeights.fseValues)) do
      probability := fseWeights.fseValues[symbol] - 1;
      if probability > 0 then
        # Each symbol gets allocated as many cells as its probability.
        for probability do
          fseDecoding.decodingTable[position].symbol := symbol;
          repeat
            # The cell allocation is spread following this rule:
            position +:= (tableSize >> 1) + (tableSize >> 3) + 3;
            position := position mod tableSize;
            # Skip positions occupied by a "less than 1" probability symbol.
          until position <= highPosition;
        end for;
      end if;
    end for;

    if position <> 0 then
      raise RANGE_ERROR;
    else
      for key index range fseDecoding.decodingTable do
        symbol := fseDecoding.decodingTable[index].symbol;
        nextState := nextStateOfSymbol[symbol];
        incr(nextStateOfSymbol[symbol]);
        fseDecoding.decodingTable[index].numberOfBits := fseWeights.accuracyLog - log2(nextState);
        fseDecoding.decodingTable[index].baseline :=
            (nextState << fseDecoding.decodingTable[index].numberOfBits) - tableSize;
      end for;
    end if;
  end func;


const func fseDecodingType: repeatingFseDecodingTable (in integer: symbol,
    in integer: numberOfAdditionalBits) is func
  result
    var fseDecodingType: fseDecoding is fseDecodingType.value;
  begin
    fseDecoding.decodingTable := [0 .. 0] times fseTableEntry.value;
    fseDecoding.decodingTable[0].numberOfAdditionalBits := numberOfAdditionalBits;
    fseDecoding.decodingTable[0].symbol := symbol;
  end func;


const proc: symbolTranslation (inout array fseTableEntry: decodingTable,
    in array integer: symbolTranslation, in array integer: extraBits) is func
  local
    var integer: index is 0;
    var integer: symbol is 0;
  begin
    for key index range decodingTable do
      symbol := decodingTable[index].symbol;
      # Translate symbols to the real ones:
      decodingTable[index].symbol := symbolTranslation[symbol];
      # Extra bits needed for decoding the sequences:
      decodingTable[index].numberOfAdditionalBits := extraBits[symbol];
    end for;
  end func;


const func integer: peekSymbol (in fseDecodingType: fseDecoding, in integer: fseState) is
    return fseDecoding.decodingTable[fseState].symbol;


const proc: nextState (in fseDecodingType: fseDecoding, inout integer: fseState,
    inout reverseBitStream: compressedStream) is func
  local
    var integer: bitsNeeded is 0;
    var integer: baseLine is 0;
  begin
    bitsNeeded := fseDecoding.decodingTable[fseState].numberOfBits;
    baseLine := fseDecoding.decodingTable[fseState].baseline;
    fseState := baseLine + getBits(compressedStream, bitsNeeded);
  end func;


const func integer: getAdditionalBits (in fseDecodingType: fseDecoding, in integer: fseState) is
    return fseDecoding.decodingTable[fseState].numberOfAdditionalBits;


const func array integer: decodeInterleavedFseStreams (in fseDecodingType: fseDecoding,
    inout reverseBitStream: compressedStream) is func
  result
    var array integer: weights is [0 .. -1] times 0;
  local
    var integer: aBit is 0;
    var array integer: fseState is 2 times 0;
    var integer: index is 0;
    var integer: otherIndex is 0;
    var boolean: endOfStreamReached is FALSE;
    var integer: symbol is 0;
  begin
    repeat
      aBit := getBits(compressedStream, 1);
    until aBit = 1;
    if bitsRead(compressedStream) > 8 then
      raise RANGE_ERROR;
    else
      fseState[1] := getBits(compressedStream, fseDecoding.accuracyLog);
      fseState[2] := getBits(compressedStream, fseDecoding.accuracyLog);
      repeat
        for key index range fseState until endOfStreamReached do
          symbol := peekSymbol(fseDecoding, fseState[index]);
          nextState(fseDecoding, fseState[index], compressedStream);
          weights &:= symbol;

          if bitsStillInStream(compressedStream) < 0 then
            otherIndex := succ(index mod length(fseState));
            symbol := peekSymbol(fseDecoding, fseState[otherIndex]);
            weights &:= symbol;
            endOfStreamReached := TRUE;
          end if;
        end for;
      until endOfStreamReached;
    end if;
  end func;


const type: zstdHuffmanDecoder is new struct
    var integer: maxBits is 0;
    var array integer: numberOfBits is 0 times 0;
    var array integer: symbols is 0 times 0;
  end struct;


const func zstdHuffmanDecoder: createZstdHuffmanDecoder (in integer: maxBits,
    in array integer: numBits, in array integer: rankCount) is func
  result
    var zstdHuffmanDecoder: decoder is zstdHuffmanDecoder.value;
  local
    var integer: index is 0;
    var integer: index2 is 0;
    var integer: base is 0;
    var integer: code is 0;
    var integer: length is 0;
    var array integer: rankIdx is 0 times 0;
  begin
    decoder.maxBits := maxBits;
    decoder.symbols := [0 .. pred(1 << maxBits)] times 0;
    decoder.numberOfBits := [0 .. pred(1 << maxBits)] times 0;

    rankIdx := [0 .. maxBits] times 0;
    for index range maxBits downto 1 do
      rankIdx[index - 1] := rankIdx[index] + rankCount[index] * (1 << (maxBits - index));
      base := rankIdx[index];
      for index2 range 0 to pred(rankIdx[pred(index)] - rankIdx[index]) do
        decoder.numberOfBits[base + index2] := index;
      end for;
    end for;

    if rankIdx[0] <> length(decoder.numberOfBits) then
      raise RANGE_ERROR;
    else
      for index range 0 to pred(length(numBits)) do
        if numBits[index] <> 0 then
          code := rankIdx[numBits[index]];
          length := 1 << (maxBits - numBits[index]);
          for index2 range 0 to pred(length) do
            decoder.symbols[code + index2] := index;
          end for;
          rankIdx[numBits[index]] +:= length;
        end if;
      end for;
    end if;
  end func;


const func zstdHuffmanDecoder: createZstdHuffmanDecoder (in array integer: weights) is func
  result
    var zstdHuffmanDecoder: decoder is zstdHuffmanDecoder.value;
  local
    var integer: sum is 0;
    var integer: weight is 0;
    var integer: maxBits is 0;
    var integer: leftOver is 0;
    var integer: lastWeight is 0;
    var integer: index is 0;
    var integer: numberOfBits is 0;
    var array integer: numBits is 0 times 0;
    var array integer: rankCount is 0 times 0;
  begin
    for weight range weights do
      if weight > 0 then
        sum +:= 1 << pred(weight);
      end if;
    end for;

    maxBits := log2(sum) + 1;
    # Deduce the weight of the last symbol.
    leftOver := (1 << maxBits) - sum;
    if leftOver > 0 and 1 << log2(leftOver) <> leftOver then
      # leftOver is not a power of two.
      raise RANGE_ERROR;
    else
      lastWeight := log2(leftOver) + 1;

      numBits := [0 .. length(weights)] times 0;
      rankCount := [0 .. maxBits] times 0;
      for weight key index range weights do
        if weight > 0 then
          numberOfBits := maxBits + 1 - weight;
        else
          numberOfBits := 0;
        end if;
        numBits[index] := numberOfBits;
        incr(rankCount[numberOfBits]);
      end for;

      if lastWeight > 0 then
        numberOfBits := maxBits + 1 - lastWeight;
      else
        numberOfBits := 0;
      end if;
      numBits[maxIdx(numBits)] := numberOfBits;
      incr(rankCount[numberOfBits]);

      decoder := createZstdHuffmanDecoder(maxBits, numBits, rankCount);
    end if;
  end func;


const func array integer: readZstdHuffmanTreeWeights (inout file: compressed) is func
  result
    var array integer: weights is [0 .. -1] times 0;
  local
    var integer: huffmanTreeHeaderByte is 0;
    var fseWeightsType: fseWeights is fseWeightsType.value;
    var fseDecodingType: fseDecoding is fseDecodingType.value;
    var integer: posBefore is 0;
    var integer: bitStreamLength is 0;
    var reverseBitStream: compressedStream is reverseBitStream.value;
    var integer: numberOfWeights is 0;
    var integer: index is 0;
    var integer: aByte is 0;
  begin
    huffmanTreeHeaderByte := ord(getc(compressed));
    if huffmanTreeHeaderByte < 128 then
      posBefore := tell(compressed);
      fseWeights := getFseCompressedHuffmanWeights(compressed);
      fseDecoding := buildDecodingTable(fseWeights);
      bitStreamLength := huffmanTreeHeaderByte - (tell(compressed) - posBefore);
      compressedStream := reverseBitStream(compressed, bitStreamLength);
      weights := decodeInterleavedFseStreams(fseDecoding, compressedStream);
    else
      # Direct representation of huffman weights.
      numberOfWeights := huffmanTreeHeaderByte - 127;
      weights := [0 .. pred(numberOfWeights)] times 0;
      for index range 0 to pred(numberOfWeights) do
        if index mod 2 = 0 then
          aByte := ord(getc(compressed));
          weights[index] := aByte >> 4;
        else
          weights[index] := aByte mod 16;
        end if;
      end for;
    end if;
  end func;


const func char: decodeSymbol (in zstdHuffmanDecoder: huffmanDecoder,
    inout integer: state, inout reverseBitStream: compressedStream) is func
  result
    var char: symbol is ' ';
  local
    var integer: bitsNeeded is 0;
    var integer: rest is 0;
  begin
    symbol := char(huffmanDecoder.symbols[state]);
    bitsNeeded := huffmanDecoder.numberOfBits[state];
    rest := getBits(compressedStream, bitsNeeded);
    state := ((state << bitsNeeded) + rest) mod (1 << huffmanDecoder.maxBits);
  end func;


const func string: decodeStream (in zstdHuffmanDecoder: huffmanDecoder,
    inout reverseBitStream: compressedStream) is func
  result
    var string: uncompressed is "";
  local
    var integer: aBit is 0;
    var integer: state is 0;
  begin
    repeat
      aBit := getBits(compressedStream, 1);
    until aBit = 1;
    if bitsRead(compressedStream) > 8 then
      raise RANGE_ERROR;
    else
      state := getBits(compressedStream, huffmanDecoder.maxBits);
      while bitsStillInStream(compressedStream) > -huffmanDecoder.maxBits do
        uncompressed &:= decodeSymbol(huffmanDecoder, state, compressedStream);
      end while;
      if bitsStillInStream(compressedStream) <> -huffmanDecoder.maxBits then
        raise RANGE_ERROR;
      end if;
    end if;
  end func;


const integer: ZSTD_RAW_LITERALS_BLOCK        is 0;
const integer: ZSTD_RLE_LITERALS_BLOCK        is 1;
const integer: ZSTD_COMPRESSED_LITERALS_BLOCK is 2;
const integer: ZSTD_TREELESS_LITERALS_BLOCK   is 3;

const type: zstdLiteralSectionHeader is new struct
    var integer: literalsBlockType is 0;
    var integer: numStreams is 0;
    var integer: regeneratedSize is 0;
    var integer: compressedSize is 0;
  end struct;


const func zstdLiteralSectionHeader: readZstdLiteralsSectionHeader (inout file: compressed) is func
  result
    var zstdLiteralSectionHeader: header is zstdLiteralSectionHeader.value;
  local
    var integer: aByte is 0;
    var integer: sizeFormat is 0;
    var integer: bytesToRead is 0;
    var string: stri is "";
    var integer: number is 0;
  begin
    aByte := ord(getc(compressed));
    header.literalsBlockType := aByte mod 4;
    sizeFormat := (aByte >> 2) mod 4;
    if header.literalsBlockType in {ZSTD_RAW_LITERALS_BLOCK, ZSTD_RLE_LITERALS_BLOCK} then
      # Only regeneratedSize is decoded.
      case sizeFormat of
        when {0, 2}:
          # Only one bit of sizeFormat is used.
          header.regeneratedSize := aByte >> 3;
        when {1}:
          header.regeneratedSize := (aByte >> 4) + (bytes2Int(gets(compressed, 1), UNSIGNED, LE) << 4);
        when {3}:
          header.regeneratedSize := (aByte >> 4) + (bytes2Int(gets(compressed, 2), UNSIGNED, LE) << 4);
      end case;
    else
      # Decode compressedSize and regeneratedSize.
      bytesToRead := max(2, succ(sizeFormat));
      stri := gets(compressed, bytesToRead);
      if length(stri) = bytesToRead then
        case sizeFormat of
          when {0}:
             header.numStreams := 1;
             number := (aByte >> 4) + (bytes2Int(stri, UNSIGNED, LE) << 4);
             header.regeneratedSize := number mod 1024;
             header.compressedSize := number >> 10;
          when {1}:
             header.numStreams := 4;
             number := (aByte >> 4) + (bytes2Int(stri, UNSIGNED, LE) << 4);
             header.regeneratedSize := number mod 1024;
             header.compressedSize := number >> 10;
          when {2}:
             header.numStreams := 4;
             number := (aByte >> 4) + (bytes2Int(stri, UNSIGNED, LE) << 4);
             header.regeneratedSize := number mod 16384;
             header.compressedSize := number >> 14;
          when {3}:
             header.numStreams := 4;
             number := (aByte >> 4) + (bytes2Int(stri, UNSIGNED, LE) << 4);
             header.regeneratedSize := number mod 262144;
             header.compressedSize := number >> 18;
        end case;
      else
        raise RANGE_ERROR;
      end if;
    end if;
  end func;


const type: zstdSequencesSectionType is new struct
    var integer: numberOfSequences is 0;
    var fseDecodingType: literalLengthsFseDecoding is fseDecodingType.value;
    var fseDecodingType: offsetsFseDecoding is fseDecodingType.value;
    var fseDecodingType: matchLengthsFseDecoding is fseDecodingType.value;
  end struct;

const type: zstdBlockStateType is new struct
    var array integer: offsetHistory is [] (1, 4, 8);
    var zstdHuffmanDecoder: decoder is zstdHuffmanDecoder.value;
    var zstdSequencesSectionType: sequencesSection is zstdSequencesSectionType.value;
  end struct;


const func string: readCompressedLiteralsBlock (inout file: compressed,
    inout zstdBlockStateType: blockState, in zstdLiteralSectionHeader: header) is func
  result
    var string: literals is "";
  local
    var integer: compressedSize is 0;
    var integer: posBefore is 0;
    var array integer: weights is [0 .. -1] times 0;
    var array integer: streamSize is 4 times 0;
    var array reverseBitStream: compressedStream is 4 times reverseBitStream.value;
    var string: stri is "";
  begin
    compressedSize := header.compressedSize;
    if header.literalsBlockType = ZSTD_COMPRESSED_LITERALS_BLOCK then
      posBefore := tell(compressed);
      weights := readZstdHuffmanTreeWeights(compressed);
      compressedSize -:= tell(compressed) - posBefore;
      blockState.decoder := createZstdHuffmanDecoder(weights);
    end if;
    if header.numStreams = 1 then
      compressedStream[1] := reverseBitStream(compressed, compressedSize);
      literals := decodeStream(blockState.decoder, compressedStream[1]);
    elsif header.numStreams = 4 then
      stri := gets(compressed, 6);
      if length(stri) = 6 then
        streamSize[1] := bytes2Int(stri[1 fixLen 2], UNSIGNED, LE);
        streamSize[2] := bytes2Int(stri[3 fixLen 2], UNSIGNED, LE);
        streamSize[3] := bytes2Int(stri[5 fixLen 2], UNSIGNED, LE);
        compressedSize -:= 6;
        streamSize[4] := compressedSize - streamSize[1] - streamSize[2] - streamSize[3];
        compressedStream[1] := reverseBitStream(compressed, streamSize[1]);
        compressedStream[2] := reverseBitStream(compressed, streamSize[2]);
        compressedStream[3] := reverseBitStream(compressed, streamSize[3]);
        compressedStream[4] := reverseBitStream(compressed, streamSize[4]);
        literals := decodeStream(blockState.decoder, compressedStream[1]) &
                    decodeStream(blockState.decoder, compressedStream[2]) &
                    decodeStream(blockState.decoder, compressedStream[3]) &
                    decodeStream(blockState.decoder, compressedStream[4]);
      else
        raise RANGE_ERROR;
      end if;
    end if;
  end func;


const func string: readLiteralsSection (inout file: compressed,
    inout zstdBlockStateType: blockState) is func
  result
    var string: literals is "";
  local
    var zstdLiteralSectionHeader: header is zstdLiteralSectionHeader.value;
    var char: ch is ' ';
  begin
    header := readZstdLiteralsSectionHeader(compressed);
    case header.literalsBlockType of
      when {ZSTD_RAW_LITERALS_BLOCK}: literals := gets(compressed, header.regeneratedSize);
      when {ZSTD_RLE_LITERALS_BLOCK}: literals := str(getc(compressed)) mult header.regeneratedSize;
      when {ZSTD_COMPRESSED_LITERALS_BLOCK, ZSTD_TREELESS_LITERALS_BLOCK}:
          literals := readCompressedLiteralsBlock(compressed, blockState, header);
    end case;
  end func;


const integer: ZSTD_PREDEFINED_MODE     is 0;
const integer: ZSTD_RLE_MODE            is 1;
const integer: ZSTD_FSE_COMPRESSED_MODE is 2;
const integer: ZSTD_REPEAT_MODE         is 3;

const type: zstdSequencesSectionHeader is new struct
    var integer: numberOfSequences is 0;
    var integer: literalLengthsMode is 0;
    var integer: offsetsMode is 0;
    var integer: matchLengthsMode is 0;
  end struct;


const func zstdSequencesSectionHeader: readZstdSequencesSectionHeader (inout file: compressed) is func
  result
    var zstdSequencesSectionHeader: header is zstdSequencesSectionHeader.value;
  local
    var integer: byte0 is 0;
    var integer: symbolCompressionModes is 0;
  begin
    byte0 := ord(getc(compressed));
    if byte0 = 0 then
      header.numberOfSequences := 0;
    elsif byte0 < 128 then
      header.numberOfSequences := byte0;
    elsif byte0 < 255 then
      header.numberOfSequences := ((byte0 - 128) << 8) + ord(getc(compressed));
    elsif byte0 = 255 then
      header.numberOfSequences := bytes2Int(gets(compressed, 2), UNSIGNED, LE) + 16#7f00;
    else
      raise RANGE_ERROR;
    end if;
    if header.numberOfSequences <> 0 then
      symbolCompressionModes := ord(getc(compressed));
      header.literalLengthsMode := symbolCompressionModes >> 6;
      header.offsetsMode := (symbolCompressionModes >> 4) mod 4;
      header.matchLengthsMode := (symbolCompressionModes >> 2) mod 4;
    end if;
  end func;


const type: zstdSequenceState is new struct
    var integer: literalLengthsState is 0;
    var integer: offsetsState is 0;
    var integer: matchLengthsState is 0;
  end struct;

const type: zstdSequenceType is new struct
    var integer: matchLength is 0;
    var integer: literalLength is 0;
    var integer: offset is 0;
  end struct;

const array integer: zstdLiteralLengthBaseValueTranslation is [0] (
    0, 1, 2, 3, 4, 5, 6, 7,
    8, 9, 10, 11, 12, 13, 14, 15,
    16, 18, 20, 22, 24, 28, 32, 40,
    48, 64, 16#80, 16#100, 16#200, 16#400, 16#800,
    16#1000, 16#2000, 16#4000, 16#8000, 16#10000);


const array integer: zstdLiteralLengthExtraBits is [0] (
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
    1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);


const array integer: zstdMatchLengthBaseValueTranslation is [0] (
    3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
    31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
    99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539);


const array integer: zstdMatchLengthsExtraBits is [0] (
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
    2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);


const func fseDecodingType: buildZstdLiteralLengthsTable is func
  result
    var fseDecodingType: fseDecoding is fseDecodingType.value;
  local
    const integer: literalLengthDefaultAccuracyLog is 6;
    const array integer: literalLengthDefaultDistributions is [0] (
        4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
        -1, -1, -1, -1);
    var fseWeightsType: fseData is fseWeightsType.value;
    var integer: probability is 0;
    var integer: index is 0;
  begin
    fseData.accuracyLog := literalLengthDefaultAccuracyLog;
    for probability key index range literalLengthDefaultDistributions do
      fseData.fseValues @:= [index] succ(probability);  # value = probability + 1
    end for;
    fseDecoding := buildDecodingTable(fseData);
    symbolTranslation(fseDecoding.decodingTable,
        zstdLiteralLengthBaseValueTranslation, zstdLiteralLengthExtraBits);
  end func;


const func fseDecodingType: buildZstdOffsetTable is func
  result
    var fseDecodingType: fseDecoding is fseDecodingType.value;
  local
    const integer: offsetDefaultAccuracyLog is 5;
    const array integer: offsetDefaultDistribution is [0] (
        1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1);
    var fseWeightsType: fseData is fseWeightsType.value;
    var integer: probability is 0;
    var integer: index is 0;
  begin
    fseData.accuracyLog := offsetDefaultAccuracyLog;
    for probability key index range offsetDefaultDistribution do
      fseData.fseValues @:= [index] succ(probability);  # value = probability + 1
    end for;
    fseDecoding := buildDecodingTable(fseData);
  end func;


const func fseDecodingType: buildZstdMatchLengthsTable is func
  result
    var fseDecodingType: fseDecoding is fseDecodingType.value;
  local
    const integer: matchLengthDefaultAccuracyLog is 6;
    const array integer: matchLengthDefaultDistribution is [0] (
        1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1);
    var fseWeightsType: fseData is fseWeightsType.value;
    var integer: probability is 0;
    var integer: index is 0;
  begin
    fseData.accuracyLog := matchLengthDefaultAccuracyLog;
    for probability key index range matchLengthDefaultDistribution do
      fseData.fseValues @:= [index] succ(probability);  # value = probability + 1
    end for;
    fseDecoding := buildDecodingTable(fseData);
    symbolTranslation(fseDecoding.decodingTable,
        zstdMatchLengthBaseValueTranslation, zstdMatchLengthsExtraBits);
  end func;


const proc: initDecodeTables (inout file: compressed,
    inout zstdBlockStateType: blockState, in zstdSequencesSectionHeader: header) is func
  local
    var integer: byteToRepeat is 0;
    var fseWeightsType: fseData is fseWeightsType.value;
  begin
    blockState.sequencesSection.numberOfSequences := header.numberOfSequences;
    case header.literalLengthsMode of
      when {ZSTD_PREDEFINED_MODE}:
        blockState.sequencesSection.literalLengthsFseDecoding := buildZstdLiteralLengthsTable;
      when {ZSTD_RLE_MODE}:
        byteToRepeat := ord(getc(compressed));
        blockState.sequencesSection.literalLengthsFseDecoding :=
            repeatingFseDecodingTable(zstdLiteralLengthBaseValueTranslation[byteToRepeat],
                                      zstdLiteralLengthExtraBits[byteToRepeat]);
      when {ZSTD_FSE_COMPRESSED_MODE}:
        fseData := getFseCompressedHuffmanWeights(compressed);
        blockState.sequencesSection.literalLengthsFseDecoding := buildDecodingTable(fseData);
        symbolTranslation(blockState.sequencesSection.literalLengthsFseDecoding.decodingTable,
            zstdLiteralLengthBaseValueTranslation, zstdLiteralLengthExtraBits);
      when {ZSTD_REPEAT_MODE}:
        if length(blockState.sequencesSection.literalLengthsFseDecoding.decodingTable) = 0 then
          raise RANGE_ERROR;
        end if;
    end case;

    case header.offsetsMode of
      when {ZSTD_PREDEFINED_MODE}:
        blockState.sequencesSection.offsetsFseDecoding := buildZstdOffsetTable;
      when {ZSTD_RLE_MODE}:
        byteToRepeat := ord(getc(compressed));
        blockState.sequencesSection.offsetsFseDecoding :=
            repeatingFseDecodingTable(byteToRepeat, 0);
      when {ZSTD_FSE_COMPRESSED_MODE}:
        fseData := getFseCompressedHuffmanWeights(compressed);
        blockState.sequencesSection.offsetsFseDecoding := buildDecodingTable(fseData);
      when {ZSTD_REPEAT_MODE}:
        if length(blockState.sequencesSection.offsetsFseDecoding.decodingTable) = 0 then
          raise RANGE_ERROR;
        end if;
    end case;

    case header.matchLengthsMode of
      when {ZSTD_PREDEFINED_MODE}:
        blockState.sequencesSection.matchLengthsFseDecoding := buildZstdMatchLengthsTable;
      when {ZSTD_RLE_MODE}:
        byteToRepeat := ord(getc(compressed));
        blockState.sequencesSection.matchLengthsFseDecoding :=
            repeatingFseDecodingTable(zstdMatchLengthBaseValueTranslation[byteToRepeat],
                                      zstdMatchLengthsExtraBits[byteToRepeat]);
      when {ZSTD_FSE_COMPRESSED_MODE}:
        fseData := getFseCompressedHuffmanWeights(compressed);
        blockState.sequencesSection.matchLengthsFseDecoding := buildDecodingTable(fseData);
        symbolTranslation(blockState.sequencesSection.matchLengthsFseDecoding.decodingTable,
            zstdMatchLengthBaseValueTranslation, zstdMatchLengthsExtraBits);
      when {ZSTD_REPEAT_MODE}:
        if length(blockState.sequencesSection.matchLengthsFseDecoding.decodingTable) = 0 then
          raise RANGE_ERROR;
        end if;
    end case;
  end func;


const func zstdSequenceType: decodeSequence (in zstdSequencesSectionType: sequencesSection,
    inout reverseBitStream: compressedStream, inout zstdSequenceState: state) is func
  result
    var zstdSequenceType: sequence is zstdSequenceType.value;
  local
    var integer: offsetCode is 0;
    var integer: matchLengthAdditionalBits is 0;
    var integer: literalLengthAdditionalBits is 0;
  begin
    offsetCode := peekSymbol(sequencesSection.offsetsFseDecoding, state.offsetsState);
    sequence.offset := getBits(compressedStream, offsetCode) + (1 << offsetCode);

    sequence.matchLength := peekSymbol(sequencesSection.matchLengthsFseDecoding,
                                       state.matchLengthsState);
    matchLengthAdditionalBits := getAdditionalBits(sequencesSection.matchLengthsFseDecoding,
                                                   state.matchLengthsState);
    if matchLengthAdditionalBits <> 0 then
      sequence.matchLength +:= getBits(compressedStream, matchLengthAdditionalBits);
    end if;

    sequence.literalLength := peekSymbol(sequencesSection.literalLengthsFseDecoding,
                                         state.literalLengthsState);
    literalLengthAdditionalBits := getAdditionalBits(sequencesSection.literalLengthsFseDecoding,
                                                     state.literalLengthsState);
    if literalLengthAdditionalBits <> 0 then
      sequence.literalLength +:= getBits(compressedStream, literalLengthAdditionalBits);
    end if;
  end func;


const proc: decodeSequences (in zstdSequencesSectionType: sequencesSection,
    inout reverseBitStream: compressedStream, in string: literals,
    inout zstdBlockStateType: blockState, inout string: uncompressed) is func
  local
    var integer: aBit is 0;
    var zstdSequenceState: state is zstdSequenceState.value;
    var integer: index is 0;
    var zstdSequenceType: sequence is zstdSequenceType.value;
    var integer: literalPos is 1;
    var integer: offset is 0;
    var integer: nextPos is 0;
    var integer: number is 0;
  begin
    repeat
      aBit := getBits(compressedStream, 1);
    until aBit = 1;
    if bitsRead(compressedStream) > 8 then
      raise RANGE_ERROR;
    else
      state.literalLengthsState := getBits(compressedStream, sequencesSection.literalLengthsFseDecoding.accuracyLog);
      state.offsetsState        := getBits(compressedStream, sequencesSection.offsetsFseDecoding.accuracyLog);
      state.matchLengthsState   := getBits(compressedStream, sequencesSection.matchLengthsFseDecoding.accuracyLog);

      for index range 1 to sequencesSection.numberOfSequences do
        sequence := decodeSequence(sequencesSection, compressedStream, state);
        if sequence.literalLength > 0 then
          uncompressed &:= literals[literalPos fixLen sequence.literalLength];
          literalPos +:= sequence.literalLength;
        end if;

        if sequence.matchLength > 0 then
          if sequence.offset <= 3 then
            if sequence.literalLength = 0 then
              case sequence.offset of
                when {0}: raise RANGE_ERROR;
                when {1}: offset := blockState.offsetHistory[2];
                          blockState.offsetHistory[2] := blockState.offsetHistory[1];
                when {2}: offset := blockState.offsetHistory[3];
                          blockState.offsetHistory[3] := blockState.offsetHistory[2];
                          blockState.offsetHistory[2] := blockState.offsetHistory[1];
                when {3}: offset := pred(blockState.offsetHistory[1]);
                          blockState.offsetHistory[3] := blockState.offsetHistory[2];
                          blockState.offsetHistory[2] := blockState.offsetHistory[1];
              end case;
              blockState.offsetHistory[1] := offset;
            else
              offset := blockState.offsetHistory[sequence.offset];
              case sequence.offset of
                when {0}: raise RANGE_ERROR;
                when {1}: noop;
                when {2}: blockState.offsetHistory[2] := blockState.offsetHistory[1];
                          blockState.offsetHistory[1] := offset;
                when {3}: blockState.offsetHistory[3] := blockState.offsetHistory[2];
                          blockState.offsetHistory[2] := blockState.offsetHistory[1];
                          blockState.offsetHistory[1] := offset;
              end case;
            end if;
          else
            offset := sequence.offset - 3;
            blockState.offsetHistory[3] := blockState.offsetHistory[2];
            blockState.offsetHistory[2] := blockState.offsetHistory[1];
            blockState.offsetHistory[1] := offset;
          end if;
          if sequence.matchLength > offset then
            nextPos := succ(length(uncompressed));
            uncompressed &:= "\0;" mult sequence.matchLength;
            for number range nextPos to nextPos + sequence.matchLength - 1 do
              uncompressed @:= [number] uncompressed[number - offset];
            end for;
          else # hopefully length(uncompressed) >= offset holds
            uncompressed &:= uncompressed[succ(length(uncompressed)) - offset fixLen sequence.matchLength];
          end if;
        end if;

        # Don't update on the last index.
        if index < sequencesSection.numberOfSequences then
          nextState(sequencesSection.literalLengthsFseDecoding, state.literalLengthsState,
                    compressedStream);
          nextState(sequencesSection.matchLengthsFseDecoding, state.matchLengthsState,
                    compressedStream);
          nextState(sequencesSection.offsetsFseDecoding, state.offsetsState,
                    compressedStream);
        end if;
      end for;

      uncompressed &:= literals[literalPos ..];

      if bitsStillInStream(compressedStream) <> 0 then
        raise RANGE_ERROR;
      end if;
    end if;
  end func;


const proc: readCompressedBlock (inout file: compressed, in integer: blockSize,
    inout zstdBlockStateType: blockState, inout string: uncompressed) is func
  local
    var integer: blockStartPos is 0;
    var string: literals is "";
    var zstdSequencesSectionHeader: header is zstdSequencesSectionHeader.value;
    var integer: lengthOfSequencesData is 0;
    var reverseBitStream: compressedStream is reverseBitStream.value;
  begin
    blockStartPos := tell(compressed);
    literals := readLiteralsSection(compressed, blockState);
    header := readZstdSequencesSectionHeader(compressed);
    if header.numberOfSequences = 0 then
      uncompressed &:= literals;
    else
      initDecodeTables(compressed, blockState, header);
      lengthOfSequencesData := blockSize - (tell(compressed) - blockStartPos);
      compressedStream := reverseBitStream(compressed, lengthOfSequencesData);
      decodeSequences(blockState.sequencesSection, compressedStream, literals, blockState, uncompressed);
    end if;
  end func;


const func boolean: zstdBlock (inout file: compressed, inout zstdBlockStateType: blockState,
    inout string: uncompressed) is func
  result
    var boolean: lastBlock is FALSE;
  local
    var string: stri is "";
    var integer: blockSize is 0;
    var integer: blockType is 0;
  begin
    stri := gets(compressed, 3);
    if length(stri) = 3 then
      blockSize := bytes2Int(stri, UNSIGNED, LE);
      lastBlock := odd(blockSize);
      # writeln("lastBlock: " <& lastBlock);
      blockSize >>:= 1;
      blockType := blockSize mod 4;
      # writeln("blockType: " <& blockType);
      blockSize >>:= 2;
      # writeln("blockSize: " <& blockSize);
      case blockType of
        when {0}: uncompressed &:= gets(compressed, blockSize);
        when {1}: uncompressed &:= str(getc(compressed)) mult blockSize;
        when {2}: readCompressedBlock(compressed, blockSize, blockState, uncompressed);
      end case;
    else
      raise RANGE_ERROR;
    end if;
  end func;


(**
 *  [[file|File]] implementation type to decompress a Zstandard file.
 *  Zstandard is a file format used for compression.
 *)
const type: zstdFile is sub null_file struct
    var file: compressed is STD_NULL;
    var boolean: finished is FALSE;
    var zstdBlockStateType: blockState is zstdBlockStateType.value;
    var string: uncompressed is "";
    var integer: position is 1;
  end struct;

type_implements_interface(zstdFile, file);


(**
 *  Open a Zstandard file for reading (decompression).
 *  Zstandard is a file format used for compression. Reading from
 *  the file delivers decompressed data. Writing is not supported.
 *  @return the file opened, or [[null_file#STD_NULL|STD_NULL]]
 *          if the file is not in Zstandard format.
 *)
const func file: openZstdFile (inout file: compressed) is func
  result
    var file: newFile is STD_NULL;
  local
    var string: magic is "";
    var zstdFrameHeader: frameHeader is zstdFrameHeader.value;
    var zstdFile: new_zstdFile is zstdFile.value;
  begin
    magic := gets(compressed, length(ZSTD_MAGIC));
    # writeln("openZstdFile: " <& literal(magic));
    if magic = ZSTD_MAGIC then
      readFrameHeader(compressed, frameHeader);
      new_zstdFile.compressed := compressed;
      newFile := toInterface(new_zstdFile);
    end if;
  end func;


(**
 *  Close a ''zstdFile''.
 *)
const proc: close (in zstdFile: aFile) is noop;


(**
 *  Read a character from a ''zstdFile''.
 *  @return the character read.
 *)
const func char: getc (inout zstdFile: inFile) is func
  result
    var char: charRead is ' ';
  begin
    while inFile.position > length(inFile.uncompressed) and
        not inFile.finished do
      inFile.finished:= zstdBlock(inFile.compressed, inFile.blockState, inFile.uncompressed);
    end while;
    if inFile.position <= length(inFile.uncompressed) then
      charRead := inFile.uncompressed[inFile.position];
      incr(inFile.position);
    else
      charRead := EOF;
    end if;
  end func;


(**
 *  Read a string with maximum length from a ''zstdFile''.
 *  @return the string read.
 *  @exception RANGE_ERROR The parameter ''maxLength'' is negative.
 *)
const func string: gets (inout zstdFile: inFile, in integer: maxLength) is func
  result
    var string: striRead is "";
  begin
    if maxLength <= 0 then
      if maxLength <> 0 then
        raise RANGE_ERROR;
      end if;
    else
      while maxLength > succ(length(inFile.uncompressed) - inFile.position) and
          not inFile.finished do
        inFile.finished:= zstdBlock(inFile.compressed, inFile.blockState, inFile.uncompressed);
      end while;
      if maxLength <= succ(length(inFile.uncompressed) - inFile.position) then
        striRead := inFile.uncompressed[inFile.position fixLen maxLength];
        inFile.position +:= maxLength;
      else
        striRead := inFile.uncompressed[inFile.position ..];
        inFile.position := succ(length(inFile.uncompressed));
      end if;
    end if;
  end func;


(**
 *  Determine the end-of-file indicator.
 *  The end-of-file indicator is set if at least one request to read
 *  from the file failed.
 *  @return TRUE if the end-of-file indicator is set, FALSE otherwise.
 *)
const func boolean: eof (in zstdFile: inFile) is
  return inFile.position > length(inFile.uncompressed) and inFile.finished;


(**
 *  Determine if at least one character can be read successfully.
 *  This function allows a file to be handled like an iterator.
 *  @return FALSE if ''getc'' would return EOF, TRUE otherwise.
 *)
const func boolean: hasNext (inout zstdFile: inFile) is func
  result
    var boolean: hasNext is FALSE;
  begin
    while inFile.position > length(inFile.uncompressed) and
        not inFile.finished do
      inFile.finished:= zstdBlock(inFile.compressed, inFile.blockState, inFile.uncompressed);
    end while;
    hasNext := inFile.position <= length(inFile.uncompressed);
  end func;


(**
 *  Obtain the length of a file.
 *  The file length is measured in bytes.
 *  @return the length of a file, or 0 if it cannot be obtained.
 *)
const func integer: length (inout zstdFile: aFile) is func
  result
    var integer: length is 0;
  begin
    while not aFile.finished do
      aFile.finished:= zstdBlock(aFile.compressed, aFile.blockState, aFile.uncompressed);
    end while;
    length := length(aFile.uncompressed);
  end func;


(**
 *  Determine if the file ''aFile'' is seekable.
 *  If a file is seekable the functions ''seek'' and ''tell''
 *  can be used to set and and obtain the current file position.
 *  @return TRUE, since a ''zstdFile'' is seekable.
 *)
const boolean: seekable (in zstdFile: aFile) is TRUE;


(**
 *  Set the current file position.
 *  The file position is measured in bytes from the start of the file.
 *  The first byte in the file has the position 1.
 *  @exception RANGE_ERROR The file position is negative or zero.
 *)
const proc: seek (inout zstdFile: aFile, in integer: position) is func
  begin
    if position <= 0 then
      raise RANGE_ERROR;
    else
      aFile.position := position;
    end if;
  end func;


(**
 *  Obtain the current file position.
 *  The file position is measured in bytes from the start of the file.
 *  The first byte in the file has the position 1.
 *  @return the current file position.
 *)
const func integer: tell (in zstdFile: aFile) is
  return aFile.position;