using System; using System.Collections.Generic; using System.IO; using SabreTools.Models.Compression.Quantum; using static SabreTools.Compression.Quantum.Constants; namespace SabreTools.Compression.Quantum { /// public class Decompressor { /// /// Internal bitstream to use for decompression /// private BitStream _bitStream; #region Models /// /// Selector 0: literal, 64 entries, starting symbol 0 /// private Model _model0; /// /// Selector 1: literal, 64 entries, starting symbol 64 /// private Model _model1; /// /// Selector 2: literal, 64 entries, starting symbol 128 /// private Model _model2; /// /// Selector 3: literal, 64 entries, starting symbol 192 /// private Model _model3; /// /// Selector 4: LZ, 3 character matches /// private Model _model4; /// /// Selector 5: LZ, 4 character matches /// private Model _model5; /// /// Selector 6: LZ, 5+ character matches /// private Model _model6; /// /// Selector 6 length model /// private Model _model6len; /// /// Selector selector model /// private Model _selector; #endregion #region Coding State /// /// Artihmetic coding state: high /// private ushort CS_H; /// /// Artihmetic coding state: low /// private ushort CS_L; /// /// Artihmetic coding state: current /// private ushort CS_C; #endregion /// /// Create a new Decompressor from a byte array /// /// Byte array to decompress /// Number of bits in the sliding window public Decompressor(byte[]? input, uint windowBits) { // If we have an invalid stream if (input == null || input.Length == 0) throw new ArgumentException(nameof(input)); // If we have an invalid value for the window bits if (windowBits < 10 || windowBits > 21) throw new ArgumentOutOfRangeException(nameof(windowBits)); // Create a memory stream to wrap var ms = new MemoryStream(input); // Wrap the stream in a BitStream _bitStream = new BitStream(ms); // Initialize literal models this._model0 = CreateModel(0, 64); this._model1 = CreateModel(64, 64); this._model2 = CreateModel(128, 64); this._model3 = CreateModel(192, 64); // Initialize LZ models int maxBitLength = (int)(windowBits * 2); this._model4 = CreateModel(0, maxBitLength > 24 ? 24 : maxBitLength); this._model5 = CreateModel(0, maxBitLength > 36 ? 36 : maxBitLength); this._model6 = CreateModel(0, maxBitLength); this._model6len = CreateModel(0, 27); // Initialze the selector model this._selector = CreateModel(0, 7); // Initialize coding state this.CS_H = 0; this.CS_L = 0; this.CS_C = 0; } /// /// Create a new Decompressor from a Stream /// /// Stream to decompress /// Number of bits in the sliding window public Decompressor(Stream? input, uint windowBits) { // If we have an invalid stream if (input == null || !input.CanRead || !input.CanSeek) throw new ArgumentException(nameof(input)); // If we have an invalid value for the window bits if (windowBits < 10 || windowBits > 21) throw new ArgumentOutOfRangeException(nameof(windowBits)); // Wrap the stream in a BitStream _bitStream = new BitStream(input); // Initialize literal models this._model0 = CreateModel(0, 64); this._model1 = CreateModel(64, 64); this._model2 = CreateModel(128, 64); this._model3 = CreateModel(192, 64); // Initialize LZ models int maxBitLength = (int)(windowBits * 2); this._model4 = CreateModel(0, maxBitLength > 24 ? 24 : maxBitLength); this._model5 = CreateModel(0, maxBitLength > 36 ? 36 : maxBitLength); this._model6 = CreateModel(0, maxBitLength); this._model6len = CreateModel(0, 27); // Initialze the selector model this._selector = CreateModel(0, 7); // Initialize coding state this.CS_H = 0; this.CS_L = 0; this.CS_C = 0; } /// /// Process the stream and return the decompressed output /// /// Byte array representing the decompressed data, null on error public byte[] Process() { // Initialize the coding state CS_H = 0xffff; CS_L = 0x0000; CS_C = (ushort)(_bitStream.ReadBitsMSB(16) ?? 0); // Loop until the end of the stream var bytes = new List(); while (_bitStream.Position < _bitStream.Length) { // Determine the selector to use int selector = GetSymbol(_selector); // Handle literal selectors if (selector < 4) { switch (selector) { case 0: bytes.Add((byte)GetSymbol(_model0)); break; case 1: bytes.Add((byte)GetSymbol(_model1)); break; case 2: bytes.Add((byte)GetSymbol(_model2)); break; case 3: bytes.Add((byte)GetSymbol(_model3)); break; default: throw new ArgumentOutOfRangeException(); } } // Handle LZ selectors else { int offset, length; switch (selector) { case 4: int model4sym = GetSymbol(_model4); int model4extra = (int)(_bitStream.ReadBitsMSB(PositionExtraBits[model4sym]) ?? 0); offset = PositionSlot[model4sym] + model4extra + 1; length = 3; break; case 5: int model5sym = GetSymbol(_model5); int model5extra = (int)(_bitStream.ReadBitsMSB(PositionExtraBits[model5sym]) ?? 0); offset = PositionSlot[model5sym] + model5extra + 1; length = 4; break; case 6: int lengthSym = GetSymbol(_model6len); int lengthExtra = (int)(_bitStream.ReadBitsMSB(LengthExtraBits[lengthSym]) ?? 0); length = LengthSlot[lengthSym] + lengthExtra + 5; int model6sym = GetSymbol(_model6); int model6extra = (int)(_bitStream.ReadBitsMSB(PositionExtraBits[model6sym]) ?? 0); offset = PositionSlot[model6sym] + model6extra + 1; break; default: throw new ArgumentOutOfRangeException(); } // Copy the previous data int copyIndex = bytes.Count - offset; while (length-- > 0) { bytes.Add(bytes[copyIndex++]); } // TODO: Add MS-CAB specific padding // TODO: Add Cinematronics specific checksum } } return bytes.ToArray(); } /// /// Create and initialize a model base on the start symbol and length /// private Model CreateModel(ushort start, int length) { // Create the model var model = new Model { Entries = length, Symbols = new ModelSymbol[length], TimeToReorder = 4, }; // Populate the symbol array for (int i = 0; i < length; i++) { model.Symbols[i] = new ModelSymbol { Symbol = (ushort)(start + i), CumulativeFrequency = (ushort)(length - 1), }; } return model; } /// /// Get the next symbol from a model /// private int GetSymbol(Model model) { int freq = GetFrequency(model.Symbols![0]!.CumulativeFrequency); int i; for (i = 1; i < model.Entries; i++) { if (model.Symbols[i]!.CumulativeFrequency <= freq) break; } int sym = model.Symbols![i - 1]!.Symbol; GetCode(model.Symbols![i - 1]!.CumulativeFrequency, model.Symbols![i]!.CumulativeFrequency, model.Symbols![0]!.CumulativeFrequency); UpdateModel(model, i); return sym; } /// /// Get the next code based on the frequencies /// private void GetCode(int prevFrequency, int currentFrequency, int totalFrequency) { uint range = (ushort)((CS_H - CS_L) + 1); CS_H = (ushort)(CS_L + (prevFrequency * range) / totalFrequency - 1); CS_L = (ushort)(CS_L + (currentFrequency * range) / totalFrequency); while (true) { if ((CS_L & 0x8000) != (CS_H & 0x8000)) { if ((CS_L & 0x4000) != 0 && (CS_H & 0x4000) == 0) { // Underflow case CS_C ^= 0x4000; CS_L &= 0x3FFF; CS_H |= 0x4000; } else { break; } } CS_L <<= 1; CS_H = (ushort)((CS_H << 1) | 1); CS_C = (ushort)((CS_C << 1) | _bitStream.ReadBit() ?? 0); } } /// /// Update the model after an encode or decode step /// private void UpdateModel(Model model, int lastUpdated) { // Update cumulative frequencies for (int i = 0; i < lastUpdated; i++) { var sym = model.Symbols![i]!; sym.CumulativeFrequency += 8; } // Decrement reordering time, if needed if (model.Symbols![0]!.CumulativeFrequency > 3800) model.TimeToReorder--; // If we haven't hit the reordering time if (model.TimeToReorder > 0) { // Update the cumulative frequencies for (int i = model.Entries - 1; i >= 0; i--) { // Divide with truncation by 2 var sym = model.Symbols![i]!; sym.CumulativeFrequency >>= 1; // If we are lower the next frequency if (i != 0 && sym.CumulativeFrequency <= model.Symbols![i + 1]!.CumulativeFrequency) sym.CumulativeFrequency = (ushort)(model.Symbols![i + 1]!.CumulativeFrequency + 1); } } // If we hit the reordering time else { // Calculate frequencies from cumulative frequencies for (int i = 0; i < model.Entries; i++) { if (i != model.Entries - 1) model.Symbols![i]!.CumulativeFrequency -= model.Symbols![i + 1]!.CumulativeFrequency; model.Symbols![i]!.CumulativeFrequency++; model.Symbols![i]!.CumulativeFrequency >>= 1; } // Sort frequencies in decreasing order for (int i = 0; i < model.Entries; i++) { for (int j = i + 1; j < model.Entries; j++) { if (model.Symbols![i]!.CumulativeFrequency < model.Symbols![j]!.CumulativeFrequency) { var temp = model.Symbols[i]; model.Symbols[i] = model.Symbols[j]; model.Symbols[j] = temp; } } } // Calculate cumulative frequencies from frequencies for (int i = model.Entries - 1; i >= 0; i--) { if (i != model.Entries - 1) model.Symbols![i]!.CumulativeFrequency += model.Symbols![i + 1]!.CumulativeFrequency; } // Reset the time to reorder model.TimeToReorder = 50; } } /// /// Get the frequency of a symbol based on its total frequency /// private ushort GetFrequency(ushort totalFrequency) { ulong range = (ulong)(((CS_H - CS_L) & 0xFFFF) + 1); ulong frequency = (ulong)((CS_C - CS_L + 1) * totalFrequency - 1) / range; return (ushort)(frequency & 0xFFFF); } } }