/*
* String Table Functions
*
* Copyright 2002-2004, Mike McCormack for CodeWeavers
* Copyright 2007 Robert Shearman for CodeWeavers
* Copyright 2010 Hans Leidekker for CodeWeavers
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
*/
using System;
using System.Text;
using LibGSF.Input;
using static LibMSI.LibmsiTable;
using static LibMSI.MsiPriv;
namespace LibMSI
{
internal class msistring
{
public ushort PersistentRefCount { get; set; }
public ushort NonPersistentRefCount { get; set; }
public string Str { get; set; }
}
internal class StringTable
{
#region Constants
private const int CP_ACP = 0;
#endregion
#region Properties
///
/// The number of strings
///
public int MaxCount { get; set; }
public int FreeSlot { get; set; }
public int CodePage { get; set; }
public int SortCount { get; set; }
///
/// An array of strings
///
public msistring[] Strings { get; set; }
///
/// Index
///
public int[] Sorted { get; set; }
#endregion
#region Functions
public void Destroy()
{
Strings = null;
Sorted = null;
}
public int AddString(string data, int len, ushort refcount, StringPersistence persistence)
{
if (string.IsNullOrEmpty(data) || data[0] == '\0')
return 0;
if (IdFromStringUTF8(data, out int n) == LibmsiResult.LIBMSI_RESULT_SUCCESS )
{
if (persistence == StringPersistence.StringPersistent)
Strings[n].PersistentRefCount += refcount;
else
Strings[n].NonPersistentRefCount += refcount;
return n;
}
n = FindFreeEntry();
if (n == -1)
return -1;
// Allocate a new string
if (len < 0)
len = data.Length;
SetEntry(n, data + '\0', refcount, persistence);
return n;
}
///
/// Find the string identified by an id - return null if there's none
///
public string LookupId(int id)
{
if (id == 0)
return szEmpty;
if (id >= MaxCount)
return null;
if (id != 0 && Strings[id].PersistentRefCount == 0 && Strings[id].NonPersistentRefCount == 0)
return null;
return Strings[id].Str;
}
/// String to find in the string table
/// Id of the string, if found
public LibmsiResult IdFromStringUTF8(string str, out int id)
{
int low = 0, high = SortCount - 1;
while (low <= high)
{
int i = (low + high) / 2;
int c = str.CompareTo(Strings[Sorted[i]].Str);
if (c < 0)
{
high = i - 1;
}
else if (c > 0)
{
low = i + 1;
}
else
{
id = Sorted[i];
return LibmsiResult.LIBMSI_RESULT_SUCCESS;
}
}
id = 0;
return LibmsiResult.LIBMSI_RESULT_INVALID_PARAMETER;
}
public static StringTable InitStringTable(out int bytes_per_strref)
{
bytes_per_strref = sizeof(ushort);
return InitStringTable(1, CP_ACP);
}
public static StringTable LoadStringTable(GsfInfile stg, out int bytes_per_strref)
{
int codepage;
int count, len;
ushort refs;
bytes_per_strref = 0;
LibmsiResult r = ReadStreamData(stg, szStringPool, out byte[] pool, out int poolsize);
if (r != LibmsiResult.LIBMSI_RESULT_SUCCESS)
return null;
r = ReadStreamData(stg, szStringData, out byte[] data, out int datasize);
if (r != LibmsiResult.LIBMSI_RESULT_SUCCESS)
return null;
if ((poolsize > 4) && (BitConverter.ToUInt16(pool, ((1) * 2)) & 0x8000) != 0)
bytes_per_strref = LONG_STR_BYTES;
else
bytes_per_strref = sizeof(ushort);
count = poolsize / 4;
if (poolsize > 4)
codepage = (int)(BitConverter.ToUInt16(pool, ((0) * 2)) | ( (BitConverter.ToUInt16(pool, ((1) * 2)) & ~0x8000) << 16));
else
codepage = CP_ACP;
StringTable st = InitStringTable(count, codepage);
if (st == null)
return null;
int offset = 0;
int n = 1;
int i = 1;
while (i < count)
{
// The string reference count is always the second word
refs = BitConverter.ToUInt16(pool, ((i * 2 + 1) * 2));
// Empty entries have two zeros, still have a string id
if (BitConverter.ToUInt16(pool, ((i * 2) * 2)) == 0 && refs == 0)
{
i++;
n++;
continue;
}
// If a string is over 64k, the previous string entry is made null
// and the high word of the length is inserted in the null string's
// reference count field.
if (BitConverter.ToUInt16(pool, ((i * 2) * 2)) == 0)
{
len = (int)((BitConverter.ToUInt16(pool, ((i * 2 + 3 * 2))) << 16) + BitConverter.ToUInt16(pool, ((i * 2 + 2) * 2)));
i += 2;
}
else
{
len = (int)BitConverter.ToUInt16(pool,((i * 2) * 2));
i += 1;
}
if ((offset + len) > datasize)
{
Console.Error.WriteLine("String table corrupt?");
break;
}
int s = st.AddString(n, data, offset, len, refs, StringPersistence .StringPersistent);
if (s != n)
Console.Error.WriteLine($"Failed to add string {n}");
n++;
offset += len;
}
if (datasize != offset)
Console.Error.WriteLine($"String table load failed! ({datasize} != {offset}), please report");
return st;
}
public LibmsiResult SaveStringTable(LibmsiDatabase db, out int bytes_per_strref)
{
// Construct the new table in memory first
StringTotalSize(out int datasize, out int poolsize);
byte[] data = new byte[datasize];
byte[] pool = new byte[poolsize];
int used = 0;
int codepage = CodePage;
pool[0] = (byte)(codepage & 0xff);
pool[1] = (byte)(codepage >> 8);
pool[2] = (byte)(codepage >> 16);
pool[3] = (byte)(codepage >> 24);
if (MaxCount > 0xffff)
{
pool[3] |= 0x80;
bytes_per_strref = LONG_STR_BYTES;
}
else
{
bytes_per_strref = sizeof(ushort);
}
int i = 1;
for (int n=1; n < MaxCount; n++ )
{
if (Strings[n].PersistentRefCount == 0)
{
pool[i * 4] = 0;
pool[i * 4 + 1] = 0;
pool[i * 4 + 2] = 0;
pool[i * 4 + 3] = 0;
i++;
continue;
}
int sz = datasize - used;
LibmsiResult s = StringId(n, ref data, ref used, ref sz);
if (s != LibmsiResult.LIBMSI_RESULT_SUCCESS)
{
Console.Error.WriteLine("Failed to fetch string");
sz = 0;
}
if (sz == 0)
{
pool[i * 4] = 0;
pool[i * 4 + 1] = 0;
pool[i * 4 + 2] = 0;
pool[i * 4 + 3] = 0;
i++;
continue;
}
if (sz >= 0x10000)
{
// Write a dummy entry, with the high part of the length
// in the reference count.
pool[i * 4] = 0;
pool[i * 4 + 1] = 0;
pool[i * 4 + 2] = (byte)(sz >> 16);
pool[i * 4 + 3] = (byte)(sz >> 24);
i++;
}
pool[i * 4] = (byte)sz;
pool[i * 4 + 1] = (byte)(sz >> 8);
pool[i * 4 + 2] = (byte)Strings[n].PersistentRefCount;
pool[i * 4 + 3] = (byte)(Strings[n].PersistentRefCount >> 8);
i++;
used += sz;
if (used > datasize)
{
Console.Error.WriteLine($"Oops overran {used} >= {datasize}");
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
}
}
if (used != datasize)
{
Console.Error.WriteLine($"Oops used {used} != datasize {datasize}");
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
}
// Write the streams
LibmsiResult r = WriteStreamData(db, szStringData, data, datasize);
if (r != LibmsiResult.LIBMSI_RESULT_SUCCESS)
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
r = WriteStreamData(db, szStringPool, pool, poolsize);
if (r != LibmsiResult.LIBMSI_RESULT_SUCCESS)
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
return LibmsiResult.LIBMSI_RESULT_SUCCESS;
}
public int GetCodePage() => CodePage;
public LibmsiResult SetCodePage(int codepage)
{
if (ValidateCodePage(codepage))
{
CodePage = codepage;
return LibmsiResult.LIBMSI_RESULT_SUCCESS;
}
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
}
#endregion
#region Utilities
private static bool ValidateCodePage(int codepage)
{
switch (codepage)
{
case CP_ACP:
case 37: case 424: case 437: case 500: case 737: case 775: case 850:
case 852: case 855: case 856: case 857: case 860: case 861: case 862:
case 863: case 864: case 865: case 866: case 869: case 874: case 875:
case 878: case 932: case 936: case 949: case 950: case 1006: case 1026:
case 1250: case 1251: case 1252: case 1253: case 1254: case 1255:
case 1256: case 1257: case 1258: case 1361:
case 10000: case 10006: case 10007: case 10029: case 10079: case 10081:
case 20127: case 20866: case 20932: case 21866: case 28591: case 28592:
case 28593: case 28594: case 28595: case 28596: case 28597: case 28598:
case 28599: case 28600: case 28603: case 28604: case 28605: case 28606:
case 65000: case 65001:
return true;
default:
return false;
}
}
private static StringTable InitStringTable(int entries, int codepage)
{
if (!ValidateCodePage(codepage))
return null;
if (entries < 1)
entries = 1;
return new StringTable
{
Strings = new msistring[entries],
Sorted = new int[entries],
MaxCount = entries,
FreeSlot = 1,
CodePage = codepage,
SortCount = 0,
};
}
private int FindFreeEntry()
{
if (FreeSlot != 0)
{
for (int i = FreeSlot; i < MaxCount; i++)
{
if (Strings[i].PersistentRefCount == 0 && Strings[i].NonPersistentRefCount == 0)
return i;
}
}
for (int i = 1; i < MaxCount; i++)
{
if (Strings[i].PersistentRefCount == 0 && Strings[i].NonPersistentRefCount == 0)
return i;
}
// Dynamically resize
int sz = MaxCount + 1 + MaxCount / 2;
msistring[] p = Strings;
Array.Resize(ref p, sz);
Strings = p;
int[] s = Sorted;
Array.Resize(ref s, sz);
Sorted = s;
FreeSlot = MaxCount;
MaxCount = sz;
if (Strings[FreeSlot].PersistentRefCount != 0 || Strings[FreeSlot].NonPersistentRefCount != 0)
Console.Error.WriteLine("Oops. expected freeslot to be free...");
return FreeSlot;
}
private int FindInsertIndex(int string_id)
{
int low = 0, high = SortCount - 1;
while (low <= high)
{
int i = (low + high) / 2;
int c = Strings[string_id].Str.CompareTo(Strings[Sorted[i]].Str);
if (c < 0)
high = i - 1;
else if (c > 0)
low = i + 1;
else
return -1; // Already exists
}
return high + 1;
}
private void InsertStringSorted(int string_id)
{
int i = FindInsertIndex(string_id);
if (i == -1)
return;
int[] temp = new int[SortCount - i];
Array.Copy(Sorted, i, temp, 0, SortCount - i);
Array.Copy(temp, 0, Sorted, i + 1, SortCount - i);
Sorted[i] = string_id;
SortCount++;
}
private void SetEntry(int n, string str, ushort refcount, StringPersistence persistence)
{
if (str == null)
return;
if (persistence == StringPersistence.StringPersistent)
{
Strings[n].PersistentRefCount = refcount;
Strings[n].NonPersistentRefCount = 0;
}
else
{
Strings[n].PersistentRefCount = 0;
Strings[n].NonPersistentRefCount = refcount;
}
Strings[n].Str = str;
InsertStringSorted(n);
if (n < MaxCount)
FreeSlot = n + 1;
}
private LibmsiResult IdFromString(byte[] buffer, int offset, out int id)
{
id = 0;
if (buffer == null || buffer[offset] == 0)
return LibmsiResult.LIBMSI_RESULT_SUCCESS;
int codepage = CodePage != 0 ? CodePage : Encoding.Default.CodePage;
Encoding cpconv = Encoding.GetEncoding(codepage);
string str = cpconv.GetString(new ReadOnlySpan(buffer, offset, buffer.Length - offset).ToArray());
if (str == null)
return LibmsiResult.LIBMSI_RESULT_INVALID_PARAMETER;
LibmsiResult r = IdFromStringUTF8(str, out id);
return r;
}
private int AddString(int n, byte[] data, int offset, int len, ushort refcount, StringPersistence persistence)
{
if (data == null || data[offset] == '\0')
return 0;
if (n > 0)
{
if (Strings[n].PersistentRefCount != 0 || Strings[n].NonPersistentRefCount != 0)
return -1;
}
else
{
if (IdFromString(data, offset, out n) == LibmsiResult.LIBMSI_RESULT_SUCCESS)
{
if (persistence == StringPersistence.StringPersistent)
Strings[n].PersistentRefCount += refcount;
else
Strings[n].NonPersistentRefCount += refcount;
return n;
}
n = FindFreeEntry();
if (n == -1)
return -1;
}
if (n < 1)
{
Console.Error.WriteLine($"Invalid index adding {data} ({n})");
return -1;
}
// Allocate a new string
int codepage = CodePage != 0 ? CodePage : Encoding.Default.CodePage;
Encoding cpconv = Encoding.GetEncoding(codepage);
string str = cpconv.GetString(data);
SetEntry(n, str, refcount, persistence);
return n;
}
/// Pointer to the string table
/// Id of the string to retrieve
/// Destination of the UTF8 string
///
/// Number of bytes available in the buffer on input
/// Number of bytes used on output
///
/// Returned string is not NUL-terminated.
private LibmsiResult StringId(int id, ref byte[] buffer, ref int offset, ref int sz)
{
string str_utf8 = LookupId(id);
if (str_utf8 == null)
return LibmsiResult.LIBMSI_RESULT_FUNCTION_FAILED;
int codepage = CodePage != 0 ? CodePage : Encoding.Default.CodePage;
Encoding cpconv = Encoding.GetEncoding(codepage);
byte[] str = Encoding.Convert(Encoding.UTF8, cpconv, Encoding.UTF8.GetBytes(str_utf8));
int len = str.Length;
if (sz < len)
{
sz = len;
return LibmsiResult.LIBMSI_RESULT_MORE_DATA;
}
sz = len;
Array.Copy(str, 0, buffer, offset, str.Length);
return LibmsiResult.LIBMSI_RESULT_SUCCESS;
}
private void StringTotalSize(out int datasize, out int poolsize)
{
if (Strings[0].Str != null || Strings[0].PersistentRefCount != 0 || Strings[0].NonPersistentRefCount != 0)
Console.Error.WriteLine("Oops. element 0 has a string");
int codepage = CodePage != 0 ? CodePage : Encoding.Default.CodePage;
poolsize = 4;
datasize = 0;
int holesize = 0;
for (int i = 1; i < MaxCount; i++)
{
if (Strings[i].PersistentRefCount == 0)
{
poolsize += 4;
}
else if (Strings[i].Str != null)
{
Encoding cpconv = Encoding.GetEncoding(codepage);
string str = cpconv.GetString(Encoding.UTF8.GetBytes(Strings[i].Str));
datasize += str.Length;
if (str.Length > 0xffff)
poolsize += 4;
poolsize += holesize + 4;
holesize = 0;
}
else
{
holesize += 4;
}
}
}
#endregion
}
}