Compare commits

...

11 Commits

Author SHA1 Message Date
Leonard Hecker
0090e2882e wip 2026-03-18 22:39:20 +01:00
Leonard Hecker
78f70111ae wip 2026-03-18 20:29:53 +01:00
Leonard Hecker
98399a6913 wip 2026-03-18 18:47:39 +01:00
Leonard Hecker
f5cdcd967e wip 2026-03-18 00:37:54 +01:00
Leonard Hecker
51d353baf3 wip 2026-03-17 21:27:15 +01:00
Leonard Hecker
89c83fb605 wip 2026-03-16 20:45:16 +01:00
Leonard Hecker
778c19f4c0 wip 2026-03-13 23:08:41 +01:00
Leonard Hecker
b0b32a356e wip 2026-03-13 21:34:57 +01:00
Leonard Hecker
88289a0858 wip 2026-03-13 20:20:56 +01:00
Leonard Hecker
8dc8e36764 wip 2026-03-13 20:08:03 +01:00
Leonard Hecker
aad1bde1c9 wip 2026-03-12 16:25:50 +01:00
25 changed files with 4785 additions and 85 deletions

View File

@@ -10,6 +10,12 @@
<Platform Name="x86" />
</Configurations>
<Folder Name="/Conhost/">
<Project Path="src/conpty/conpty-test/conpty-test.vcxproj" Id="10715020-e347-4d4e-a8f2-d11e4bced6ec">
<BuildType Solution="AuditMode|*" Project="Release" />
<BuildType Solution="Fuzzing|*" Project="Release" />
<Platform Solution="*|ARM64" Project="x64" />
</Project>
<Project Path="src/conpty/conpty.vcxproj" Id="23a66bb9-dccf-420c-b1a1-fa9ecfe7db65" />
<Project Path="src/host/exe/Host.EXE.vcxproj" Id="9cbd7dfa-1754-4a9d-93d7-857a9d17cb1b">
<BuildDependency Project="src/buffer/out/lib/bufferout.vcxproj" />
<BuildDependency Project="src/host/proxy/Host.Proxy.vcxproj" />

View File

@@ -95,7 +95,7 @@
<!-- For ALL build types-->
<PropertyGroup Label="Configuration">
<PlatformToolset>v143</PlatformToolset>
<PlatformToolset>v145</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<LinkIncremental>false</LinkIncremental>
<PreferredToolArchitecture>x64</PreferredToolArchitecture>

409
src/conpty/InputBuffer.cpp Normal file
View File

@@ -0,0 +1,409 @@
#include "pch.h"
#include "InputBuffer.h"
void InputBuffer::write(std::string_view text)
{
m_buf.append(text);
}
bool InputBuffer::hasData() const noexcept
{
return m_readPos < m_buf.size() || m_recordReadPos < m_records.size();
}
size_t InputBuffer::pendingEventCount() const noexcept
{
const auto queued = (m_records.size() > m_recordReadPos) ? (m_records.size() - m_recordReadPos) : 0;
const auto rawBytes = (m_buf.size() > m_readPos) ? (m_buf.size() - m_readPos) : 0;
return queued + rawBytes;
}
size_t InputBuffer::readRawText(char* dst, size_t dstCapacity)
{
const auto available = m_buf.size() - m_readPos;
const auto toCopy = std::min(available, dstCapacity);
if (toCopy > 0)
{
memcpy(dst, m_buf.data() + m_readPos, toCopy);
m_readPos += toCopy;
compact();
}
return toCopy;
}
size_t InputBuffer::readInputRecords(INPUT_RECORD* dst, size_t maxRecords, bool peek)
{
// Parse any new data from the raw buffer into m_records.
if (m_readPos < m_buf.size())
parseTokensToRecords();
const auto available = m_records.size() - m_recordReadPos;
const auto toCopy = std::min(available, maxRecords);
if (toCopy > 0)
{
memcpy(dst, m_records.data() + m_recordReadPos, toCopy * sizeof(INPUT_RECORD));
if (!peek)
{
m_recordReadPos += toCopy;
if (m_recordReadPos == m_records.size())
{
m_records.clear();
m_recordReadPos = 0;
}
}
}
if (!peek)
compact();
return toCopy;
}
void InputBuffer::flush()
{
m_buf.clear();
m_readPos = 0;
m_records.clear();
m_recordReadPos = 0;
}
void InputBuffer::compact()
{
if (m_readPos > 4096 || (m_readPos > 0 && m_readPos == m_buf.size()))
{
m_buf.erase(0, m_readPos);
m_readPos = 0;
}
}
// ============================================================================
// Token → INPUT_RECORD conversion
// ============================================================================
void InputBuffer::parseTokensToRecords()
{
const auto input = std::string_view{ m_buf.data() + m_readPos, m_buf.size() - m_readPos };
auto stream = m_parser.parse(input);
VtToken token;
while (stream.next(token))
{
switch (token.type)
{
case VtToken::Text:
handleText(token.payload);
break;
case VtToken::Ctrl:
handleCtrl(token.ch);
break;
case VtToken::Esc:
handleEsc(token.ch);
break;
case VtToken::SS3:
handleSs3(token.ch);
break;
case VtToken::Csi:
handleCsi(*token.csi);
break;
default:
// Osc/Dcs — irrelevant for input records, skip.
break;
}
}
// Advance m_readPos by what the parser consumed.
m_readPos += stream.offset();
}
void InputBuffer::handleText(std::string_view text)
{
// Decode UTF-8 codepoints and emit key events.
const auto* bytes = reinterpret_cast<const uint8_t*>(text.data());
size_t i = 0;
while (i < text.size())
{
uint32_t cp;
size_t seqLen;
const auto b = bytes[i];
if (b < 0x80) { cp = b; seqLen = 1; }
else if ((b & 0xE0) == 0xC0) { cp = b & 0x1F; seqLen = 2; }
else if ((b & 0xF0) == 0xE0) { cp = b & 0x0F; seqLen = 3; }
else if ((b & 0xF8) == 0xF0) { cp = b & 0x07; seqLen = 4; }
else { cp = 0xFFFD; seqLen = 1; i++; continue; }
if (i + seqLen > text.size()) break;
for (size_t j = 1; j < seqLen; j++)
{
const auto cont = bytes[i + j];
if ((cont & 0xC0) != 0x80) { cp = 0xFFFD; break; }
cp = (cp << 6) | (cont & 0x3F);
}
i += seqLen;
// Emit one or two INPUT_RECORDs (surrogate pair for supplementary plane).
if (cp <= 0xFFFF)
{
ParsedKey key;
key.ch = static_cast<wchar_t>(cp);
key.vk = LOBYTE(VkKeyScanW(key.ch));
if (key.vk == 0xFF) key.vk = 0;
key.scanCode = vkToScanCode(key.vk);
emitKey(key);
}
else if (cp <= 0x10FFFF)
{
ParsedKey key;
key.ch = static_cast<wchar_t>(0xD800 + ((cp - 0x10000) >> 10));
emitKey(key);
key.ch = static_cast<wchar_t>(0xDC00 + ((cp - 0x10000) & 0x3FF));
emitKey(key);
}
}
}
void InputBuffer::handleCtrl(char ch)
{
ParsedKey key;
switch (static_cast<uint8_t>(ch))
{
case '\r':
key.vk = VK_RETURN;
key.scanCode = vkToScanCode(VK_RETURN);
key.ch = L'\r';
break;
case '\n':
key.vk = VK_RETURN;
key.scanCode = vkToScanCode(VK_RETURN);
key.ch = L'\n';
key.modifiers = LEFT_CTRL_PRESSED;
break;
case '\t':
key.vk = VK_TAB;
key.scanCode = vkToScanCode(VK_TAB);
key.ch = L'\t';
break;
case '\b':
key.vk = VK_BACK;
key.scanCode = vkToScanCode(VK_BACK);
key.ch = L'\b';
break;
case 0x7F:
key.vk = VK_BACK;
key.scanCode = vkToScanCode(VK_BACK);
key.ch = L'\b';
break;
case 0x00:
key.vk = VK_SPACE;
key.scanCode = vkToScanCode(VK_SPACE);
key.ch = L'\0';
key.modifiers = LEFT_CTRL_PRESSED;
break;
default:
if (ch >= 0x01 && ch <= 0x1A)
{
key.ch = static_cast<wchar_t>(ch);
key.vk = static_cast<WORD>('A' + ch - 1);
key.scanCode = vkToScanCode(key.vk);
key.modifiers = LEFT_CTRL_PRESSED;
}
break;
}
emitKey(key);
}
void InputBuffer::handleEsc(char ch)
{
ParsedKey key;
if (ch == '\0')
{
// Bare ESC (timeout or end of buffer).
key.vk = VK_ESCAPE;
key.scanCode = vkToScanCode(VK_ESCAPE);
key.ch = 0x1B;
}
else if (ch >= 0x20 && ch <= 0x7E)
{
// Alt+character.
key.ch = static_cast<wchar_t>(ch);
key.vk = LOBYTE(VkKeyScanW(key.ch));
key.scanCode = vkToScanCode(key.vk);
key.modifiers = LEFT_ALT_PRESSED;
}
else
{
// Unexpected — emit bare ESC.
key.vk = VK_ESCAPE;
key.scanCode = vkToScanCode(VK_ESCAPE);
key.ch = 0x1B;
}
emitKey(key);
}
void InputBuffer::handleSs3(char ch)
{
static constexpr WORD keypadLut[] = {
VK_UP, // A
VK_DOWN, // B
VK_RIGHT,// C
VK_LEFT, // D
0, // E
VK_END, // F
0, // G
VK_HOME, // H
};
ParsedKey key;
key.modifiers = ENHANCED_KEY;
if (ch >= 'A' && ch <= 'H')
{
key.vk = keypadLut[ch - 'A'];
if (key.vk == 0) return;
key.scanCode = vkToScanCode(key.vk);
}
else if (ch >= 'P' && ch <= 'S')
{
key.vk = static_cast<WORD>(VK_F1 + (ch - 'P'));
key.scanCode = vkToScanCode(key.vk);
}
else
{
return; // Unknown SS3.
}
emitKey(key);
}
void InputBuffer::handleCsi(const VtCsi& csi)
{
// Win32 Input Mode: CSI Vk ; Sc ; Uc ; Kd ; Cs ; Rc _
if (csi.finalByte == '_' && csi.paramCount >= 4)
{
ParsedKey key;
key.vk = static_cast<WORD>(csi.params[0]);
key.scanCode = static_cast<WORD>(csi.params[1]);
key.ch = static_cast<wchar_t>(csi.params[2]);
key.keyDown = csi.params[3] != 0;
key.modifiers = (csi.paramCount >= 5) ? static_cast<DWORD>(csi.params[4]) : 0;
key.repeatCount = (csi.paramCount >= 6) ? static_cast<WORD>(csi.params[5]) : 1;
if (key.repeatCount == 0) key.repeatCount = 1;
key.isW32IM = true;
emitKey(key);
return;
}
// Cursor keys: CSI [1;mod] A-H
static constexpr WORD keypadLut[] = {
VK_UP, VK_DOWN, VK_RIGHT, VK_LEFT, 0, VK_END, 0, VK_HOME,
};
if (csi.finalByte >= 'A' && csi.finalByte <= 'H')
{
const auto vk = keypadLut[csi.finalByte - 'A'];
if (vk == 0) return;
ParsedKey key;
key.vk = vk;
key.scanCode = vkToScanCode(vk);
key.modifiers = ENHANCED_KEY;
if (csi.paramCount >= 2)
key.modifiers |= vtModifierToControlKeyState(csi.params[1]);
emitKey(key);
return;
}
// Shift+Tab: CSI Z
if (csi.finalByte == 'Z')
{
ParsedKey key;
key.vk = VK_TAB;
key.scanCode = vkToScanCode(VK_TAB);
key.ch = L'\t';
key.modifiers = SHIFT_PRESSED;
emitKey(key);
return;
}
// Generic keys: CSI {num} [;mod] ~
if (csi.finalByte == '~' && csi.paramCount >= 1)
{
static constexpr struct { uint16_t param; WORD vk; } genericLut[] = {
{ 1, VK_HOME }, { 2, VK_INSERT }, { 3, VK_DELETE },
{ 4, VK_END }, { 5, VK_PRIOR }, { 6, VK_NEXT },
{ 15, VK_F5 }, { 17, VK_F6 }, { 18, VK_F7 },
{ 19, VK_F8 }, { 20, VK_F9 }, { 21, VK_F10 },
{ 23, VK_F11 }, { 24, VK_F12 },
};
for (const auto& entry : genericLut)
{
if (csi.params[0] == entry.param)
{
ParsedKey key;
key.vk = entry.vk;
key.scanCode = vkToScanCode(entry.vk);
key.modifiers = ENHANCED_KEY;
if (csi.paramCount >= 2)
key.modifiers |= vtModifierToControlKeyState(csi.params[1]);
emitKey(key);
return;
}
}
}
// Unrecognised CSI — discard.
}
void InputBuffer::emitKey(const ParsedKey& key)
{
if (key.isW32IM)
{
INPUT_RECORD rec{};
rec.EventType = KEY_EVENT;
rec.Event.KeyEvent.bKeyDown = key.keyDown ? TRUE : FALSE;
rec.Event.KeyEvent.wRepeatCount = key.repeatCount;
rec.Event.KeyEvent.wVirtualKeyCode = key.vk;
rec.Event.KeyEvent.wVirtualScanCode = key.scanCode;
rec.Event.KeyEvent.uChar.UnicodeChar = key.ch;
rec.Event.KeyEvent.dwControlKeyState = key.modifiers;
m_records.push_back(rec);
return;
}
INPUT_RECORD down{};
down.EventType = KEY_EVENT;
down.Event.KeyEvent.bKeyDown = TRUE;
down.Event.KeyEvent.wRepeatCount = key.repeatCount;
down.Event.KeyEvent.wVirtualKeyCode = key.vk;
down.Event.KeyEvent.wVirtualScanCode = key.scanCode;
down.Event.KeyEvent.uChar.UnicodeChar = key.ch;
down.Event.KeyEvent.dwControlKeyState = key.modifiers;
m_records.push_back(down);
INPUT_RECORD up = down;
up.Event.KeyEvent.bKeyDown = FALSE;
m_records.push_back(up);
}
WORD InputBuffer::vkToScanCode(WORD vk)
{
return static_cast<WORD>(MapVirtualKeyW(vk, MAPVK_VK_TO_VSC));
}
DWORD InputBuffer::vtModifierToControlKeyState(uint16_t vtMod)
{
if (vtMod <= 1) return 0;
const auto flags = vtMod - 1;
DWORD state = 0;
if (flags & 1) state |= SHIFT_PRESSED;
if (flags & 2) state |= LEFT_ALT_PRESSED;
if (flags & 4) state |= LEFT_CTRL_PRESSED;
return state;
}

85
src/conpty/InputBuffer.h Normal file
View File

@@ -0,0 +1,85 @@
#pragma once
#include <string>
#include <string_view>
#include <vector>
#include "VtParser.h"
// InputBuffer: Text-based input buffer with an integrated VT parser.
//
// The hosting terminal writes UTF-8 VT sequences via write(). Consumers
// (console API handlers) dequeue data in one of two forms:
//
// 1. As raw VT text (ENABLE_VIRTUAL_TERMINAL_INPUT is set):
// readRawText() returns the raw bytes 1:1.
//
// 2. As INPUT_RECORDs (ENABLE_VIRTUAL_TERMINAL_INPUT is NOT set):
// readInputRecords() uses VtParser to tokenize sequences on-demand
// and converts them to key events:
// - Win32InputMode (W32IM, CSI ... _) for full INPUT_RECORD fidelity
// - Standard VT520 cursor/function keys (CSI A-D, CSI ~, SS3)
// - Plain text (each codepoint -> key down + key up pair)
//
// Thread safety: NOT thread-safe. All calls must be serialized by the caller.
class InputBuffer
{
public:
InputBuffer() = default;
// Append UTF-8 text to the input buffer.
void write(std::string_view text);
// True if there is any data available for reading.
bool hasData() const noexcept;
// Returns the number of pending INPUT_RECORDs (rough estimate).
size_t pendingEventCount() const noexcept;
// Read raw VT text (for ENABLE_VIRTUAL_TERMINAL_INPUT mode).
size_t readRawText(char* dst, size_t dstCapacity);
// Generate INPUT_RECORDs from the buffer (for legacy mode).
size_t readInputRecords(INPUT_RECORD* dst, size_t maxRecords, bool peek = false);
// Discard all buffered data.
void flush();
private:
struct ParsedKey
{
WORD vk = 0;
WORD scanCode = 0;
wchar_t ch = 0;
DWORD modifiers = 0;
WORD repeatCount = 1;
bool keyDown = true;
bool isW32IM = false;
};
// Convert VtTokens into INPUT_RECORDs.
void parseTokensToRecords();
// Token interpretation helpers.
void handleText(std::string_view text);
void handleCtrl(char ch);
void handleEsc(char ch);
void handleSs3(char ch);
void handleCsi(const VtCsi& csi);
void emitKey(const ParsedKey& key);
static WORD vkToScanCode(WORD vk);
static DWORD vtModifierToControlKeyState(uint16_t vtMod);
void compact();
std::string m_buf;
size_t m_readPos = 0;
VtParser m_parser;
std::vector<INPUT_RECORD> m_records;
size_t m_recordReadPos = 0;
};

455
src/conpty/Server.cpp Normal file
View File

@@ -0,0 +1,455 @@
#include "pch.h"
#include "Server.h"
#include <cassert>
#include <wil/nt_result_macros.h>
#define ProcThreadAttributeConsoleReference 10
#define PROC_THREAD_ATTRIBUTE_CONSOLE_REFERENCE \
ProcThreadAttributeValue(10, FALSE, TRUE, FALSE)
#pragma warning(disable : 4100 4189)
HRESULT WINAPI PtyCreateServer(REFIID riid, void** server)
try
{
if (server == nullptr)
{
return E_POINTER;
}
*server = nullptr;
if (riid == __uuidof(IPtyServer))
{
*server = static_cast<IPtyServer*>(new Server());
return S_OK;
}
return E_NOINTERFACE;
}
CATCH_RETURN()
Server::Server()
{
m_server = createHandle(nullptr, L"\\Device\\ConDrv\\Server", false, false);
m_inputAvailableEvent.create(wil::EventOptions::ManualReset);
CD_IO_SERVER_INFORMATION info{
.InputAvailableEvent = m_inputAvailableEvent.get(),
};
THROW_IF_FAILED(ioctl(IOCTL_CONDRV_SET_SERVER_INFORMATION, &info, sizeof(CD_IO_SERVER_INFORMATION), nullptr, 0));
}
#pragma region IUnknown
HRESULT Server::QueryInterface(const IID& riid, void** ppvObject)
{
if (ppvObject == nullptr)
{
return E_POINTER;
}
if (riid == __uuidof(IPtyServer) || riid == __uuidof(IUnknown))
{
*ppvObject = static_cast<IPtyServer*>(this);
AddRef();
return S_OK;
}
*ppvObject = nullptr;
return E_NOINTERFACE;
}
ULONG Server::AddRef()
{
return m_refCount.fetch_add(1, std::memory_order_relaxed) + 1;
}
ULONG Server::Release()
{
const auto count = m_refCount.fetch_sub(1, std::memory_order_relaxed) - 1;
if (count == 0)
{
delete this;
}
return count;
}
HRESULT Server::SetHost(IPtyHost* host)
{
m_host = host;
return S_OK;
}
#pragma endregion
#pragma region IPtyServer
HRESULT Server::WriteUTF8(PTY_UTF8_STRING input)
try
{
m_input.write({ input.data, input.length });
m_inputAvailableEvent.SetEvent();
drainPendingInputReads();
return S_OK;
}
CATCH_RETURN()
HRESULT Server::WriteUTF16(PTY_UTF16_STRING input)
try
{
// TODO
return S_OK;
}
CATCH_RETURN()
HRESULT Server::Run()
{
CD_IO_COMPLETE res{};
NTSTATUS status = STATUS_NO_RESPONSE;
while (true)
{
{
void* in = nullptr;
DWORD inLen = 0;
if (status != STATUS_NO_RESPONSE)
{
res.Identifier = m_req.Descriptor.Identifier;
res.IoStatus.Status = status;
res.Write.Data = const_cast<uint8_t*>(m_resData.data());
res.Write.Size = static_cast<ULONG>(m_resData.size());
in = &res;
inLen = sizeof(res);
}
status = ioctl(IOCTL_CONDRV_READ_IO, in, inLen, &m_req, sizeof(m_req));
m_resData = {};
m_resBuffer.clear();
}
if (!NT_SUCCESS(status))
{
if (status == STATUS_PIPE_DISCONNECTED)
{
return S_OK;
}
return HRESULT_FROM_NT(status);
}
try
{
switch (m_req.Descriptor.Function)
{
case CONSOLE_IO_CONNECT:
status = handleConnect();
break;
case CONSOLE_IO_DISCONNECT:
status = handleDisconnect();
break;
case CONSOLE_IO_CREATE_OBJECT:
status = handleCreateObject();
break;
case CONSOLE_IO_CLOSE_OBJECT:
status = handleCloseObject();
break;
case CONSOLE_IO_RAW_WRITE:
status = handleRawWrite();
break;
case CONSOLE_IO_RAW_READ:
status = handleRawRead();
break;
case CONSOLE_IO_USER_DEFINED:
status = handleUserDefined();
break;
case CONSOLE_IO_RAW_FLUSH:
status = handleRawFlush();
break;
default:
status = STATUS_UNSUCCESSFUL;
break;
}
}
catch (...)
{
status = wil::StatusFromCaughtException();
}
}
}
HRESULT Server::CreateProcessW(
LPCWSTR lpApplicationName,
LPWSTR lpCommandLine,
LPSECURITY_ATTRIBUTES lpProcessAttributes,
LPSECURITY_ATTRIBUTES lpThreadAttributes,
BOOL bInheritHandles,
DWORD dwCreationFlags,
LPVOID lpEnvironment,
LPCWSTR lpCurrentDirectory,
LPPROCESS_INFORMATION lpProcessInformation)
try
{
const auto proc = GetCurrentProcess();
uint64_t attrListBuffer[16];
STARTUPINFOEX si{};
si.StartupInfo.cb = sizeof(STARTUPINFOEXW);
si.lpAttributeList = reinterpret_cast<PPROC_THREAD_ATTRIBUTE_LIST>(&attrListBuffer[0]);
auto listSize = sizeof(attrListBuffer);
THROW_IF_WIN32_BOOL_FALSE(InitializeProcThreadAttributeList(si.lpAttributeList, 2, 0, &listSize));
const auto cleanup = wil::scope_exit([&] {
DeleteProcThreadAttributeList(si.lpAttributeList);
});
std::array<unique_nthandle, 4> handles{
createHandle(m_server.get(), L"\\Reference", false, true),
createHandle(m_server.get(), L"\\Input", true, true),
createHandle(m_server.get(), L"\\Output", true, true),
nullptr,
};
THROW_IF_WIN32_BOOL_FALSE(DuplicateHandle(proc, handles[2].get(), proc, handles[3].addressof(), 0, TRUE, DUPLICATE_SAME_ACCESS));
THROW_IF_WIN32_BOOL_FALSE(UpdateProcThreadAttribute(si.lpAttributeList, 0, PROC_THREAD_ATTRIBUTE_CONSOLE_REFERENCE, handles[0].addressof(), sizeof(HANDLE), nullptr, nullptr));
// bInheritHandles=TRUE is required in order to use STARTF_USESTDHANDLES.
// We can fake bInheritHandles=FALSE anyway, by using PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
if (!bInheritHandles)
{
// NOTE: UpdateProcThreadAttribute doesn't copy the handle values!
// The given lpValue pointers have to be valid until the call to CreateProcessW!
THROW_IF_WIN32_BOOL_FALSE(UpdateProcThreadAttribute(si.lpAttributeList, 0, PROC_THREAD_ATTRIBUTE_HANDLE_LIST, handles[1].addressof(), 3 * sizeof(HANDLE), nullptr, nullptr));
}
si.StartupInfo.dwFlags |= STARTF_USESTDHANDLES;
si.StartupInfo.hStdInput = handles[1].get();
si.StartupInfo.hStdOutput = handles[2].get();
si.StartupInfo.hStdError = handles[3].get();
return ::CreateProcessW(
lpApplicationName,
lpCommandLine,
lpProcessAttributes,
lpThreadAttributes,
TRUE,
dwCreationFlags | EXTENDED_STARTUPINFO_PRESENT,
lpEnvironment,
lpCurrentDirectory,
&si.StartupInfo,
lpProcessInformation);
}
CATCH_RETURN()
#pragma endregion
unique_nthandle Server::createHandle(HANDLE parent, const wchar_t* typeName, bool inherit, bool synchronous)
{
UNICODE_STRING name;
RtlInitUnicodeString(&name, typeName);
ULONG attrFlags = OBJ_CASE_INSENSITIVE;
WI_SetFlagIf(attrFlags, OBJ_INHERIT, inherit);
OBJECT_ATTRIBUTES attr;
InitializeObjectAttributes(&attr, &name, attrFlags, parent, nullptr);
ULONG options = 0;
WI_SetFlagIf(options, FILE_SYNCHRONOUS_IO_NONALERT, synchronous);
HANDLE handle;
IO_STATUS_BLOCK ioStatus;
THROW_IF_NTSTATUS_FAILED(NtCreateFile(
/* FileHandle */ &handle,
/* DesiredAccess */ FILE_GENERIC_READ | FILE_GENERIC_WRITE,
/* ObjectAttributes */ &attr,
/* IoStatusBlock */ &ioStatus,
/* AllocationSize */ nullptr,
/* FileAttributes */ 0,
/* ShareAccess */ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
/* CreateDisposition */ FILE_CREATE,
/* CreateOptions */ options,
/* EaBuffer */ nullptr,
/* EaLength */ 0));
return unique_nthandle{ handle };
}
NTSTATUS Server::ioctl(DWORD code, void* in, DWORD inLen, void* out, DWORD outLen) const
{
assert((in == nullptr) == (inLen == 0));
assert((out == nullptr) == (outLen == 0));
IO_STATUS_BLOCK iosb;
auto status = NtDeviceIoControlFile(m_server.get(), nullptr, nullptr, nullptr, &iosb, code, in, inLen, out, outLen);
if (status == STATUS_PENDING)
{
// Operation must complete before iosb is destroyed
status = NtWaitForSingleObject(m_server.get(), FALSE, nullptr);
if (NT_SUCCESS(status))
{
status = iosb.Status;
}
}
return status;
}
// Reads [part of] the input payload of the current message from the driver.
// Analogous to the OG ReadMessageInput() in csrutil.cpp.
//
// For CONSOLE_IO_CONNECT, offset is 0 and the payload is a CONSOLE_SERVER_MSG.
// For CONSOLE_IO_USER_DEFINED, offset would typically be past the message header.
void Server::readInput(ULONG offset, void* buffer, ULONG size)
{
CD_IO_OPERATION op{};
op.Identifier = m_req.Descriptor.Identifier;
op.Buffer.Offset = offset;
op.Buffer.Data = buffer;
op.Buffer.Size = size;
THROW_IF_NTSTATUS_FAILED(ioctl(IOCTL_CONDRV_READ_INPUT, &op, sizeof(op), nullptr, 0));
}
// Reads the trailing input payload (after the API descriptor struct) for
// USER_DEFINED messages. Analogous to OG GetInputBuffer() in csrutil.cpp.
std::vector<uint8_t> Server::readTrailingInput()
{
const auto readOffset = m_req.msgHeader.ApiDescriptorSize + sizeof(CONSOLE_MSG_HEADER);
if (readOffset > m_req.Descriptor.InputSize)
{
THROW_NTSTATUS(STATUS_UNSUCCESSFUL);
}
const auto size = m_req.Descriptor.InputSize - readOffset;
std::vector<uint8_t> buf(size);
if (size > 0)
{
readInput(static_cast<ULONG>(readOffset), buf.data(), static_cast<ULONG>(size));
}
return buf;
}
// Writes data back to the client's output buffer for the current message.
// Analogous to the IOCTL_CONDRV_WRITE_OUTPUT call in the OG ReleaseMessageBuffers() (csrutil.cpp).
//
// The driver matches the Identifier to the pending IO and copies data into
// the client's buffer at the specified offset.
void Server::writeOutput(ULONG offset, const void* buffer, ULONG size)
{
CD_IO_OPERATION op{};
op.Identifier = m_req.Descriptor.Identifier;
op.Buffer.Offset = offset;
op.Buffer.Data = const_cast<void*>(buffer);
op.Buffer.Size = size;
THROW_IF_NTSTATUS_FAILED(ioctl(IOCTL_CONDRV_WRITE_OUTPUT, &op, sizeof(op), nullptr, 0));
}
// Completes a message with the given completion descriptor.
// Analogous to the OG ConsoleComplete() in csrutil.cpp.
//
// This sends the reply out-of-band (via IOCTL_CONDRV_COMPLETE_IO) rather than
// piggybacking on the next IOCTL_CONDRV_READ_IO call. Used when the reply
// carries write data (e.g. CD_CONNECTION_INFORMATION for CONNECT).
void Server::completeIo(CD_IO_COMPLETE& completion)
{
THROW_IF_NTSTATUS_FAILED(ioctl(IOCTL_CONDRV_COMPLETE_IO, &completion, sizeof(completion), nullptr, 0));
}
// Validates a handle against expected type and access mask.
// Analogous to OG DereferenceIoHandle() in handle.cpp.
// Returns nullptr if not found, wrong type, or insufficient access.
Handle* Server::findHandle(ULONG_PTR obj, ULONG type, ACCESS_MASK access)
{
auto ptr = reinterpret_cast<Handle*>(obj);
for (auto& h : m_handles)
{
if (h.get() == ptr && (h->handleType & type) && (h->access & access) == access)
{
return ptr;
}
}
return nullptr;
}
// Validates an output handle and temporarily activates the associated screen
// buffer on the host if it differs from the currently active one.
// All IPtyHost calls that operate on screen buffer content (GetScreenBufferInfo,
// SetScreenBufferInfo, ReadBuffer, WriteUTF8/16) operate on the active buffer,
// so we ensure the right one is activated before each API call.
//
// Returns nullptr if the handle is invalid. The caller must check for nullptr
// and return STATUS_INVALID_HANDLE.
Handle* Server::activateOutputBuffer(ACCESS_MASK requiredAccess)
{
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_OUTPUT_HANDLE, requiredAccess);
if (!h)
return nullptr;
// If the handle's buffer differs from the active one, ask the host to switch.
if (h->screenBuffer != m_activeScreenBuffer && m_host)
{
m_host->ActivateBuffer(h->screenBuffer);
m_activeScreenBuffer = h->screenBuffer;
}
return h;
}
// VT output helpers.
// These accumulate VT sequences into m_vtBuf. Call vtFlush() to send them.
void Server::vtFlush()
{
if (!m_vtBuf.empty() && m_host)
{
m_host->WriteUTF8({ m_vtBuf.data(), m_vtBuf.size() });
m_vtBuf.clear();
}
}
void Server::vtAppend(std::string_view sv)
{
m_vtBuf.append(sv);
}
void Server::vtAppendFmt(const char* fmt, ...)
{
char buf[256];
va_list args;
va_start(args, fmt);
const auto n = vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
if (n > 0)
m_vtBuf.append(buf, static_cast<size_t>(n));
}
void Server::vtAppendUTF16(std::wstring_view str)
{
if (str.empty())
return;
const auto len = static_cast<int>(str.size());
const auto utf8Len = WideCharToMultiByte(CP_UTF8, 0, str.data(), len, nullptr, 0, nullptr, nullptr);
if (utf8Len <= 0)
return;
const auto offset = m_vtBuf.size();
m_vtBuf.resize(offset + utf8Len);
WideCharToMultiByte(CP_UTF8, 0, str.data(), len, m_vtBuf.data() + offset, utf8Len, nullptr, nullptr);
}
void Server::vtAppendTitle(std::wstring_view title)
{
vtAppend("\x1b]0;");
// Strip C0/C1 control characters to prevent OSC injection.
std::wstring safe;
safe.reserve(title.size());
for (const auto ch : title)
{
if (ch >= 0x20 && ch != 0x7F)
safe += ch;
}
vtAppendUTF16(safe);
vtAppend("\x1b\\");
}

319
src/conpty/Server.h Normal file
View File

@@ -0,0 +1,319 @@
#pragma once
#include <conpty.h>
#include "InputBuffer.h"
using unique_nthandle = wil::unique_any_handle_null<decltype(&::NtClose), ::NtClose>;
// Mirrors the payload of IOCTL_CONDRV_READ_IO.
// Unlike the OG CONSOLE_API_MSG (which has Complete/State/IoStatus before Descriptor),
// this struct starts at Descriptor because that's where the driver output begins.
struct CONSOLE_API_MSG
{
CD_IO_DESCRIPTOR Descriptor;
union
{
struct
{
CD_CREATE_OBJECT_INFORMATION CreateObject;
CONSOLE_CREATESCREENBUFFER_MSG CreateScreenBuffer;
};
struct
{
CONSOLE_MSG_HEADER msgHeader;
union
{
CONSOLE_MSG_BODY_L1 consoleMsgL1;
CONSOLE_MSG_BODY_L2 consoleMsgL2;
CONSOLE_MSG_BODY_L3 consoleMsgL3;
} u;
};
};
};
// Handle type flags, from the OG server.h.
// These are internal to the console server, not part of the condrv protocol.
#define CONSOLE_INPUT_HANDLE 0x00000001
#define CONSOLE_OUTPUT_HANDLE 0x00000002
struct Server;
// Handle tracking data, analogous to the OG CONSOLE_HANDLE_DATA.
// In the OG, handles are raw pointers to CONSOLE_HANDLE_DATA which contain
// share mode/access tracking and a pointer to the underlying object
// (INPUT_INFORMATION or SCREEN_INFORMATION).
//
// For output handles, `screenBuffer` is the opaque buffer ID returned by
// IPtyHost::CreateBuffer. NULL means the main (default) screen buffer.
struct Handle
{
ULONG handleType = 0; // CONSOLE_INPUT_HANDLE or CONSOLE_OUTPUT_HANDLE
ACCESS_MASK access = 0;
ULONG shareMode = 0;
void* screenBuffer = nullptr; // Only for CONSOLE_OUTPUT_HANDLE. NULL = main buffer.
};
// Per-client tracking data, analogous to the OG CONSOLE_PROCESS_HANDLE.
struct Client
{
DWORD processId = 0;
ULONG processGroupId = 0;
bool rootProcess = false;
ULONG_PTR inputHandle = 0;
ULONG_PTR outputHandle = 0;
};
// Console aliases, keyed by executable name.
// Each executable has a map of source → target alias pairs.
// OG: ALIAS_LIST_ENTRY chain in cmdline.cpp.
struct AliasStore
{
// Case-insensitive exe name → (case-insensitive source → target).
std::unordered_map<std::wstring, std::unordered_map<std::wstring, std::wstring>> exes;
void add(std::wstring_view exe, std::wstring_view source, std::wstring_view target);
void remove(std::wstring_view exe, std::wstring_view source);
const std::wstring* find(std::wstring_view exe, std::wstring_view source) const;
void expunge(std::wstring_view exe);
};
// Per-exe command history buffer.
// OG: COMMAND_HISTORY in cmdline.cpp / commandHistory.h.
struct CommandHistory
{
std::vector<std::wstring> commands;
ULONG maxCommands = 50;
bool allowDuplicates = true;
void add(std::wstring_view cmd);
void clear();
};
struct CommandHistoryStore
{
std::unordered_map<std::wstring, CommandHistory> exes;
ULONG defaultBufferSize = 50;
ULONG numberOfBuffers = 4;
DWORD flags = 0; // HISTORY_NO_DUP_FLAG = 1
CommandHistory& getOrCreate(std::wstring_view exe);
void expunge(std::wstring_view exe);
};
// A pending IO request that couldn't be completed immediately.
//
// Analogous to the OG CONSOLE_WAIT_BLOCK. When new input arrives (or output
// is unpaused), the server walks its pending-IO queues and calls `retry()`.
// If retry() returns true, the IO was satisfied and the block is removed.
// If false, it stays queued for a future attempt.
//
// retry() is responsible for calling writeOutput() and completeIo() itself.
struct PendingIO
{
LUID identifier{}; // ConDrv message identifier, for completeIo/writeOutput.
ULONG_PTR process = 0; // Descriptor.Process, for cleanup on disconnect.
// Retry callback. Called when conditions change (new input, output unpaused).
// Returns true if the IO was completed, false if it should stay queued.
// Signature: bool retry(Server& server)
std::function<bool(Server&)> retry;
};
struct Server : IPtyServer
{
Server();
virtual ~Server() = default;
#pragma region IUnknown
HRESULT QueryInterface(const IID& riid, void** ppvObject) override;
ULONG AddRef() override;
ULONG Release() override;
#pragma endregion
#pragma region IPtyServer
HRESULT SetHost(IPtyHost* host) override;
HRESULT WriteUTF8(PTY_UTF8_STRING input) override;
HRESULT WriteUTF16(PTY_UTF16_STRING input) override;
HRESULT Run() override;
HRESULT CreateProcessW(
LPCWSTR lpApplicationName,
LPWSTR lpCommandLine,
LPSECURITY_ATTRIBUTES lpProcessAttributes,
LPSECURITY_ATTRIBUTES lpThreadAttributes,
BOOL bInheritHandles,
DWORD dwCreationFlags,
LPVOID lpEnvironment,
LPCWSTR lpCurrentDirectory,
LPPROCESS_INFORMATION lpProcessInformation) override;
#pragma endregion
// Positive NTSTATUS sentinel: "no piggyback reply for this iteration."
// NTSTATUS is a signed LONG, so status <= 0 catches both success (0) and
// errors (negative), while this positive value skips the piggyback path.
static constexpr NTSTATUS STATUS_NO_RESPONSE = 1;
static unique_nthandle createHandle(HANDLE parent, const wchar_t* typeName, bool inherit, bool synchronous);
NTSTATUS ioctl(DWORD code, void* in, DWORD inLen, void* out, DWORD outLen) const;
// ConDrv communication helpers.
// These operate on m_req (the current message being processed).
void readInput(ULONG offset, void* buffer, ULONG size);
std::vector<uint8_t> readTrailingInput();
void writeOutput(ULONG offset, const void* buffer, ULONG size);
void completeIo(CD_IO_COMPLETE& completion);
// Handle validation. Returns nullptr if the handle doesn't exist,
// doesn't match the expected type, or lacks the required access.
Handle* findHandle(ULONG_PTR obj, ULONG type, ACCESS_MASK access);
// Message handlers.
// All handlers read from m_req and return NTSTATUS:
// - STATUS_SUCCESS / error → piggyback reply on next READ_IO
// - STATUS_NO_RESPONSE → handler already replied (completeIo) or deferred
NTSTATUS handleConnect();
NTSTATUS handleDisconnect();
NTSTATUS handleCreateObject();
NTSTATUS handleCloseObject();
NTSTATUS handleRawWrite();
NTSTATUS handleRawRead();
NTSTATUS handleUserDefined();
NTSTATUS handleRawFlush();
NTSTATUS handleUserDeprecatedApi();
NTSTATUS handleUserL1GetConsoleCP();
NTSTATUS handleUserL1GetConsoleMode();
NTSTATUS handleUserL1SetConsoleMode();
NTSTATUS handleUserL1GetNumberOfConsoleInputEvents();
NTSTATUS handleUserL1GetConsoleInput();
NTSTATUS handleUserL1ReadConsole();
NTSTATUS handleUserL1WriteConsole();
NTSTATUS handleUserL1GetConsoleLangId();
NTSTATUS handleUserL2FillConsoleOutput();
NTSTATUS handleUserL2GenerateConsoleCtrlEvent();
NTSTATUS handleUserL2SetConsoleActiveScreenBuffer();
NTSTATUS handleUserL2FlushConsoleInputBuffer();
NTSTATUS handleUserL2SetConsoleCP();
NTSTATUS handleUserL2GetConsoleCursorInfo();
NTSTATUS handleUserL2SetConsoleCursorInfo();
NTSTATUS handleUserL2GetConsoleScreenBufferInfo();
NTSTATUS handleUserL2SetConsoleScreenBufferInfo();
NTSTATUS handleUserL2SetConsoleScreenBufferSize();
NTSTATUS handleUserL2SetConsoleCursorPosition();
NTSTATUS handleUserL2GetLargestConsoleWindowSize();
NTSTATUS handleUserL2ScrollConsoleScreenBuffer();
NTSTATUS handleUserL2SetConsoleTextAttribute();
NTSTATUS handleUserL2SetConsoleWindowInfo();
NTSTATUS handleUserL2ReadConsoleOutputString();
NTSTATUS handleUserL2WriteConsoleInput();
NTSTATUS handleUserL2WriteConsoleOutput();
NTSTATUS handleUserL2WriteConsoleOutputString();
NTSTATUS handleUserL2ReadConsoleOutput();
NTSTATUS handleUserL2GetConsoleTitle();
NTSTATUS handleUserL2SetConsoleTitle();
NTSTATUS handleUserL3GetConsoleMouseInfo();
NTSTATUS handleUserL3GetConsoleFontSize();
NTSTATUS handleUserL3GetConsoleCurrentFont();
NTSTATUS handleUserL3SetConsoleDisplayMode();
NTSTATUS handleUserL3GetConsoleDisplayMode();
NTSTATUS handleUserL3AddConsoleAlias();
NTSTATUS handleUserL3GetConsoleAlias();
NTSTATUS handleUserL3GetConsoleAliasesLength();
NTSTATUS handleUserL3GetConsoleAliasExesLength();
NTSTATUS handleUserL3GetConsoleAliases();
NTSTATUS handleUserL3GetConsoleAliasExes();
NTSTATUS handleUserL3ExpungeConsoleCommandHistory();
NTSTATUS handleUserL3SetConsoleNumberOfCommands();
NTSTATUS handleUserL3GetConsoleCommandHistoryLength();
NTSTATUS handleUserL3GetConsoleCommandHistory();
NTSTATUS handleUserL3GetConsoleWindow();
NTSTATUS handleUserL3GetConsoleSelectionInfo();
NTSTATUS handleUserL3GetConsoleProcessList();
NTSTATUS handleUserL3GetConsoleHistory();
NTSTATUS handleUserL3SetConsoleHistory();
NTSTATUS handleUserL3SetConsoleCurrentFont();
// Complete pending IOs (called when state changes make progress possible).
void drainPendingInputReads();
void drainPendingOutputWrites();
void cancelPendingIOs(ULONG_PTR process);
// Helpers to pend a read or write IO. Both return STATUS_NO_RESPONSE.
NTSTATUS pendRead(std::function<bool(Server&)> retry);
NTSTATUS pendWrite(std::function<bool(Server&)> retry);
// Helper to complete a single pending IO with status + information.
void completePendingIo(const LUID& identifier, NTSTATUS status, ULONG_PTR information = 0);
// Client lookup by opaque handle value (the raw Client pointer cast to ULONG_PTR).
Client* findClient(ULONG_PTR handle);
// Handle management.
ULONG_PTR allocateHandle(ULONG handleType, ACCESS_MASK access, ULONG shareMode, void* screenBuffer = nullptr);
void freeHandle(ULONG_PTR handle);
// Resolve the output handle from the current message and ensure the
// correct buffer is temporarily activated on the host.
// Returns the Handle pointer, or nullptr (STATUS_INVALID_HANDLE) on failure.
Handle* activateOutputBuffer(ACCESS_MASK requiredAccess);
// VT output helpers.
void vtFlush();
void vtAppend(std::string_view sv);
void vtAppendFmt(_Printf_format_string_ const char* fmt, ...);
void vtAppendUTF16(std::wstring_view str);
void vtAppendTitle(std::wstring_view title);
std::atomic<ULONG> m_refCount{ 1 };
unique_nthandle m_server;
wil::com_ptr<IPtyHost> m_host;
wil::unique_event m_inputAvailableEvent;
// Per-message state. Set by Run() before dispatching, read by handlers.
CONSOLE_API_MSG m_req{};
// Piggyback response. m_resData points to either m_req.u (zero-copy, set by
// handleUserDefined for most APIs) or m_resBuffer.data() (for bulk responses).
std::span<const uint8_t> m_resData;
std::vector<uint8_t> m_resBuffer;
bool m_initialized = false;
bool m_outputPaused = false;
std::vector<std::unique_ptr<Client>> m_clients;
std::vector<std::unique_ptr<Handle>> m_handles;
std::deque<PendingIO> m_pendingReads;
std::deque<PendingIO> m_pendingWrites;
// Console state — code pages.
UINT m_inputCP = CP_UTF8;
UINT m_outputCP = CP_UTF8;
// Console state — mode flags (global, not per-buffer).
DWORD m_inputMode = ENABLE_PROCESSED_INPUT | ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
DWORD m_outputMode = ENABLE_PROCESSED_OUTPUT | ENABLE_WRAP_AT_EOL_OUTPUT;
// Console state — title (global, not per-buffer).
std::wstring m_title;
std::wstring m_originalTitle;
// The currently active screen buffer handle. NULL = main buffer.
void* m_activeScreenBuffer = nullptr;
// Alias and command history storage (global, not per-buffer).
AliasStore m_aliases;
CommandHistoryStore m_history;
// VT output accumulation buffer.
std::string m_vtBuf;
// Input buffer with integrated VT parser.
InputBuffer m_input;
};

View File

@@ -0,0 +1,247 @@
#include "pch.h"
#include "Server.h"
// Handles CONSOLE_IO_CONNECT messages.
//
// Protocol (from the OG ConsoleHandleConnectionRequest in srvinit.cpp):
// 1. Read the CONSOLE_SERVER_MSG from the client's input payload.
// 2. Validate string lengths and null termination.
// 3. Allocate per-process tracking data.
// 4. Mark the first connection as the root process.
// 5. Allocate IO handles for input and output.
// 6. Reply with CD_CONNECTION_INFORMATION via IOCTL_CONDRV_COMPLETE_IO.
//
// Returns STATUS_SUCCESS if the reply was sent via completeIo.
// Returns a failure NTSTATUS if the caller should reply inline with that status.
NTSTATUS Server::handleConnect()
{
// 1. Read the CONSOLE_SERVER_MSG payload from the client.
CONSOLE_SERVER_MSG data{};
readInput(0, &data, sizeof(data));
// 2. Validate that strings are within the buffers and null-terminated.
if ((data.ApplicationNameLength > (sizeof(data.ApplicationName) - sizeof(WCHAR))) ||
(data.TitleLength > (sizeof(data.Title) - sizeof(WCHAR))) ||
(data.CurrentDirectoryLength > (sizeof(data.CurrentDirectory) - sizeof(WCHAR))) ||
(data.ApplicationName[data.ApplicationNameLength / sizeof(WCHAR)] != UNICODE_NULL) ||
(data.Title[data.TitleLength / sizeof(WCHAR)] != UNICODE_NULL) ||
(data.CurrentDirectory[data.CurrentDirectoryLength / sizeof(WCHAR)] != UNICODE_NULL))
{
THROW_NTSTATUS(STATUS_INVALID_BUFFER_SIZE);
}
// 3. Allocate per-process tracking data.
const auto processId = static_cast<DWORD>(m_req.Descriptor.Process);
auto client = std::make_unique<Client>();
client->processId = processId;
client->processGroupId = data.ProcessGroupId;
// 4. The first connection is the root process (console owner).
client->rootProcess = !m_initialized;
// 5. Allocate IO handles for input and output.
client->inputHandle = allocateHandle(CONSOLE_INPUT_HANDLE, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE);
client->outputHandle = allocateHandle(CONSOLE_OUTPUT_HANDLE, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE);
if (!m_initialized)
{
m_initialized = true;
// Capture the initial title from the client's startup info.
m_title.assign(data.Title, data.TitleLength / sizeof(WCHAR));
m_originalTitle = m_title;
// NOTE: Screen buffer size/window size from CONSOLE_SERVER_MSG are
// NOT stored here — they live on the host. The host should be
// initialized with appropriate defaults before Run() is called.
}
auto clientPtr = client.get();
m_clients.push_back(std::move(client));
// 6. Build the reply with connection information.
CD_CONNECTION_INFORMATION connInfo{};
connInfo.Process = reinterpret_cast<ULONG_PTR>(clientPtr);
connInfo.Input = clientPtr->inputHandle;
connInfo.Output = clientPtr->outputHandle;
CD_IO_COMPLETE completion{};
completion.Identifier = m_req.Descriptor.Identifier;
completion.IoStatus.Status = STATUS_SUCCESS;
completion.IoStatus.Information = sizeof(CD_CONNECTION_INFORMATION);
completion.Write.Data = &connInfo;
completion.Write.Size = sizeof(CD_CONNECTION_INFORMATION);
completeIo(completion);
return STATUS_NO_RESPONSE;
}
// Handles CONSOLE_IO_DISCONNECT messages.
//
// Protocol (from the OG ConsoleClientDisconnectRoutine / RemoveConsole in srvinit.cpp):
// 1. Look up the process from Descriptor.Process (the opaque value we set in CONNECT).
// 2. Free per-process data (handles, command history, etc.).
//
// The caller always replies with STATUS_SUCCESS inline.
NTSTATUS Server::handleDisconnect()
{
auto client = findClient(m_req.Descriptor.Process);
if (!client)
{
return STATUS_SUCCESS;
}
// Cancel any pending IOs from this client.
cancelPendingIOs(m_req.Descriptor.Process);
// Free the client's IO handles, mirroring OG FreeProcessData which calls
// ConsoleCloseHandle on InputHandle and OutputHandle.
if (client->inputHandle)
{
freeHandle(client->inputHandle);
}
if (client->outputHandle)
{
freeHandle(client->outputHandle);
}
std::erase_if(m_clients, [client](const auto& c) { return c.get() == client; });
return STATUS_SUCCESS;
}
// Handle management.
//
// Analogous to OG AllocateIoHandle (handle.cpp). The OG creates a CONSOLE_HANDLE_DATA
// with share/access tracking and a pointer to the underlying console object.
// We create a lightweight Handle and return its pointer cast to ULONG_PTR.
ULONG_PTR Server::allocateHandle(ULONG handleType, ACCESS_MASK access, ULONG shareMode, void* screenBuffer)
{
auto h = std::make_unique<Handle>();
h->handleType = handleType;
h->access = access;
h->shareMode = shareMode;
h->screenBuffer = screenBuffer;
auto ptr = reinterpret_cast<ULONG_PTR>(h.get());
m_handles.push_back(std::move(h));
return ptr;
}
// Analogous to OG ConsoleCloseHandle → FreeConsoleHandle (handle.cpp).
void Server::freeHandle(ULONG_PTR handle)
{
auto ptr = reinterpret_cast<Handle*>(handle);
// If this is an output handle with a non-null screen buffer (i.e. not the
// main buffer), check if any other handle still references it. If not,
// tell the host to release it.
if (ptr && (ptr->handleType & CONSOLE_OUTPUT_HANDLE) && ptr->screenBuffer)
{
const auto buf = ptr->screenBuffer;
bool otherRef = false;
for (auto& h : m_handles)
{
if (h.get() != ptr && h->screenBuffer == buf)
{
otherRef = true;
break;
}
}
if (!otherRef && m_host)
{
m_host->ReleaseBuffer(buf);
if (m_activeScreenBuffer == buf)
m_activeScreenBuffer = nullptr;
}
}
std::erase_if(m_handles, [ptr](const auto& h) { return h.get() == ptr; });
}
// Handles CONSOLE_IO_CREATE_OBJECT messages.
//
// Protocol (from OG ConsoleCreateObject in srvinit.cpp):
// 1. Read CD_CREATE_OBJECT_INFORMATION from the message (already in msg.CreateObject).
// 2. Resolve CD_IO_OBJECT_TYPE_GENERIC based on DesiredAccess.
// 3. Allocate a handle of the appropriate type.
// 4. Reply via completeIo with the handle value in IoStatus.Information.
NTSTATUS Server::handleCreateObject()
{
auto& info = m_req.CreateObject;
// Resolve generic object type based on desired access, matching OG behavior.
if (info.ObjectType == CD_IO_OBJECT_TYPE_GENERIC)
{
if ((info.DesiredAccess & (GENERIC_READ | GENERIC_WRITE)) == GENERIC_READ)
{
info.ObjectType = CD_IO_OBJECT_TYPE_CURRENT_INPUT;
}
else if ((info.DesiredAccess & (GENERIC_READ | GENERIC_WRITE)) == GENERIC_WRITE)
{
info.ObjectType = CD_IO_OBJECT_TYPE_CURRENT_OUTPUT;
}
}
ULONG_PTR handle = 0;
switch (info.ObjectType)
{
case CD_IO_OBJECT_TYPE_CURRENT_INPUT:
handle = allocateHandle(CONSOLE_INPUT_HANDLE, info.DesiredAccess, info.ShareMode);
break;
case CD_IO_OBJECT_TYPE_CURRENT_OUTPUT:
handle = allocateHandle(CONSOLE_OUTPUT_HANDLE, info.DesiredAccess, info.ShareMode);
break;
case CD_IO_OBJECT_TYPE_NEW_OUTPUT:
{
// Create a new screen buffer via the host.
void* screenBuffer = nullptr;
const auto hr = m_host->CreateBuffer(&screenBuffer);
if (FAILED(hr))
THROW_NTSTATUS(STATUS_NO_MEMORY);
handle = allocateHandle(CONSOLE_OUTPUT_HANDLE, info.DesiredAccess, info.ShareMode, screenBuffer);
break;
}
default:
THROW_NTSTATUS(STATUS_INVALID_PARAMETER);
}
// Reply with the handle value in IoStatus.Information.
// The driver stores this and echoes it back in Descriptor.Object for future IO.
CD_IO_COMPLETE completion{};
completion.Identifier = m_req.Descriptor.Identifier;
completion.IoStatus.Status = STATUS_SUCCESS;
completion.IoStatus.Information = handle;
completeIo(completion);
return STATUS_NO_RESPONSE;
}
// Handles CONSOLE_IO_CLOSE_OBJECT messages.
//
// Protocol (from OG SrvCloseHandle in stream.cpp):
// 1. Descriptor.Object contains the opaque handle value.
// 2. Close/free the handle.
//
// The caller replies with STATUS_SUCCESS inline.
NTSTATUS Server::handleCloseObject()
{
freeHandle(m_req.Descriptor.Object);
return STATUS_SUCCESS;
}
Client* Server::findClient(ULONG_PTR handle)
{
auto ptr = reinterpret_cast<Client*>(handle);
for (auto& c : m_clients)
{
if (c.get() == ptr)
{
return ptr;
}
}
return nullptr;
}

142
src/conpty/Server.msg.cpp Normal file
View File

@@ -0,0 +1,142 @@
#include "pch.h"
#include "Server.h"
#define CONSOLE_API_STRUCT(Routine, Struct) { &Routine, sizeof(Struct) }
#define CONSOLE_API_NO_PARAMETER(Routine) { &Routine, 0 }
struct ApiDescriptor
{
NTSTATUS(Server::*routine)();
size_t requiredSize;
};
struct ApiDescriptorLayer
{
const ApiDescriptor* descriptors;
size_t count;
};
static constexpr ApiDescriptor s_ptyServerUserL1[] = {
CONSOLE_API_STRUCT(Server::handleUserL1GetConsoleCP, CONSOLE_GETCP_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1GetConsoleMode, CONSOLE_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1SetConsoleMode, CONSOLE_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1GetNumberOfConsoleInputEvents, CONSOLE_GETNUMBEROFINPUTEVENTS_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1GetConsoleInput, CONSOLE_GETCONSOLEINPUT_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1ReadConsole, CONSOLE_READCONSOLE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL1WriteConsole, CONSOLE_WRITECONSOLE_MSG),
CONSOLE_API_NO_PARAMETER(Server::handleUserDeprecatedApi), // SrvConsoleNotifyLastClose
CONSOLE_API_STRUCT(Server::handleUserL1GetConsoleLangId, CONSOLE_LANGID_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_MAPBITMAP_MSG),
};
static constexpr ApiDescriptor s_ptyServerUserL2[] = {
CONSOLE_API_STRUCT(Server::handleUserL2FillConsoleOutput, CONSOLE_FILLCONSOLEOUTPUT_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2GenerateConsoleCtrlEvent, CONSOLE_CTRLEVENT_MSG),
CONSOLE_API_NO_PARAMETER(Server::handleUserL2SetConsoleActiveScreenBuffer),
CONSOLE_API_NO_PARAMETER(Server::handleUserL2FlushConsoleInputBuffer),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleCP, CONSOLE_SETCP_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2GetConsoleCursorInfo, CONSOLE_GETCURSORINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleCursorInfo, CONSOLE_SETCURSORINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2GetConsoleScreenBufferInfo, CONSOLE_SCREENBUFFERINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleScreenBufferInfo, CONSOLE_SCREENBUFFERINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleScreenBufferSize, CONSOLE_SETSCREENBUFFERSIZE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleCursorPosition, CONSOLE_SETCURSORPOSITION_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2GetLargestConsoleWindowSize, CONSOLE_GETLARGESTWINDOWSIZE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2ScrollConsoleScreenBuffer, CONSOLE_SCROLLSCREENBUFFER_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleTextAttribute, CONSOLE_SETTEXTATTRIBUTE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleWindowInfo, CONSOLE_SETWINDOWINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2ReadConsoleOutputString, CONSOLE_READCONSOLEOUTPUTSTRING_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2WriteConsoleInput, CONSOLE_WRITECONSOLEINPUT_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2WriteConsoleOutput, CONSOLE_WRITECONSOLEOUTPUT_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2WriteConsoleOutputString, CONSOLE_WRITECONSOLEOUTPUTSTRING_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2ReadConsoleOutput, CONSOLE_READCONSOLEOUTPUT_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2GetConsoleTitle, CONSOLE_GETTITLE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL2SetConsoleTitle, CONSOLE_SETTITLE_MSG),
};
static constexpr ApiDescriptor s_ptyServerUserL3[] = {
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_GETNUMBEROFFONTS_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleMouseInfo, CONSOLE_GETMOUSEINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_GETFONTINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleFontSize, CONSOLE_GETFONTSIZE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleCurrentFont, CONSOLE_CURRENTFONT_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETFONT_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETICON_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_INVALIDATERECT_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_VDM_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETCURSOR_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SHOWCURSOR_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_MENUCONTROL_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETPALETTE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3SetConsoleDisplayMode, CONSOLE_SETDISPLAYMODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_REGISTERVDM_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_GETHARDWARESTATE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETHARDWARESTATE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleDisplayMode, CONSOLE_GETDISPLAYMODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3AddConsoleAlias, CONSOLE_ADDALIAS_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleAlias, CONSOLE_GETALIAS_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleAliasesLength, CONSOLE_GETALIASESLENGTH_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleAliasExesLength, CONSOLE_GETALIASEXESLENGTH_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleAliases, CONSOLE_GETALIASES_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleAliasExes, CONSOLE_GETALIASEXES_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3ExpungeConsoleCommandHistory, CONSOLE_EXPUNGECOMMANDHISTORY_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3SetConsoleNumberOfCommands, CONSOLE_SETNUMBEROFCOMMANDS_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleCommandHistoryLength, CONSOLE_GETCOMMANDHISTORYLENGTH_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleCommandHistory, CONSOLE_GETCOMMANDHISTORY_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETKEYSHORTCUTS_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETMENUCLOSE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_GETKEYBOARDLAYOUTNAME_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleWindow, CONSOLE_GETCONSOLEWINDOW_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_CHAR_TYPE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_LOCAL_EUDC_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_CURSOR_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_CURSOR_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_REGISTEROS2_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_SETOS2OEMFORMAT_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_NLS_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserDeprecatedApi, CONSOLE_NLS_MODE_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleSelectionInfo, CONSOLE_GETSELECTIONINFO_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleProcessList, CONSOLE_GETCONSOLEPROCESSLIST_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3GetConsoleHistory, CONSOLE_HISTORY_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3SetConsoleHistory, CONSOLE_HISTORY_MSG),
CONSOLE_API_STRUCT(Server::handleUserL3SetConsoleCurrentFont, CONSOLE_CURRENTFONT_MSG),
};
static constexpr ApiDescriptorLayer s_ptyServerUserLayers[] = {
{ s_ptyServerUserL1, std::size(s_ptyServerUserL1) },
{ s_ptyServerUserL2, std::size(s_ptyServerUserL2) },
{ s_ptyServerUserL3, std::size(s_ptyServerUserL3) },
};
NTSTATUS Server::handleUserDefined()
{
const auto layer = (m_req.msgHeader.ApiNumber >> 24) - 1;
const auto idx = m_req.msgHeader.ApiNumber & 0xffffff;
if (layer >= std::size(s_ptyServerUserLayers) || idx >= s_ptyServerUserLayers[layer].count)
{
THROW_NTSTATUS(STATUS_ILLEGAL_FUNCTION);
}
const auto& descriptor = s_ptyServerUserLayers[layer].descriptors[idx];
if (m_req.Descriptor.InputSize < sizeof(CONSOLE_MSG_HEADER) ||
m_req.msgHeader.ApiDescriptorSize > sizeof(m_req.u) ||
m_req.msgHeader.ApiDescriptorSize > (m_req.Descriptor.InputSize - sizeof(CONSOLE_MSG_HEADER)) ||
m_req.msgHeader.ApiDescriptorSize < descriptor.requiredSize)
{
THROW_NTSTATUS(STATUS_ILLEGAL_FUNCTION);
}
// Pre-configure the response span to point at m_req.u (zero-copy).
// Handlers that write results into m_req.u have them sent back automatically.
// Mirrors the OG ConsoleDispatchRequest setting Complete.Write before calling the API.
m_resData = { reinterpret_cast<const uint8_t*>(&m_req.u), m_req.msgHeader.ApiDescriptorSize };
return (this->*descriptor.routine)();
}
NTSTATUS Server::handleUserDeprecatedApi()
{
return STATUS_UNSUCCESSFUL;
}

View File

@@ -0,0 +1,383 @@
#include "pch.h"
#include "Server.h"
// L1: GetConsoleCP / GetConsoleOutputCP
// OG: SrvGetConsoleCP in getset.cpp — no handle validation.
// Reads a->Output to decide input vs output CP, returns a->CodePage.
NTSTATUS Server::handleUserL1GetConsoleCP()
{
auto& a = m_req.u.consoleMsgL1.GetConsoleCP;
a.CodePage = a.Output ? m_outputCP : m_inputCP;
return STATUS_SUCCESS;
}
// L1: GetConsoleMode
// OG: SrvGetConsoleMode in getset.cpp
// DereferenceIoHandle(obj, INPUT|OUTPUT, GENERIC_READ)
// Returns a->Mode with the handle's mode flags.
NTSTATUS Server::handleUserL1GetConsoleMode()
{
auto& a = m_req.u.consoleMsgL1.GetConsoleMode;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE | CONSOLE_OUTPUT_HANDLE, GENERIC_READ);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
a.Mode = (h->handleType & CONSOLE_INPUT_HANDLE) ? m_inputMode : m_outputMode;
return STATUS_SUCCESS;
}
// L1: SetConsoleMode
// OG: SrvSetConsoleMode in getset.cpp
// DereferenceIoHandle(obj, INPUT|OUTPUT, GENERIC_WRITE)
// Reads a->Mode, validates flags, applies to handle's buffer.
NTSTATUS Server::handleUserL1SetConsoleMode()
{
auto& a = m_req.u.consoleMsgL1.SetConsoleMode;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE | CONSOLE_OUTPUT_HANDLE, GENERIC_WRITE);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
if (h->handleType & CONSOLE_INPUT_HANDLE)
{
m_inputMode = a.Mode;
}
else
{
const auto oldMode = m_outputMode;
m_outputMode = a.Mode;
// Emit DECAWM when the wrap-at-EOL flag changes.
if ((oldMode ^ m_outputMode) & ENABLE_WRAP_AT_EOL_OUTPUT)
{
vtAppendFmt("\x1b[?7%c", (m_outputMode & ENABLE_WRAP_AT_EOL_OUTPUT) ? 'h' : 'l');
vtFlush();
}
}
return STATUS_SUCCESS;
}
// L1: GetNumberOfConsoleInputEvents
// OG: SrvGetConsoleNumberOfInputEvents in getset.cpp
// DereferenceIoHandle(obj, INPUT, GENERIC_READ)
// Returns a->ReadyEvents.
NTSTATUS Server::handleUserL1GetNumberOfConsoleInputEvents()
{
auto& a = m_req.u.consoleMsgL1.GetNumberOfConsoleInputEvents;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_READ);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
a.ReadyEvents = static_cast<ULONG>(m_input.pendingEventCount());
return STATUS_SUCCESS;
}
// L1: GetConsoleInput (ReadConsoleInput / PeekConsoleInput)
// OG: SrvGetConsoleInput in directio.cpp
// DereferenceIoHandle(obj, INPUT, GENERIC_READ)
// Reads a->Flags (CONSOLE_READ_NOREMOVE = peek), a->Unicode.
// Writes INPUT_RECORD array to output buffer, returns a->NumRecords.
// Can block (ReplyPending) if no data and CONSOLE_READ_NOWAIT not set.
NTSTATUS Server::handleUserL1GetConsoleInput()
{
auto& a = m_req.u.consoleMsgL1.GetConsoleInput;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_READ);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
const bool peek = (a.Flags & CONSOLE_READ_NOREMOVE) != 0;
const bool nowait = (a.Flags & CONSOLE_READ_NOWAIT) != 0;
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto maxOutputBytes = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
const auto maxRecords = maxOutputBytes / sizeof(INPUT_RECORD);
// Try to satisfy immediately.
if (maxRecords > 0 && m_input.hasData())
{
std::vector<INPUT_RECORD> records(maxRecords);
const auto n = m_input.readInputRecords(records.data(), maxRecords, peek);
if (n > 0)
{
writeOutput(outputOffset, records.data(), static_cast<ULONG>(n * sizeof(INPUT_RECORD)));
a.NumRecords = static_cast<ULONG>(n);
if (!peek && !m_input.hasData())
m_inputAvailableEvent.ResetEvent();
return STATUS_SUCCESS;
}
}
// No data. If NOWAIT or peek, return immediately with 0 records.
if (nowait || peek)
{
a.NumRecords = 0;
return STATUS_SUCCESS;
}
// Block: defer until input arrives.
// Capture the ID and parameters for the retry lambda.
const auto id = m_req.Descriptor.Identifier;
return pendRead([id, outputOffset, maxRecords](Server& self) -> bool {
if (!self.m_input.hasData())
return false;
std::vector<INPUT_RECORD> records(maxRecords);
const auto n = self.m_input.readInputRecords(records.data(), maxRecords, false);
if (n == 0)
return false;
// Write the records to the client's output buffer.
CD_IO_OPERATION op{};
op.Identifier = id;
op.Buffer.Offset = outputOffset;
op.Buffer.Data = records.data();
op.Buffer.Size = static_cast<ULONG>(n * sizeof(INPUT_RECORD));
THROW_IF_NTSTATUS_FAILED(self.ioctl(IOCTL_CONDRV_WRITE_OUTPUT, &op, sizeof(op), nullptr, 0));
// Build the reply. We need to write back the API descriptor with NumRecords,
// then complete the IO.
CONSOLE_GETCONSOLEINPUT_MSG reply{};
reply.NumRecords = static_cast<ULONG>(n);
CD_IO_COMPLETE completion{};
completion.Identifier = id;
completion.IoStatus.Status = STATUS_SUCCESS;
completion.Write.Data = &reply;
completion.Write.Size = sizeof(reply);
self.completeIo(completion);
return true;
});
}
// Helper: Try to read raw text from m_input, convert encoding, and write to
// the client's output buffer at `outputOffset`. Returns bytes written to client,
// or 0 if no data is available.
static ULONG tryReadRawTextToClient(
Server& self,
const LUID& identifier,
ULONG outputOffset,
ULONG maxOutputBytes,
bool unicode,
UINT inputCP)
{
if (maxOutputBytes == 0 || !self.m_input.hasData())
return 0;
// Read up to maxOutputBytes of raw UTF-8 from the input buffer.
std::vector<char> raw(maxOutputBytes);
const auto n = self.m_input.readRawText(raw.data(), maxOutputBytes);
if (n == 0)
return 0;
ULONG bytesWritten = 0;
if (unicode)
{
// UTF-8 → UTF-16.
const auto wideLen = MultiByteToWideChar(CP_UTF8, 0, raw.data(), static_cast<int>(n), nullptr, 0);
if (wideLen > 0)
{
const auto maxChars = maxOutputBytes / sizeof(WCHAR);
std::vector<wchar_t> wide(wideLen);
MultiByteToWideChar(CP_UTF8, 0, raw.data(), static_cast<int>(n), wide.data(), wideLen);
const auto outChars = std::min(static_cast<size_t>(wideLen), maxChars);
bytesWritten = static_cast<ULONG>(outChars * sizeof(WCHAR));
CD_IO_OPERATION op{};
op.Identifier = identifier;
op.Buffer.Offset = outputOffset;
op.Buffer.Data = wide.data();
op.Buffer.Size = bytesWritten;
THROW_IF_NTSTATUS_FAILED(self.ioctl(IOCTL_CONDRV_WRITE_OUTPUT, &op, sizeof(op), nullptr, 0));
}
}
else
{
// UTF-8 → UTF-16 → output code page.
const auto wideLen = MultiByteToWideChar(CP_UTF8, 0, raw.data(), static_cast<int>(n), nullptr, 0);
if (wideLen > 0)
{
std::vector<wchar_t> wide(wideLen);
MultiByteToWideChar(CP_UTF8, 0, raw.data(), static_cast<int>(n), wide.data(), wideLen);
const auto ansiLen = WideCharToMultiByte(inputCP, 0, wide.data(), wideLen, nullptr, 0, nullptr, nullptr);
if (ansiLen > 0)
{
const auto outBytes = std::min(static_cast<size_t>(ansiLen), static_cast<size_t>(maxOutputBytes));
std::vector<char> ansi(ansiLen);
WideCharToMultiByte(inputCP, 0, wide.data(), wideLen, ansi.data(), ansiLen, nullptr, nullptr);
bytesWritten = static_cast<ULONG>(outBytes);
CD_IO_OPERATION op{};
op.Identifier = identifier;
op.Buffer.Offset = outputOffset;
op.Buffer.Data = ansi.data();
op.Buffer.Size = bytesWritten;
THROW_IF_NTSTATUS_FAILED(self.ioctl(IOCTL_CONDRV_WRITE_OUTPUT, &op, sizeof(op), nullptr, 0));
}
}
}
return bytesWritten;
}
// L1: ReadConsole
// OG: SrvReadConsole in stream.cpp
// DereferenceIoHandle(obj, INPUT, GENERIC_READ)
// Reads a->Unicode, a->ProcessControlZ, a->ExeNameLength, a->InitialNumBytes, a->CtrlWakeupMask.
// Writes text to output buffer (GetAugmentedOutputBuffer, factor=2 for ANSI->Unicode).
// Returns a->NumBytes, a->ControlKeyState.
// Can block (ReplyPending) waiting for user input.
NTSTATUS Server::handleUserL1ReadConsole()
{
auto& a = m_req.u.consoleMsgL1.ReadConsole;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_READ);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
// TODO: Implement cooked read (line editing) when ENABLE_LINE_INPUT is set.
// For now, return raw text from the input buffer.
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto maxOutputBytes = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
const bool unicode = a.Unicode != 0;
const auto inputCP = m_inputCP;
// Try to satisfy immediately.
const auto bytesWritten = tryReadRawTextToClient(
*this, m_req.Descriptor.Identifier, outputOffset, maxOutputBytes, unicode, inputCP);
if (bytesWritten > 0)
{
a.NumBytes = bytesWritten;
a.ControlKeyState = 0;
if (!m_input.hasData())
m_inputAvailableEvent.ResetEvent();
return STATUS_SUCCESS;
}
// No data — block until input arrives.
const auto id = m_req.Descriptor.Identifier;
return pendRead([id, outputOffset, maxOutputBytes, unicode, inputCP](Server& self) -> bool {
const auto bytesWritten = tryReadRawTextToClient(
self, id, outputOffset, maxOutputBytes, unicode, inputCP);
if (bytesWritten == 0)
return false; // Still no data.
// Build reply with NumBytes.
CONSOLE_READCONSOLE_MSG reply{};
reply.NumBytes = bytesWritten;
reply.ControlKeyState = 0;
reply.Unicode = unicode ? TRUE : FALSE;
CD_IO_COMPLETE completion{};
completion.Identifier = id;
completion.IoStatus.Status = STATUS_SUCCESS;
completion.Write.Data = &reply;
completion.Write.Size = sizeof(reply);
self.completeIo(completion);
return true;
});
}
// Helper: Convert payload to UTF-8 and write to VT output.
static void writePayloadAsVt(Server& self, const std::vector<uint8_t>& payload, bool unicode, UINT outputCP)
{
if (unicode)
{
self.vtAppendUTF16({ reinterpret_cast<const wchar_t*>(payload.data()), payload.size() / sizeof(wchar_t) });
}
else
{
const auto len = static_cast<int>(payload.size());
if (len > 0)
{
const auto wideLen = MultiByteToWideChar(outputCP, 0, reinterpret_cast<const char*>(payload.data()), len, nullptr, 0);
if (wideLen > 0)
{
std::wstring wide(wideLen, L'\0');
MultiByteToWideChar(outputCP, 0, reinterpret_cast<const char*>(payload.data()), len, wide.data(), wideLen);
self.vtAppendUTF16(wide);
}
}
}
self.vtFlush();
}
// L1: WriteConsole
// OG: SrvWriteConsole in stream.cpp → DoWriteConsole in _stream.cpp.
// If output is paused (CONSOLE_SUSPENDED), the write is deferred on the
// OutputQueue. When output is unpaused, UnblockWriteConsole wakes all
// pending writes. See also handleRawWrite for the same pattern.
NTSTATUS Server::handleUserL1WriteConsole()
{
auto& a = m_req.u.consoleMsgL1.WriteConsole;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_OUTPUT_HANDLE, GENERIC_WRITE);
if (!h)
{
return STATUS_INVALID_HANDLE;
}
auto payload = std::make_shared<std::vector<uint8_t>>(readTrailingInput());
const auto payloadSize = static_cast<ULONG>(payload->size());
const bool unicode = a.Unicode != 0;
const auto outputCP = m_outputCP;
if (m_outputPaused)
{
const auto id = m_req.Descriptor.Identifier;
return pendWrite([payload, payloadSize, unicode, outputCP, id](Server& self) -> bool {
if (self.m_outputPaused)
return false;
writePayloadAsVt(self, *payload, unicode, outputCP);
CONSOLE_WRITECONSOLE_MSG reply{};
reply.NumBytes = payloadSize;
reply.Unicode = unicode ? TRUE : FALSE;
CD_IO_COMPLETE completion{};
completion.Identifier = id;
completion.IoStatus.Status = STATUS_SUCCESS;
completion.Write.Data = &reply;
completion.Write.Size = sizeof(reply);
self.completeIo(completion);
return true;
});
}
writePayloadAsVt(*this, *payload, unicode, outputCP);
a.NumBytes = payloadSize;
return STATUS_SUCCESS;
}
// L1: GetConsoleLangId
// OG: SrvGetConsoleLangId in srvinit.cpp — no handle validation.
// Returns a->LangId based on output code page.
NTSTATUS Server::handleUserL1GetConsoleLangId()
{
auto& a = m_req.u.consoleMsgL1.GetConsoleLangId;
// TODO: Derive from actual output code page.
a.LangId = MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US);
return STATUS_SUCCESS;
}

View File

@@ -0,0 +1,613 @@
#include "pch.h"
#include "Server.h"
// Helper: call an IPtyHost method that returns HRESULT, and convert failure
// to an NTSTATUS return from the calling handler.
#define HOST_CALL(expr) \
do { \
const auto _hr = (expr); \
if (FAILED(_hr)) return STATUS_UNSUCCESSFUL; \
} while (0)
// L2: FillConsoleOutput
// OG: SrvFillConsoleOutput in directio.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
// Reads a->WriteCoord, a->ElementType, a->Element, a->Length.
// Returns a->Length (actual count filled).
NTSTATUS Server::handleUserL2FillConsoleOutput()
{
auto& a = m_req.u.consoleMsgL2.FillConsoleOutput;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
// TODO: Read existing cells via ReadBuffer, modify per ElementType, WriteUTF16 back.
// This requires the host to support cell-level read/write.
a.Length = 0;
return STATUS_SUCCESS;
}
// L2: GenerateConsoleCtrlEvent
// OG: SrvGenerateConsoleCtrlEvent in getset.cpp — no handle validation.
// Reads a->CtrlEvent, a->ProcessGroupId.
NTSTATUS Server::handleUserL2GenerateConsoleCtrlEvent()
{
auto& a = m_req.u.consoleMsgL2.GenerateConsoleCtrlEvent;
// TODO: Dispatch ctrl event (CTRL_C_EVENT / CTRL_BREAK_EVENT)
// to processes in a->ProcessGroupId.
(void)a.CtrlEvent;
(void)a.ProcessGroupId;
return STATUS_SUCCESS;
}
// L2: SetConsoleActiveScreenBuffer
// OG: SrvSetConsoleActiveScreenBuffer in getset.cpp
// DereferenceIoHandle(obj, OUTPUT|GRAPHICS_OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleActiveScreenBuffer()
{
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_OUTPUT_HANDLE, GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
if (h->screenBuffer != m_activeScreenBuffer && m_host)
{
const auto hr = m_host->ActivateBuffer(h->screenBuffer);
if (FAILED(hr))
return STATUS_UNSUCCESSFUL;
m_activeScreenBuffer = h->screenBuffer;
}
return STATUS_SUCCESS;
}
// L2: FlushConsoleInputBuffer
// OG: SrvFlushConsoleInputBuffer in getset.cpp
// DereferenceIoHandle(obj, INPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2FlushConsoleInputBuffer()
{
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
m_input.flush();
m_inputAvailableEvent.ResetEvent();
return STATUS_SUCCESS;
}
// L2: SetConsoleCP / SetConsoleOutputCP
// OG: SrvSetConsoleCP in getset.cpp — no handle validation.
// Reads a->CodePage, a->Output.
NTSTATUS Server::handleUserL2SetConsoleCP()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleCP;
if (!IsValidCodePage(a.CodePage))
return STATUS_INVALID_PARAMETER;
if (a.Output)
m_outputCP = a.CodePage;
else
m_inputCP = a.CodePage;
return STATUS_SUCCESS;
}
// L2: GetConsoleCursorInfo
// OG: SrvGetConsoleCursorInfo in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_READ)
// Returns a->CursorSize, a->Visible.
NTSTATUS Server::handleUserL2GetConsoleCursorInfo()
{
auto& a = m_req.u.consoleMsgL2.GetConsoleCursorInfo;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
a.CursorSize = info.CursorSize;
a.Visible = info.CursorVisible;
return STATUS_SUCCESS;
}
// L2: SetConsoleCursorInfo
// OG: SrvSetConsoleCursorInfo in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
// Reads a->CursorSize, a->Visible.
NTSTATUS Server::handleUserL2SetConsoleCursorInfo()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleCursorInfo;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
if (a.CursorSize < 1 || a.CursorSize > 100)
return STATUS_INVALID_PARAMETER;
ULONG cursorSize = a.CursorSize;
BOOLEAN cursorVisible = a.Visible;
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
change.CursorSize = &cursorSize;
change.CursorVisible = &cursorVisible;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: GetConsoleScreenBufferInfo (GetConsoleScreenBufferInfoEx)
// OG: SrvGetConsoleScreenBufferInfo in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_READ)
NTSTATUS Server::handleUserL2GetConsoleScreenBufferInfo()
{
auto& a = m_req.u.consoleMsgL2.GetConsoleScreenBufferInfo;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
a.Size = info.Size;
a.CursorPosition = info.CursorPosition;
a.ScrollPosition = { info.Window.Left, info.Window.Top };
a.Attributes = info.Attributes;
a.CurrentWindowSize = {
static_cast<SHORT>(info.Window.Right - info.Window.Left + 1),
static_cast<SHORT>(info.Window.Bottom - info.Window.Top + 1),
};
a.MaximumWindowSize = info.MaximumWindowSize;
a.PopupAttributes = info.PopupAttributes;
a.FullscreenSupported = info.FullscreenSupported;
memcpy(a.ColorTable, info.ColorTable, sizeof(a.ColorTable));
return STATUS_SUCCESS;
}
// L2: SetConsoleScreenBufferInfoEx
// OG: SrvSetConsoleScreenBufferInfo in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleScreenBufferInfo()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleScreenBufferInfo;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
WORD attributes = a.Attributes;
WORD popupAttributes = a.PopupAttributes;
SMALL_RECT window = { 0, 0, static_cast<SHORT>(a.CurrentWindowSize.X - 1), static_cast<SHORT>(a.CurrentWindowSize.Y - 1) };
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
if (a.Size.X > 0 && a.Size.Y > 0)
change.Size = &a.Size;
change.Attributes = &attributes;
change.PopupAttributes = &popupAttributes;
change.Window = &window;
change.ColorTable = a.ColorTable;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: SetConsoleScreenBufferSize
// OG: SrvSetConsoleScreenBufferSize in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleScreenBufferSize()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleScreenBufferSize;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
if (a.Size.X <= 0 || a.Size.Y <= 0)
return STATUS_INVALID_PARAMETER;
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
change.Size = &a.Size;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: SetConsoleCursorPosition
// OG: SrvSetConsoleCursorPosition in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleCursorPosition()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleCursorPosition;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
change.CursorPosition = &a.CursorPosition;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: GetLargestConsoleWindowSize
// OG: SrvGetLargestConsoleWindowSize in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE) — yes, WRITE for a getter.
NTSTATUS Server::handleUserL2GetLargestConsoleWindowSize()
{
auto& a = m_req.u.consoleMsgL2.GetLargestConsoleWindowSize;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
a.Size = info.MaximumWindowSize;
return STATUS_SUCCESS;
}
// L2: ScrollConsoleScreenBuffer
// OG: SrvScrollConsoleScreenBuffer in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2ScrollConsoleScreenBuffer()
{
auto& a = m_req.u.consoleMsgL2.ScrollConsoleScreenBuffer;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
// TODO: Read source rectangle via ReadBuffer, write to destination,
// fill vacated area with a->Fill. Needs host ReadBuffer + WriteUTF16 support.
(void)a;
return STATUS_SUCCESS;
}
// L2: SetConsoleTextAttribute
// OG: SrvSetConsoleTextAttribute in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleTextAttribute()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleTextAttribute;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
change.Attributes = &a.Attributes;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: SetConsoleWindowInfo
// OG: SrvSetConsoleWindowInfo in getset.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2SetConsoleWindowInfo()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleWindowInfo;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
SMALL_RECT window = a.Window;
if (!a.Absolute)
{
// Relative mode: get current window position first.
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
window.Left += info.Window.Left;
window.Top += info.Window.Top;
window.Right += info.Window.Right;
window.Bottom += info.Window.Bottom;
}
PTY_SCREEN_BUFFER_INFO_CHANGE change{};
change.Window = &window;
HOST_CALL(m_host->SetScreenBufferInfo(&change));
return STATUS_SUCCESS;
}
// L2: ReadConsoleOutputString (ReadConsoleOutputCharacter / ReadConsoleOutputAttribute)
// OG: SrvReadConsoleOutputString in directio.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_READ)
NTSTATUS Server::handleUserL2ReadConsoleOutputString()
{
auto& a = m_req.u.consoleMsgL2.ReadConsoleOutputString;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto maxOutputBytes = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
// Compute how many cells we can read based on the output capacity and string type.
ULONG maxCells = 0;
switch (a.StringType)
{
case CONSOLE_ASCII:
maxCells = maxOutputBytes; // 1 byte per cell
break;
case CONSOLE_REAL_UNICODE:
maxCells = maxOutputBytes / sizeof(WCHAR); // 2 bytes per cell
break;
case CONSOLE_ATTRIBUTE:
maxCells = maxOutputBytes / sizeof(WORD); // 2 bytes per cell
break;
default:
return STATUS_INVALID_PARAMETER;
}
if (maxCells == 0)
{
a.NumRecords = 0;
return STATUS_SUCCESS;
}
// Read cells from the host.
std::vector<PTY_CHAR_INFO> cells(maxCells);
const auto hr = m_host->ReadBuffer(a.ReadCoord, maxCells, cells.data());
if (FAILED(hr))
{
a.NumRecords = 0;
return STATUS_SUCCESS;
}
// Convert to the requested string type.
if (a.StringType == CONSOLE_REAL_UNICODE)
{
std::vector<WCHAR> chars(maxCells);
for (ULONG i = 0; i < maxCells; i++)
chars[i] = cells[i].Char;
writeOutput(outputOffset, chars.data(), maxCells * sizeof(WCHAR));
}
else if (a.StringType == CONSOLE_ASCII)
{
// Convert each Unicode char to the output code page.
std::vector<WCHAR> wchars(maxCells);
for (ULONG i = 0; i < maxCells; i++)
wchars[i] = cells[i].Char;
const auto ansiLen = WideCharToMultiByte(m_outputCP, 0, wchars.data(), maxCells, nullptr, 0, nullptr, nullptr);
if (ansiLen > 0)
{
std::vector<char> ansi(ansiLen);
WideCharToMultiByte(m_outputCP, 0, wchars.data(), maxCells, ansi.data(), ansiLen, nullptr, nullptr);
const auto toWrite = std::min(static_cast<ULONG>(ansiLen), maxOutputBytes);
writeOutput(outputOffset, ansi.data(), toWrite);
// For ASCII, the cell count may differ from byte count due to DBCS.
// We return the number of cells consumed (= maxCells).
}
}
else if (a.StringType == CONSOLE_ATTRIBUTE)
{
std::vector<WORD> attrs(maxCells);
for (ULONG i = 0; i < maxCells; i++)
attrs[i] = cells[i].Attributes;
writeOutput(outputOffset, attrs.data(), maxCells * sizeof(WORD));
}
a.NumRecords = maxCells;
return STATUS_SUCCESS;
}
// L2: WriteConsoleInput
// OG: SrvWriteConsoleInput in directio.cpp
// DereferenceIoHandle(obj, INPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2WriteConsoleInput()
{
auto& a = m_req.u.consoleMsgL2.WriteConsoleInput;
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
auto payload = readTrailingInput();
auto numRecords = static_cast<ULONG>(payload.size() / sizeof(INPUT_RECORD));
// TODO: Handle Unicode vs ANSI conversion if !a->Unicode.
const auto* records = reinterpret_cast<const INPUT_RECORD*>(payload.data());
for (ULONG i = 0; i < numRecords; i++)
{
const auto& rec = records[i];
if (rec.EventType == KEY_EVENT)
{
const auto& ke = rec.Event.KeyEvent;
char buf[128];
const auto n = snprintf(buf, sizeof(buf), "\x1b[%u;%u;%u;%u;%lu;%u_",
ke.wVirtualKeyCode,
ke.wVirtualScanCode,
static_cast<unsigned>(ke.uChar.UnicodeChar),
ke.bKeyDown ? 1u : 0u,
ke.dwControlKeyState,
ke.wRepeatCount);
if (n > 0)
m_input.write({ buf, static_cast<size_t>(n) });
}
// TODO: Handle MOUSE_EVENT, WINDOW_BUFFER_SIZE_EVENT, etc.
}
if (numRecords > 0 && m_input.hasData())
{
m_inputAvailableEvent.SetEvent();
drainPendingInputReads();
}
a.NumRecords = numRecords;
return STATUS_SUCCESS;
}
// L2: WriteConsoleOutput
// OG: SrvWriteConsoleOutput in directio.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2WriteConsoleOutput()
{
auto& a = m_req.u.consoleMsgL2.WriteConsoleOutput;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
auto payload = readTrailingInput();
// TODO: Write CHAR_INFO grid from payload into screen buffer at a->CharRegion.
// This requires a WriteBuffer-like host callback or per-row WriteUTF16.
(void)payload;
a.CharRegion = {};
return STATUS_SUCCESS;
}
// L2: WriteConsoleOutputString (WriteConsoleOutputCharacter / WriteConsoleOutputAttribute)
// OG: SrvWriteConsoleOutputString in directio.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_WRITE)
NTSTATUS Server::handleUserL2WriteConsoleOutputString()
{
auto& a = m_req.u.consoleMsgL2.WriteConsoleOutputString;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
auto payload = readTrailingInput();
// TODO: Write chars/attrs from payload to screen buffer at a->WriteCoord.
// Requires per-cell write support on the host.
(void)payload;
a.NumRecords = 0;
return STATUS_SUCCESS;
}
// L2: ReadConsoleOutput
// OG: SrvReadConsoleOutput in directio.cpp
// DereferenceIoHandle(obj, OUTPUT, GENERIC_READ)
NTSTATUS Server::handleUserL2ReadConsoleOutput()
{
auto& a = m_req.u.consoleMsgL2.ReadConsoleOutput;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto maxOutputBytes = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
const SHORT width = a.CharRegion.Right - a.CharRegion.Left + 1;
const SHORT height = a.CharRegion.Bottom - a.CharRegion.Top + 1;
if (width <= 0 || height <= 0)
{
a.CharRegion = {};
return STATUS_SUCCESS;
}
const auto totalCells = static_cast<ULONG>(width) * height;
const auto maxCells = maxOutputBytes / sizeof(CHAR_INFO);
if (maxCells == 0)
{
a.CharRegion = {};
return STATUS_SUCCESS;
}
// Read row by row from the host.
std::vector<CHAR_INFO> result(totalCells);
for (SHORT row = 0; row < height; row++)
{
COORD readPos = { a.CharRegion.Left, static_cast<SHORT>(a.CharRegion.Top + row) };
std::vector<PTY_CHAR_INFO> cells(width);
const auto hr = m_host->ReadBuffer(readPos, width, cells.data());
if (FAILED(hr))
break;
for (SHORT col = 0; col < width; col++)
{
auto& ci = result[row * width + col];
ci.Char.UnicodeChar = cells[col].Char;
ci.Attributes = cells[col].Attributes;
}
}
const auto cellsToWrite = std::min(totalCells, static_cast<ULONG>(maxCells));
if (cellsToWrite > 0)
writeOutput(outputOffset, result.data(), cellsToWrite * sizeof(CHAR_INFO));
a.CharRegion = {
a.CharRegion.Left,
a.CharRegion.Top,
static_cast<SHORT>(a.CharRegion.Left + width - 1),
static_cast<SHORT>(a.CharRegion.Top + height - 1),
};
return STATUS_SUCCESS;
}
// L2: GetConsoleTitle
// OG: SrvGetConsoleTitle in cmdline.cpp — no handle validation.
NTSTATUS Server::handleUserL2GetConsoleTitle()
{
auto& a = m_req.u.consoleMsgL2.GetConsoleTitle;
const auto& title = a.Original ? m_originalTitle : m_title;
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto outputCapacity = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
if (a.Unicode)
{
const auto bytes = static_cast<ULONG>(title.size() * sizeof(WCHAR));
const auto bytesToWrite = std::min(bytes, outputCapacity);
if (bytesToWrite > 0)
writeOutput(outputOffset, title.data(), bytesToWrite);
a.TitleLength = bytes;
}
else
{
const auto ansiLen = WideCharToMultiByte(m_outputCP, 0, title.data(), static_cast<int>(title.size()), nullptr, 0, nullptr, nullptr);
const auto bytesToWrite = std::min(static_cast<ULONG>(std::max(ansiLen, 0)), outputCapacity);
if (bytesToWrite > 0)
{
std::string ansi(ansiLen, '\0');
WideCharToMultiByte(m_outputCP, 0, title.data(), static_cast<int>(title.size()), ansi.data(), ansiLen, nullptr, nullptr);
writeOutput(outputOffset, ansi.data(), bytesToWrite);
}
a.TitleLength = static_cast<ULONG>(std::max(ansiLen, 0));
}
return STATUS_SUCCESS;
}
// L2: SetConsoleTitle
// OG: SrvSetConsoleTitle in cmdline.cpp — no handle validation.
NTSTATUS Server::handleUserL2SetConsoleTitle()
{
auto& a = m_req.u.consoleMsgL2.SetConsoleTitle;
auto payload = readTrailingInput();
if (a.Unicode)
{
m_title.assign(reinterpret_cast<const wchar_t*>(payload.data()), payload.size() / sizeof(wchar_t));
}
else
{
const auto len = static_cast<int>(payload.size());
const auto wideLen = MultiByteToWideChar(m_outputCP, 0, reinterpret_cast<const char*>(payload.data()), len, nullptr, 0);
if (wideLen > 0)
{
m_title.resize(wideLen);
MultiByteToWideChar(m_outputCP, 0, reinterpret_cast<const char*>(payload.data()), len, m_title.data(), wideLen);
}
else
{
m_title.clear();
}
}
// Emit the title to the host as an OSC sequence.
vtAppendTitle(m_title);
vtFlush();
return STATUS_SUCCESS;
}

View File

@@ -0,0 +1,474 @@
#include "pch.h"
#include "Server.h"
#define HOST_CALL(expr) \
do { \
const auto _hr = (expr); \
if (FAILED(_hr)) return STATUS_UNSUCCESSFUL; \
} while (0)
// Helper: extract a wide string from a byte payload at the given offset.
static std::wstring extractWideString(const std::vector<uint8_t>& payload, size_t offset, USHORT byteLen)
{
if (offset + byteLen > payload.size() || byteLen % sizeof(WCHAR) != 0)
return {};
return { reinterpret_cast<const wchar_t*>(payload.data() + offset), byteLen / sizeof(WCHAR) };
}
// ============================================================================
// AliasStore implementation
// ============================================================================
void AliasStore::add(std::wstring_view exe, std::wstring_view source, std::wstring_view target)
{
auto& map = exes[std::wstring(exe)];
if (target.empty())
map.erase(std::wstring(source)); // Empty target = remove alias.
else
map[std::wstring(source)] = std::wstring(target);
// Clean up empty exe entries.
if (map.empty())
exes.erase(std::wstring(exe));
}
void AliasStore::remove(std::wstring_view exe, std::wstring_view source)
{
auto it = exes.find(std::wstring(exe));
if (it != exes.end())
{
it->second.erase(std::wstring(source));
if (it->second.empty())
exes.erase(it);
}
}
const std::wstring* AliasStore::find(std::wstring_view exe, std::wstring_view source) const
{
auto exeIt = exes.find(std::wstring(exe));
if (exeIt == exes.end())
return nullptr;
auto srcIt = exeIt->second.find(std::wstring(source));
if (srcIt == exeIt->second.end())
return nullptr;
return &srcIt->second;
}
void AliasStore::expunge(std::wstring_view exe)
{
exes.erase(std::wstring(exe));
}
// ============================================================================
// CommandHistoryStore implementation
// ============================================================================
void CommandHistory::add(std::wstring_view cmd)
{
if (!allowDuplicates)
{
// Remove any existing duplicate before adding.
std::erase_if(commands, [&](const auto& c) { return c == cmd; });
}
commands.emplace_back(cmd);
if (commands.size() > maxCommands)
commands.erase(commands.begin());
}
void CommandHistory::clear()
{
commands.clear();
}
CommandHistory& CommandHistoryStore::getOrCreate(std::wstring_view exe)
{
auto& h = exes[std::wstring(exe)];
if (h.maxCommands == 0)
{
h.maxCommands = defaultBufferSize;
h.allowDuplicates = !(flags & 1); // HISTORY_NO_DUP_FLAG = 1
}
return h;
}
void CommandHistoryStore::expunge(std::wstring_view exe)
{
auto it = exes.find(std::wstring(exe));
if (it != exes.end())
it->second.clear();
}
// ============================================================================
// L3 handler implementations
// ============================================================================
NTSTATUS Server::handleUserL3GetConsoleMouseInfo()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleMouseInfo;
a.NumButtons = GetSystemMetrics(SM_CMOUSEBUTTONS);
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleFontSize()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleFontSize;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
// The OG only supports font index 0 in modern builds.
(void)a.FontIndex;
a.FontSize = info.FontSize;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleCurrentFont()
{
auto& a = m_req.u.consoleMsgL3.GetCurrentConsoleFont;
auto* h = activateOutputBuffer(GENERIC_READ);
if (!h)
return STATUS_INVALID_HANDLE;
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
a.FontIndex = info.FontIndex;
a.FontSize = info.FontSize;
a.FontFamily = info.FontFamily;
a.FontWeight = info.FontWeight;
memcpy(a.FaceName, info.FaceName, sizeof(a.FaceName));
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3SetConsoleDisplayMode()
{
auto& a = m_req.u.consoleMsgL3.SetConsoleDisplayMode;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
// Fullscreen is not supported. Return current buffer dimensions.
PTY_SCREEN_BUFFER_INFO info{};
HOST_CALL(m_host->GetScreenBufferInfo(&info));
a.ScreenBufferDimensions = info.Size;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleDisplayMode()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleDisplayMode;
a.ModeFlags = 0; // Always windowed.
return STATUS_SUCCESS;
}
// ============================================================================
// Alias handlers
// ============================================================================
NTSTATUS Server::handleUserL3AddConsoleAlias()
{
auto& a = m_req.u.consoleMsgL3.AddConsoleAliasW;
if (a.SourceLength > USHRT_MAX || a.TargetLength > USHRT_MAX || a.ExeLength > USHRT_MAX)
return STATUS_INVALID_PARAMETER;
// TODO: Handle !a.Unicode (ANSI) by converting to Unicode first.
auto payload = readTrailingInput();
const auto totalLen = static_cast<size_t>(a.SourceLength) + a.TargetLength + a.ExeLength;
if (payload.size() < totalLen)
return STATUS_INVALID_PARAMETER;
auto source = extractWideString(payload, 0, a.SourceLength);
auto target = extractWideString(payload, a.SourceLength, a.TargetLength);
auto exe = extractWideString(payload, a.SourceLength + a.TargetLength, a.ExeLength);
m_aliases.add(exe, source, target);
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleAlias()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleAliasW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
const auto totalLen = static_cast<size_t>(a.SourceLength) + a.ExeLength;
if (payload.size() < totalLen)
return STATUS_INVALID_PARAMETER;
auto source = extractWideString(payload, 0, a.SourceLength);
auto exe = extractWideString(payload, a.SourceLength, a.ExeLength);
const auto* target = m_aliases.find(exe, source);
if (!target)
{
a.TargetLength = 0;
return STATUS_SUCCESS;
}
const auto bytes = static_cast<USHORT>(target->size() * sizeof(WCHAR));
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto outputCapacity = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
const auto toWrite = std::min(static_cast<ULONG>(bytes), outputCapacity);
if (toWrite > 0)
writeOutput(outputOffset, target->data(), toWrite);
a.TargetLength = bytes;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleAliasesLength()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleAliasesLengthW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
ULONG totalBytes = 0;
auto exeIt = m_aliases.exes.find(exe);
if (exeIt != m_aliases.exes.end())
{
for (const auto& [src, tgt] : exeIt->second)
{
// Format: "source=target\0" (in WCHARs)
totalBytes += static_cast<ULONG>((src.size() + 1 + tgt.size() + 1) * sizeof(WCHAR));
}
}
a.AliasesLength = totalBytes;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleAliasExesLength()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleAliasExesLengthW;
// TODO: Handle !a.Unicode (ANSI).
ULONG totalBytes = 0;
for (const auto& [exe, _] : m_aliases.exes)
{
totalBytes += static_cast<ULONG>((exe.size() + 1) * sizeof(WCHAR)); // "exe\0"
}
a.AliasExesLength = totalBytes;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleAliases()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleAliasesW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto outputCapacity = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
std::wstring buf;
auto exeIt = m_aliases.exes.find(exe);
if (exeIt != m_aliases.exes.end())
{
for (const auto& [src, tgt] : exeIt->second)
{
buf += src;
buf += L'=';
buf += tgt;
buf += L'\0';
}
}
const auto bytes = static_cast<ULONG>(buf.size() * sizeof(WCHAR));
const auto toWrite = std::min(bytes, outputCapacity);
if (toWrite > 0)
writeOutput(outputOffset, buf.data(), toWrite);
a.AliasesBufferLength = bytes;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleAliasExes()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleAliasExesW;
// TODO: Handle !a.Unicode (ANSI).
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto outputCapacity = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
std::wstring buf;
for (const auto& [exe, _] : m_aliases.exes)
{
buf += exe;
buf += L'\0';
}
const auto bytes = static_cast<ULONG>(buf.size() * sizeof(WCHAR));
const auto toWrite = std::min(bytes, outputCapacity);
if (toWrite > 0)
writeOutput(outputOffset, buf.data(), toWrite);
a.AliasExesBufferLength = bytes;
return STATUS_SUCCESS;
}
// ============================================================================
// Command history handlers
// ============================================================================
NTSTATUS Server::handleUserL3ExpungeConsoleCommandHistory()
{
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
m_history.expunge(exe);
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3SetConsoleNumberOfCommands()
{
auto& a = m_req.u.consoleMsgL3.SetConsoleNumberOfCommandsW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
auto& hist = m_history.getOrCreate(exe);
hist.maxCommands = a.NumCommands;
// Trim if current history exceeds new limit.
while (hist.commands.size() > hist.maxCommands)
hist.commands.erase(hist.commands.begin());
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleCommandHistoryLength()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleCommandHistoryLengthW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
ULONG totalBytes = 0;
auto it = m_history.exes.find(exe);
if (it != m_history.exes.end())
{
for (const auto& cmd : it->second.commands)
totalBytes += static_cast<ULONG>((cmd.size() + 1) * sizeof(WCHAR)); // "cmd\0"
}
a.CommandHistoryLength = totalBytes;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleCommandHistory()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleCommandHistoryW;
// TODO: Handle !a.Unicode (ANSI).
auto payload = readTrailingInput();
auto exe = extractWideString(payload, 0, static_cast<USHORT>(payload.size()));
const auto outputOffset = m_req.msgHeader.ApiDescriptorSize;
const auto outputCapacity = (m_req.Descriptor.OutputSize > outputOffset) ? (m_req.Descriptor.OutputSize - outputOffset) : 0u;
std::wstring buf;
auto it = m_history.exes.find(exe);
if (it != m_history.exes.end())
{
for (const auto& cmd : it->second.commands)
{
buf += cmd;
buf += L'\0';
}
}
const auto bytes = static_cast<ULONG>(buf.size() * sizeof(WCHAR));
const auto toWrite = std::min(bytes, outputCapacity);
if (toWrite > 0)
writeOutput(outputOffset, buf.data(), toWrite);
a.CommandBufferLength = bytes;
return STATUS_SUCCESS;
}
// ============================================================================
// Window, selection, process, history settings, font
// ============================================================================
NTSTATUS Server::handleUserL3GetConsoleWindow()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleWindow;
HWND hwnd = nullptr;
if (m_host)
m_host->GetConsoleWindow(&hwnd);
a.hwnd = hwnd;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleSelectionInfo()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleSelectionInfo;
// Selection is not supported in PTY mode.
a.SelectionInfo = {};
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleProcessList()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleProcessList;
const auto capacity = a.dwProcessCount;
std::vector<DWORD> pids;
pids.reserve(m_clients.size());
for (auto& c : m_clients)
pids.push_back(c->processId);
a.dwProcessCount = static_cast<DWORD>(pids.size());
if (capacity >= pids.size() && !pids.empty())
{
const auto writeOffset = m_req.msgHeader.ApiDescriptorSize;
writeOutput(writeOffset, pids.data(), static_cast<ULONG>(pids.size() * sizeof(DWORD)));
}
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3GetConsoleHistory()
{
auto& a = m_req.u.consoleMsgL3.GetConsoleHistory;
a.HistoryBufferSize = m_history.defaultBufferSize;
a.NumberOfHistoryBuffers = m_history.numberOfBuffers;
a.dwFlags = m_history.flags;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3SetConsoleHistory()
{
auto& a = m_req.u.consoleMsgL3.SetConsoleHistory;
m_history.defaultBufferSize = a.HistoryBufferSize;
m_history.numberOfBuffers = a.NumberOfHistoryBuffers;
m_history.flags = a.dwFlags;
return STATUS_SUCCESS;
}
NTSTATUS Server::handleUserL3SetConsoleCurrentFont()
{
auto& a = m_req.u.consoleMsgL3.SetCurrentConsoleFont;
auto* h = activateOutputBuffer(GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
// TODO: Font changes via SetScreenBufferInfo once we add font fields to the change struct.
(void)a;
return STATUS_SUCCESS;
}

221
src/conpty/Server.raw.cpp Normal file
View File

@@ -0,0 +1,221 @@
#include "pch.h"
#include "Server.h"
// ============================================================================
// Pending IO infrastructure
// ============================================================================
// Queue a read that will be retried when new input arrives.
NTSTATUS Server::pendRead(std::function<bool(Server&)> retry)
{
PendingIO pending;
pending.identifier = m_req.Descriptor.Identifier;
pending.process = m_req.Descriptor.Process;
pending.retry = std::move(retry);
m_pendingReads.push_back(std::move(pending));
return STATUS_NO_RESPONSE;
}
// Queue a write that will be retried when output is unpaused.
NTSTATUS Server::pendWrite(std::function<bool(Server&)> retry)
{
PendingIO pending;
pending.identifier = m_req.Descriptor.Identifier;
pending.process = m_req.Descriptor.Process;
pending.retry = std::move(retry);
m_pendingWrites.push_back(std::move(pending));
return STATUS_NO_RESPONSE;
}
// Send a completion for a specific pending IO.
void Server::completePendingIo(const LUID& identifier, NTSTATUS status, ULONG_PTR information)
{
CD_IO_COMPLETE completion{};
completion.Identifier = identifier;
completion.IoStatus.Status = status;
completion.IoStatus.Information = information;
completeIo(completion);
}
// Walk the pending-read queue and retry each. Remove any that succeed.
// Called when new input arrives (WriteInput, WriteConsoleInput, etc.).
//
// Analogous to OG WakeUpReadersWaitingForData → ConsoleNotifyWait.
// In the OG, only ONE reader is woken per input arrival (fSatisfyAll=FALSE).
// We follow the same pattern: stop after the first successful retry.
void Server::drainPendingInputReads()
{
auto it = m_pendingReads.begin();
while (it != m_pendingReads.end())
{
if (it->retry(*this))
{
it = m_pendingReads.erase(it);
break; // OG: only satisfy one reader at a time.
}
else
{
++it;
}
}
if (!m_input.hasData())
m_inputAvailableEvent.ResetEvent();
}
// Walk the pending-write queue and retry each. Remove any that succeed.
// Called when output is unpaused.
//
// Analogous to OG UnblockWriteConsole → ConsoleNotifyWait(OutputQueue, TRUE).
// In the OG, ALL pending writes are retried (fSatisfyAll=TRUE).
void Server::drainPendingOutputWrites()
{
auto it = m_pendingWrites.begin();
while (it != m_pendingWrites.end())
{
if (it->retry(*this))
{
it = m_pendingWrites.erase(it);
}
else
{
++it;
}
}
}
// Cancel all pending IOs for a disconnecting process.
//
// Analogous to OG FreeProcessData which calls ConsoleNotifyWaitBlock
// with fThreadDying=TRUE for each wait block owned by the process.
void Server::cancelPendingIOs(ULONG_PTR process)
{
auto cancelMatching = [&](std::deque<PendingIO>& queue) {
auto it = queue.begin();
while (it != queue.end())
{
if (it->process == process)
{
// Best-effort: complete with STATUS_CANCELLED.
CD_IO_COMPLETE completion{};
completion.Identifier = it->identifier;
completion.IoStatus.Status = STATUS_CANCELLED;
ioctl(IOCTL_CONDRV_COMPLETE_IO, &completion, sizeof(completion), nullptr, 0);
it = queue.erase(it);
}
else
{
++it;
}
}
};
cancelMatching(m_pendingReads);
cancelMatching(m_pendingWrites);
}
// ============================================================================
// Raw IO handlers
// ============================================================================
// Handles CONSOLE_IO_RAW_WRITE.
//
// OG path: ConsoleIoThread RAW_WRITE → SrvWriteConsole → DoWriteConsole.
// If output is paused (CONSOLE_SUSPENDED), the write is deferred on the
// OutputQueue and retried when UnblockWriteConsole is called.
NTSTATUS Server::handleRawWrite()
{
const auto size = m_req.Descriptor.InputSize;
// Read the payload upfront — the driver expects us to consume it.
auto buffer = std::make_shared<std::vector<uint8_t>>(size);
if (size > 0)
{
readInput(0, buffer->data(), size);
}
if (m_outputPaused)
{
// Capture the data and identifier for the retry lambda.
const auto id = m_req.Descriptor.Identifier;
return pendWrite([buffer, id](Server& self) -> bool {
if (self.m_outputPaused)
return false; // Still paused.
self.m_host->WriteUTF8({ reinterpret_cast<const char*>(buffer->data()), buffer->size() });
self.completePendingIo(id, STATUS_SUCCESS, static_cast<ULONG_PTR>(buffer->size()));
return true;
});
}
m_host->WriteUTF8({ reinterpret_cast<const char*>(buffer->data()), buffer->size() });
return STATUS_SUCCESS;
}
// Handles CONSOLE_IO_RAW_READ.
//
// OG path: ConsoleIoThread RAW_READ → SrvReadConsole → ReadChars → GetChar
// → ReadInputBuffer. If the input buffer is empty, a wait block is created
// on the ReadWaitQueue. When input arrives, WakeUpReadersWaitingForData
// walks the queue and retries each pending read.
NTSTATUS Server::handleRawRead()
{
const auto maxBytes = m_req.Descriptor.OutputSize;
// Try to satisfy immediately.
if (m_input.hasData() && maxBytes > 0)
{
std::vector<char> buf(maxBytes);
const auto n = m_input.readRawText(buf.data(), maxBytes);
if (n > 0)
{
writeOutput(0, buf.data(), static_cast<ULONG>(n));
if (!m_input.hasData())
m_inputAvailableEvent.ResetEvent();
// Complete out-of-band (the response carries write data).
completePendingIo(m_req.Descriptor.Identifier, STATUS_SUCCESS, static_cast<ULONG_PTR>(n));
return STATUS_NO_RESPONSE;
}
}
// No data — defer until input arrives.
const auto id = m_req.Descriptor.Identifier;
return pendRead([maxBytes, id](Server& self) -> bool {
if (!self.m_input.hasData())
return false;
std::vector<char> buf(maxBytes);
const auto n = self.m_input.readRawText(buf.data(), maxBytes);
if (n == 0)
return false;
// Write data back to the client's read buffer.
CD_IO_OPERATION op{};
op.Identifier = id;
op.Buffer.Offset = 0;
op.Buffer.Data = buf.data();
op.Buffer.Size = static_cast<ULONG>(n);
THROW_IF_NTSTATUS_FAILED(self.ioctl(IOCTL_CONDRV_WRITE_OUTPUT, &op, sizeof(op), nullptr, 0));
self.completePendingIo(id, STATUS_SUCCESS, static_cast<ULONG_PTR>(n));
return true;
});
}
// Handles CONSOLE_IO_RAW_FLUSH.
//
// OG path: SrvFlushConsoleInputBuffer → FlushInputBuffer.
// Clears the input buffer and resets the input-available event.
NTSTATUS Server::handleRawFlush()
{
auto* h = findHandle(m_req.Descriptor.Object, CONSOLE_INPUT_HANDLE, GENERIC_WRITE);
if (!h)
return STATUS_INVALID_HANDLE;
m_input.flush();
m_inputAvailableEvent.ResetEvent();
return STATUS_SUCCESS;
}

300
src/conpty/VtParser.cpp Normal file
View File

@@ -0,0 +1,300 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "pch.h"
#include "VtParser.h"
bool VtParser::hasEscTimeout() const noexcept
{
return m_state == State::Esc;
}
VtParser::Stream VtParser::parse(std::string_view input) noexcept
{
return Stream{ this, input };
}
// Decode one UTF-8 codepoint from the input at the current offset.
// Advances m_off past the consumed bytes. Returns U+0000 at end-of-input.
char32_t VtParser::Stream::nextChar()
{
const auto* bytes = reinterpret_cast<const uint8_t*>(m_input.data());
const auto len = m_input.size();
if (m_off >= len)
return U'\0';
const auto b0 = bytes[m_off];
// ASCII fast path.
if (b0 < 0x80)
{
m_off++;
return static_cast<char32_t>(b0);
}
// Determine sequence length and initial bits.
size_t seqLen;
char32_t cp;
if ((b0 & 0xE0) == 0xC0) { seqLen = 2; cp = b0 & 0x1F; }
else if ((b0 & 0xF0) == 0xE0) { seqLen = 3; cp = b0 & 0x0F; }
else if ((b0 & 0xF8) == 0xF0) { seqLen = 4; cp = b0 & 0x07; }
else { m_off++; return U'\xFFFD'; } // Invalid lead byte.
if (m_off + seqLen > len)
{
// Incomplete codepoint at end of input — consume what we have.
m_off = len;
return U'\xFFFD';
}
for (size_t i = 1; i < seqLen; i++)
{
const auto cont = bytes[m_off + i];
if ((cont & 0xC0) != 0x80)
{
m_off += i;
return U'\xFFFD';
}
cp = (cp << 6) | (cont & 0x3F);
}
m_off += seqLen;
return cp;
}
bool VtParser::Stream::next(VtToken& out)
{
const auto* bytes = reinterpret_cast<const uint8_t*>(m_input.data());
const auto len = m_input.size();
// If the previous input ended with an escape character, and we're called
// with empty input (timeout fired), return the bare ESC.
if (len == 0 && m_parser->m_state == State::Esc)
{
m_parser->m_state = State::Ground;
out.type = VtToken::Esc;
out.ch = '\0';
return true;
}
while (m_off < len)
{
switch (m_parser->m_state)
{
case State::Ground:
{
const auto b = bytes[m_off];
if (b == 0x1B)
{
m_parser->m_state = State::Esc;
m_off++;
break; // Continue the outer loop to process the Esc state.
}
if (b < 0x20 || b == 0x7F)
{
m_off++;
out.type = VtToken::Ctrl;
out.ch = static_cast<char>(b);
return true;
}
// Bulk scan printable text (>= 0x20, != 0x7F, != 0x1B).
const auto beg = m_off;
do {
m_off++;
} while (m_off < len && bytes[m_off] >= 0x20 && bytes[m_off] != 0x7F && bytes[m_off] != 0x1B);
out.type = VtToken::Text;
out.payload = m_input.substr(beg, m_off - beg);
return true;
}
case State::Esc:
{
const auto ch = nextChar();
switch (ch)
{
case '[':
m_parser->m_state = State::Csi;
m_parser->m_csi.privateByte = '\0';
m_parser->m_csi.finalByte = '\0';
// Clear only params that were used last time.
while (m_parser->m_csi.paramCount > 0)
{
m_parser->m_csi.paramCount--;
m_parser->m_csi.params[m_parser->m_csi.paramCount] = 0;
}
break;
case ']':
m_parser->m_state = State::Osc;
break;
case 'O':
m_parser->m_state = State::Ss3;
break;
case 'P':
m_parser->m_state = State::Dcs;
break;
default:
m_parser->m_state = State::Ground;
out.type = VtToken::Esc;
// Truncate to char. For the sequences we care about this is always ASCII.
out.ch = static_cast<char>(ch);
return true;
}
break;
}
case State::Ss3:
{
m_parser->m_state = State::Ground;
const auto ch = nextChar();
out.type = VtToken::SS3;
out.ch = static_cast<char>(ch);
return true;
}
case State::Csi:
{
for (;;)
{
// Parse parameter digits.
if (m_parser->m_csi.paramCount < std::size(m_parser->m_csi.params))
{
auto& dst = m_parser->m_csi.params[m_parser->m_csi.paramCount];
while (m_off < len && bytes[m_off] >= '0' && bytes[m_off] <= '9')
{
const uint32_t add = bytes[m_off] - '0';
const uint32_t value = static_cast<uint32_t>(dst) * 10 + add;
dst = static_cast<uint16_t>(std::min(value, static_cast<uint32_t>(UINT16_MAX)));
m_off++;
}
}
else
{
// Overflow: skip digits.
while (m_off < len && bytes[m_off] >= '0' && bytes[m_off] <= '9')
m_off++;
}
// Need more data?
if (m_off >= len)
return false;
const auto c = bytes[m_off];
m_off++;
if (c >= 0x40 && c <= 0x7E)
{
// Final byte.
m_parser->m_state = State::Ground;
m_parser->m_csi.finalByte = static_cast<char>(c);
if (m_parser->m_csi.paramCount != 0 || m_parser->m_csi.params[0] != 0)
m_parser->m_csi.paramCount++;
out.type = VtToken::Csi;
out.csi = &m_parser->m_csi;
return true;
}
if (c == ';')
{
m_parser->m_csi.paramCount++;
}
else if (c >= '<' && c <= '?')
{
m_parser->m_csi.privateByte = static_cast<char>(c);
}
// else: intermediate bytes (0x20-0x2F) or unknown — silently skip.
}
}
case State::Osc:
case State::Dcs:
{
const auto beg = m_off;
std::string_view data;
bool partial;
for (;;)
{
// Scan for BEL (0x07) or ESC (0x1B) — potential terminators.
while (m_off < len && bytes[m_off] != 0x07 && bytes[m_off] != 0x1B)
m_off++;
data = m_input.substr(beg, m_off - beg);
partial = m_off >= len;
if (partial)
break;
const auto c = bytes[m_off];
m_off++;
if (c == 0x1B)
{
// ESC might start ST (ESC \). Check next byte.
if (m_off >= len)
{
// At end of input — save state for next chunk.
m_parser->m_state = (m_parser->m_state == State::Osc) ? State::OscEsc : State::DcsEsc;
partial = true;
break;
}
if (bytes[m_off] != '\\')
continue; // False alarm, not ST.
m_off++; // Consume the backslash.
}
// BEL or ESC \ — sequence is complete.
break;
}
const auto wasOsc = (m_parser->m_state == State::Osc);
if (!partial)
m_parser->m_state = State::Ground;
out.type = wasOsc ? VtToken::Osc : VtToken::Dcs;
out.payload = data;
out.partial = partial;
return true;
}
case State::OscEsc:
case State::DcsEsc:
{
// Previous chunk ended with ESC inside an OSC/DCS.
// Check if this chunk starts with '\' to complete the ST.
if (bytes[m_off] == '\\')
{
const auto wasOsc = (m_parser->m_state == State::OscEsc);
m_parser->m_state = State::Ground;
m_off++;
out.type = wasOsc ? VtToken::Osc : VtToken::Dcs;
out.payload = {};
out.partial = false;
return true;
}
else
{
// False alarm — the ESC was not a string terminator.
// Return it as partial payload and resume the string state.
const auto wasOsc = (m_parser->m_state == State::OscEsc);
m_parser->m_state = wasOsc ? State::Osc : State::Dcs;
out.type = wasOsc ? VtToken::Osc : VtToken::Dcs;
out.payload = "\x1b";
out.partial = true;
return true;
}
}
} // switch
} // while
return false;
}

134
src/conpty/VtParser.h Normal file
View File

@@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// A reusable VT tokenizer, ported from the Rust implementation in Microsoft Edit.
//
// The parser produces tokens from a UTF-8 byte stream. It handles chunked input
// correctly — if a sequence is split across two feed() calls, the parser buffers
// the incomplete prefix and completes it on the next call.
//
// Usage:
// VtParser parser;
// VtParser::Stream stream = parser.parse(input);
// VtToken token;
// while (stream.next(token)) {
// switch (token.type) { ... }
// }
#pragma once
#include <cstdint>
#include <string_view>
// A single CSI sequence, parsed for your convenience.
struct VtCsi
{
// The parameters of the CSI sequence.
uint16_t params[32]{};
// The number of parameters stored in params[].
size_t paramCount = 0;
// The private byte, if any. '\0' if none.
// The private byte is the first character right after the ESC [ sequence.
// It is usually a '?' or '<'.
char privateByte = '\0';
// The final byte of the CSI sequence.
// This is the last character of the sequence, e.g. 'm' or 'H'.
char finalByte = '\0';
};
struct VtToken
{
enum Type : uint8_t
{
// A bunch of text. Doesn't contain any control characters.
Text,
// A single control character, like backspace or return.
Ctrl,
// We encountered ESC x and this contains x (in `ch`).
Esc,
// We encountered ESC O x and this contains x (in `ch`).
SS3,
// A CSI sequence started with ESC [. See `csi`.
Csi,
// An OSC sequence started with ESC ]. May be partial (chunked).
Osc,
// A DCS sequence started with ESC P. May be partial (chunked).
Dcs,
};
Type type = Text;
// For Ctrl: the control byte itself.
// For Esc/SS3: the character after ESC / ESC O.
char ch = '\0';
// For Csi: pointer to the parser's Csi struct. Valid until the next next() call.
const VtCsi* csi = nullptr;
// For Text/Osc/Dcs: the string payload (points into the input buffer, zero-copy).
std::string_view payload;
// For Osc/Dcs: true if the sequence is incomplete (split across chunks).
bool partial = false;
};
class VtParser
{
public:
class Stream;
VtParser() = default;
// Returns true if the parser is in the middle of an ESC sequence,
// meaning the caller should apply a timeout before the next parse() call.
// If the timeout fires, call parse("") to flush the bare ESC.
bool hasEscTimeout() const noexcept;
// Begin parsing the given input. Returns a Stream that yields tokens.
// The returned Stream borrows from both `this` and `input` — do not
// modify either while the Stream is alive.
Stream parse(std::string_view input) noexcept;
private:
enum class State : uint8_t
{
Ground,
Esc,
Ss3,
Csi,
Osc,
Dcs,
OscEsc,
DcsEsc,
};
State m_state = State::Ground;
VtCsi m_csi;
};
// An iterator that yields VtTokens from a single parse() call.
// This is a "lending iterator" — the token references data owned by
// the parser and the input string_view.
class VtParser::Stream
{
public:
Stream(VtParser* parser, std::string_view input) noexcept
: m_parser(parser), m_input(input) {}
// The input being parsed.
std::string_view input() const noexcept { return m_input; }
// Current byte offset into the input.
size_t offset() const noexcept { return m_off; }
// True if all input has been consumed.
bool done() const noexcept { return m_off >= m_input.size(); }
// Get the next token. Returns false when no more complete tokens
// can be extracted (remaining bytes are an incomplete sequence).
bool next(VtToken& out);
// Decode and consume one UTF-8 codepoint. Returns '\0' at end.
char32_t nextChar();
private:
VtParser* m_parser;
std::string_view m_input;
size_t m_off = 0;
};

View File

@@ -0,0 +1,32 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup Label="Globals">
<VCProjectVersion>18.0</VCProjectVersion>
<Keyword>Win32Proj</Keyword>
<ProjectGuid>{10715020-e347-4d4e-a8f2-d11e4bced6ec}</ProjectGuid>
<ProjectName>conptytest</ProjectName>
<RootNamespace>conptytest</RootNamespace>
<WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(SolutionDir)src\common.build.pre.props" />
<Import Project="$(SolutionDir)src\common.nugetversions.props" />
<ItemDefinitionGroup>
<ClCompile>
<AdditionalIncludeDirectories>$(OutDir)\conpty;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="main.cpp" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\conpty.vcxproj">
<Project>{23a66bb9-dccf-420c-b1a1-fa9ecfe7db65}</Project>
</ProjectReference>
</ItemGroup>
<Import Project="$(SolutionDir)src\common.build.post.props" />
<Import Project="$(SolutionDir)src\common.nugetversions.targets" />
</Project>

View File

@@ -0,0 +1,25 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<Natvis Include="$(SolutionDir)tools\ConsoleTypes.natvis" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="main.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,589 @@
// A minimal Win32 terminal — Windows 95 conhost style.
// Fixed-size char32_t grid, GDI rendering, no Unicode shaping, no scrollback.
// Implements IPtyHost and uses VtParser to interpret output from the server.
#define NOMINMAX
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <wil/com.h>
#include <wil/resource.h>
#include <algorithm>
#include <atomic>
#include <string>
#include <thread>
#include <conpty.h>
#include "../VtParser.h"
// ============================================================================
// Terminal grid
// ============================================================================
static constexpr SHORT COLS = 120;
static constexpr SHORT ROWS = 30;
static constexpr int CELL_W = 8;
static constexpr int CELL_H = 16;
struct Cell
{
char32_t ch = ' ';
WORD attr = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE;
};
struct TermState
{
Cell grid[ROWS][COLS]{};
SHORT cursorX = 0;
SHORT cursorY = 0;
WORD currentAttr = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE;
bool cursorVisible = true;
std::wstring title = L"conpty-test";
COLORREF colorTable[16] = {
0x000000, 0x800000, 0x008000, 0x808000, 0x000080, 0x800080, 0x008080, 0xC0C0C0,
0x808080, 0xFF0000, 0x00FF00, 0xFFFF00, 0x0000FF, 0xFF00FF, 0x00FFFF, 0xFFFFFF,
};
void scrollUp()
{
memmove(&grid[0], &grid[1], sizeof(Cell) * COLS * (ROWS - 1));
for (SHORT x = 0; x < COLS; x++)
grid[ROWS - 1][x] = Cell{};
}
void advanceCursor()
{
cursorX++;
if (cursorX >= COLS)
{
cursorX = 0;
cursorY++;
if (cursorY >= ROWS)
{
cursorY = ROWS - 1;
scrollUp();
}
}
}
void linefeed()
{
cursorY++;
if (cursorY >= ROWS)
{
cursorY = ROWS - 1;
scrollUp();
}
}
void putChar(char32_t ch)
{
if (cursorX < COLS && cursorY < ROWS)
{
grid[cursorY][cursorX] = { ch, currentAttr };
advanceCursor();
}
}
void eraseDisplay(int mode)
{
const Cell blank = { ' ', currentAttr };
if (mode == 0)
{
for (SHORT x = cursorX; x < COLS; x++) grid[cursorY][x] = blank;
for (SHORT y = cursorY + 1; y < ROWS; y++)
for (SHORT x = 0; x < COLS; x++) grid[y][x] = blank;
}
else if (mode == 1)
{
for (SHORT y = 0; y < cursorY; y++)
for (SHORT x = 0; x < COLS; x++) grid[y][x] = blank;
for (SHORT x = 0; x <= cursorX && x < COLS; x++) grid[cursorY][x] = blank;
}
else if (mode == 2 || mode == 3)
{
for (SHORT y = 0; y < ROWS; y++)
for (SHORT x = 0; x < COLS; x++)
grid[y][x] = blank;
}
}
void eraseLine(int mode)
{
const Cell blank = { ' ', currentAttr };
SHORT start = 0, end = COLS;
if (mode == 0) start = cursorX;
else if (mode == 1) end = cursorX + 1;
for (SHORT x = start; x < end; x++)
grid[cursorY][x] = blank;
}
};
// ============================================================================
// Globals
// ============================================================================
static TermState g_term;
static VtParser g_vtParser;
static HWND g_hwnd = nullptr;
static wil::com_ptr<IPtyServer> g_server;
static CRITICAL_SECTION g_lock;
static void invalidate()
{
if (g_hwnd)
InvalidateRect(g_hwnd, nullptr, FALSE);
}
// ============================================================================
// VT output interpreter
// ============================================================================
static COLORREF attrToFg(WORD attr)
{
return g_term.colorTable[attr & 0x0F];
}
static COLORREF attrToBg(WORD attr)
{
return g_term.colorTable[(attr >> 4) & 0x0F];
}
static void parseSGR(const VtCsi& csi)
{
if (csi.paramCount == 0)
{
g_term.currentAttr = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE;
return;
}
// VT→Console color index: VT uses RGB bit order, Console uses BGR.
static constexpr WORD vtToConsole[] = { 0, 4, 2, 6, 1, 5, 3, 7 };
for (size_t i = 0; i < csi.paramCount; i++)
{
const auto p = csi.params[i];
switch (p)
{
case 0: g_term.currentAttr = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE; break;
case 1: g_term.currentAttr |= FOREGROUND_INTENSITY; break;
case 7: g_term.currentAttr |= COMMON_LVB_REVERSE_VIDEO; break;
case 22: g_term.currentAttr &= ~FOREGROUND_INTENSITY; break;
case 27: g_term.currentAttr &= ~COMMON_LVB_REVERSE_VIDEO; break;
case 39: g_term.currentAttr = (g_term.currentAttr & ~0x0F) | 0x07; break;
case 49: g_term.currentAttr &= ~0xF0; break;
default:
if (p >= 30 && p <= 37)
g_term.currentAttr = (g_term.currentAttr & ~0x07) | vtToConsole[p - 30];
else if (p >= 40 && p <= 47)
g_term.currentAttr = (g_term.currentAttr & ~0x70) | (vtToConsole[p - 40] << 4);
else if (p >= 90 && p <= 97)
g_term.currentAttr = (g_term.currentAttr & ~0x0F) | vtToConsole[p - 90] | FOREGROUND_INTENSITY;
else if (p >= 100 && p <= 107)
g_term.currentAttr = (g_term.currentAttr & ~0xF0) | (vtToConsole[p - 100] << 4) | BACKGROUND_INTENSITY;
break;
}
}
}
static void handleCsiOutput(const VtCsi& csi)
{
const auto p0 = (csi.paramCount >= 1) ? csi.params[0] : 0;
const auto p1 = (csi.paramCount >= 2) ? csi.params[1] : 0;
switch (csi.finalByte)
{
case 'A': g_term.cursorY = std::max<SHORT>(0, g_term.cursorY - std::max<SHORT>(1, (SHORT)p0)); break;
case 'B': g_term.cursorY = std::min<SHORT>(ROWS - 1, g_term.cursorY + std::max<SHORT>(1, (SHORT)p0)); break;
case 'C': g_term.cursorX = std::min<SHORT>(COLS - 1, g_term.cursorX + std::max<SHORT>(1, (SHORT)p0)); break;
case 'D': g_term.cursorX = std::max<SHORT>(0, g_term.cursorX - std::max<SHORT>(1, (SHORT)p0)); break;
case 'H':
case 'f':
g_term.cursorY = std::clamp<SHORT>(p0 > 0 ? (SHORT)(p0 - 1) : 0, 0, ROWS - 1);
g_term.cursorX = std::clamp<SHORT>(p1 > 0 ? (SHORT)(p1 - 1) : 0, 0, COLS - 1);
break;
case 'J': g_term.eraseDisplay(p0); break;
case 'K': g_term.eraseLine(p0); break;
case 'm': parseSGR(csi); break;
case 'h':
if (csi.privateByte == '?' && p0 == 25) g_term.cursorVisible = true;
break;
case 'l':
if (csi.privateByte == '?' && p0 == 25) g_term.cursorVisible = false;
break;
default: break;
}
}
static void processOutput(std::string_view text)
{
auto stream = g_vtParser.parse(text);
VtToken token;
while (stream.next(token))
{
switch (token.type)
{
case VtToken::Text:
{
const auto* bytes = reinterpret_cast<const uint8_t*>(token.payload.data());
size_t i = 0;
while (i < token.payload.size())
{
uint32_t cp;
size_t seqLen;
const auto b = bytes[i];
if (b < 0x80) { cp = b; seqLen = 1; }
else if ((b & 0xE0) == 0xC0) { cp = b & 0x1F; seqLen = 2; }
else if ((b & 0xF0) == 0xE0) { cp = b & 0x0F; seqLen = 3; }
else if ((b & 0xF8) == 0xF0) { cp = b & 0x07; seqLen = 4; }
else { i++; continue; }
if (i + seqLen > token.payload.size()) break;
for (size_t j = 1; j < seqLen; j++)
cp = (cp << 6) | (bytes[i + j] & 0x3F);
i += seqLen;
g_term.putChar(cp);
}
break;
}
case VtToken::Ctrl:
switch (token.ch)
{
case '\r': g_term.cursorX = 0; break;
case '\n': g_term.linefeed(); break;
case '\b': if (g_term.cursorX > 0) g_term.cursorX--; break;
case '\t': g_term.cursorX = std::min<SHORT>(COLS - 1, (g_term.cursorX + 8) & ~7); break;
case '\a': MessageBeep(MB_OK); break;
default: break;
}
break;
case VtToken::Csi:
handleCsiOutput(*token.csi);
break;
case VtToken::Osc:
if (!token.partial && token.payload.size() >= 2 &&
(token.payload[0] == '0' || token.payload[0] == '2') && token.payload[1] == ';')
{
auto data = token.payload.substr(2);
const auto wLen = MultiByteToWideChar(CP_UTF8, 0, data.data(), (int)data.size(), nullptr, 0);
if (wLen > 0)
{
g_term.title.resize(wLen);
MultiByteToWideChar(CP_UTF8, 0, data.data(), (int)data.size(), g_term.title.data(), wLen);
if (g_hwnd) SetWindowTextW(g_hwnd, g_term.title.c_str());
}
}
break;
default: break;
}
}
invalidate();
}
// ============================================================================
// IPtyHost implementation
// ============================================================================
struct TestHost : IPtyHost
{
HRESULT QueryInterface(const IID& riid, void** ppvObject) override
{
if (!ppvObject) return E_POINTER;
if (riid == __uuidof(IPtyHost) || riid == __uuidof(IUnknown))
{
*ppvObject = static_cast<IPtyHost*>(this);
AddRef();
return S_OK;
}
*ppvObject = nullptr;
return E_NOINTERFACE;
}
ULONG AddRef() override { return m_refCount.fetch_add(1) + 1; }
ULONG Release() override
{
const auto c = m_refCount.fetch_sub(1) - 1;
if (c == 0) delete this;
return c;
}
void WriteUTF8(PTY_UTF8_STRING text) override
{
EnterCriticalSection(&g_lock);
processOutput({ text.data, text.length });
LeaveCriticalSection(&g_lock);
}
void WriteUTF16(PTY_UTF16_STRING text) override
{
const auto len = WideCharToMultiByte(CP_UTF8, 0, text.data, (int)text.length, nullptr, 0, nullptr, nullptr);
if (len > 0)
{
std::string utf8(len, '\0');
WideCharToMultiByte(CP_UTF8, 0, text.data, (int)text.length, utf8.data(), len, nullptr, nullptr);
EnterCriticalSection(&g_lock);
processOutput(utf8);
LeaveCriticalSection(&g_lock);
}
}
HRESULT CreateBuffer(void** buffer) override { *buffer = nullptr; return E_NOTIMPL; }
HRESULT ReleaseBuffer(void*) override { return S_OK; }
HRESULT ActivateBuffer(void*) override { return S_OK; }
HRESULT GetScreenBufferInfo(PTY_SCREEN_BUFFER_INFO* info) override
{
EnterCriticalSection(&g_lock);
*info = {};
info->Size = { COLS, ROWS };
info->CursorPosition = { g_term.cursorX, g_term.cursorY };
info->Attributes = g_term.currentAttr;
info->Window = { 0, 0, COLS - 1, ROWS - 1 };
info->MaximumWindowSize = { COLS, ROWS };
info->CursorSize = 25;
info->CursorVisible = g_term.cursorVisible;
info->FontSize = { CELL_W, CELL_H };
info->FontFamily = FF_MODERN | FIXED_PITCH;
info->FontWeight = FW_NORMAL;
wcscpy_s(info->FaceName, L"Terminal");
memcpy(info->ColorTable, g_term.colorTable, sizeof(g_term.colorTable));
LeaveCriticalSection(&g_lock);
return S_OK;
}
HRESULT SetScreenBufferInfo(const PTY_SCREEN_BUFFER_INFO_CHANGE* change) override
{
EnterCriticalSection(&g_lock);
if (change->CursorPosition)
{
g_term.cursorX = std::clamp(change->CursorPosition->X, SHORT(0), SHORT(COLS - 1));
g_term.cursorY = std::clamp(change->CursorPosition->Y, SHORT(0), SHORT(ROWS - 1));
}
if (change->Attributes) g_term.currentAttr = *change->Attributes;
if (change->CursorVisible) g_term.cursorVisible = *change->CursorVisible != 0;
if (change->ColorTable) memcpy(g_term.colorTable, change->ColorTable, sizeof(g_term.colorTable));
invalidate();
LeaveCriticalSection(&g_lock);
return S_OK;
}
HRESULT ReadBuffer(COORD pos, LONG count, PTY_CHAR_INFO* infos) override
{
EnterCriticalSection(&g_lock);
for (LONG i = 0; i < count; i++)
{
SHORT x = pos.X + static_cast<SHORT>(i);
SHORT y = pos.Y;
while (x >= COLS && y < ROWS) { x -= COLS; y++; }
if (y >= 0 && y < ROWS && x >= 0 && x < COLS)
{
infos[i].Char = static_cast<WCHAR>(g_term.grid[y][x].ch <= 0xFFFF ? g_term.grid[y][x].ch : L'?');
infos[i].Attributes = g_term.grid[y][x].attr;
}
else
{
infos[i].Char = L' ';
infos[i].Attributes = 0x07;
}
}
LeaveCriticalSection(&g_lock);
return S_OK;
}
HRESULT GetConsoleWindow(HWND* hwnd) override { *hwnd = g_hwnd; return S_OK; }
private:
std::atomic<ULONG> m_refCount{ 1 };
};
// ============================================================================
// GDI rendering
// ============================================================================
static void paint(HWND hwnd, HDC hdc)
{
EnterCriticalSection(&g_lock);
RECT rc;
GetClientRect(hwnd, &rc);
const auto bgBrush = CreateSolidBrush(g_term.colorTable[0]);
FillRect(hdc, &rc, bgBrush);
DeleteObject(bgBrush);
const auto font = CreateFontW(
CELL_H, CELL_W, 0, 0, FW_NORMAL, FALSE, FALSE, FALSE,
DEFAULT_CHARSET, OUT_DEFAULT_PRECIS, CLIP_DEFAULT_PRECIS,
NONANTIALIASED_QUALITY, FIXED_PITCH | FF_MODERN, L"Terminal");
const auto oldFont = SelectObject(hdc, font);
SetBkMode(hdc, OPAQUE);
for (SHORT y = 0; y < ROWS; y++)
{
for (SHORT x = 0; x < COLS; x++)
{
const auto& cell = g_term.grid[y][x];
COLORREF fg, bg;
if (cell.attr & COMMON_LVB_REVERSE_VIDEO)
{
fg = attrToBg(cell.attr);
bg = attrToFg(cell.attr);
}
else
{
fg = attrToFg(cell.attr);
bg = attrToBg(cell.attr);
}
SetTextColor(hdc, fg);
SetBkColor(hdc, bg);
wchar_t wch = static_cast<wchar_t>(cell.ch <= 0xFFFF ? cell.ch : L'?');
if (wch < 0x20) wch = L' ';
TextOutW(hdc, x * CELL_W, y * CELL_H, &wch, 1);
}
}
if (g_term.cursorVisible)
{
RECT cur = { g_term.cursorX * CELL_W, g_term.cursorY * CELL_H + CELL_H - 2,
g_term.cursorX * CELL_W + CELL_W, g_term.cursorY * CELL_H + CELL_H };
const auto curBrush = CreateSolidBrush(g_term.colorTable[7]);
FillRect(hdc, &cur, curBrush);
DeleteObject(curBrush);
}
SelectObject(hdc, oldFont);
DeleteObject(font);
LeaveCriticalSection(&g_lock);
}
// ============================================================================
// Window procedure
// ============================================================================
static LRESULT CALLBACK WndProc(HWND hwnd, UINT msg, WPARAM wParam, LPARAM lParam)
{
switch (msg)
{
case WM_PAINT:
{
PAINTSTRUCT ps;
const auto hdc = BeginPaint(hwnd, &ps);
paint(hwnd, hdc);
EndPaint(hwnd, &ps);
return 0;
}
case WM_CHAR:
{
wchar_t wch = static_cast<wchar_t>(wParam);
char utf8[4];
const auto len = WideCharToMultiByte(CP_UTF8, 0, &wch, 1, utf8, sizeof(utf8), nullptr, nullptr);
if (len > 0 && g_server)
{
PTY_UTF8_STRING input{ utf8, static_cast<SIZE_T>(len) };
g_server->WriteUTF8(input);
}
return 0;
}
case WM_KEYDOWN:
{
const char* seq = nullptr;
switch (wParam)
{
case VK_UP: seq = "\x1b[A"; break;
case VK_DOWN: seq = "\x1b[B"; break;
case VK_RIGHT: seq = "\x1b[C"; break;
case VK_LEFT: seq = "\x1b[D"; break;
case VK_HOME: seq = "\x1b[H"; break;
case VK_END: seq = "\x1b[F"; break;
case VK_INSERT: seq = "\x1b[2~"; break;
case VK_DELETE: seq = "\x1b[3~"; break;
case VK_PRIOR: seq = "\x1b[5~"; break;
case VK_NEXT: seq = "\x1b[6~"; break;
case VK_F1: seq = "\x1bOP"; break;
case VK_F2: seq = "\x1bOQ"; break;
case VK_F3: seq = "\x1bOR"; break;
case VK_F4: seq = "\x1bOS"; break;
case VK_F5: seq = "\x1b[15~"; break;
case VK_F6: seq = "\x1b[17~"; break;
case VK_F7: seq = "\x1b[18~"; break;
case VK_F8: seq = "\x1b[19~"; break;
case VK_F9: seq = "\x1b[20~"; break;
case VK_F10: seq = "\x1b[21~"; break;
case VK_F11: seq = "\x1b[23~"; break;
case VK_F12: seq = "\x1b[24~"; break;
default: return DefWindowProcW(hwnd, msg, wParam, lParam);
}
if (seq && g_server)
{
PTY_UTF8_STRING input{ seq, strlen(seq) };
g_server->WriteUTF8(input);
}
return 0;
}
case WM_DESTROY:
PostQuitMessage(0);
return 0;
default:
return DefWindowProcW(hwnd, msg, wParam, lParam);
}
}
// ============================================================================
// Entry point
// ============================================================================
int WINAPI wWinMain(HINSTANCE hInstance, HINSTANCE, LPWSTR, int nCmdShow)
{
InitializeCriticalSection(&g_lock);
WNDCLASSEXW wc{};
wc.cbSize = sizeof(wc);
wc.style = CS_HREDRAW | CS_VREDRAW;
wc.lpfnWndProc = WndProc;
wc.hInstance = hInstance;
wc.hCursor = LoadCursor(nullptr, IDC_ARROW);
wc.lpszClassName = L"ConPtyTestWindow";
RegisterClassExW(&wc);
RECT wr = { 0, 0, COLS * CELL_W, ROWS * CELL_H };
AdjustWindowRect(&wr, WS_OVERLAPPEDWINDOW, FALSE);
g_hwnd = CreateWindowExW(
0, L"ConPtyTestWindow", L"conpty-test",
WS_OVERLAPPEDWINDOW & ~(WS_THICKFRAME | WS_MAXIMIZEBOX),
CW_USEDEFAULT, CW_USEDEFAULT,
wr.right - wr.left, wr.bottom - wr.top,
nullptr, nullptr, hInstance, nullptr);
ShowWindow(g_hwnd, nCmdShow);
UpdateWindow(g_hwnd);
THROW_IF_FAILED(PtyCreateServer(IID_PPV_ARGS(g_server.addressof())));
THROW_IF_FAILED(g_server->SetHost(new TestHost()));
wil::unique_process_information pi;
THROW_IF_FAILED(g_server->CreateProcessW(
nullptr, _wcsdup(L"cmd.exe"),
nullptr, nullptr, FALSE, 0, nullptr, nullptr,
pi.addressof()));
// Run the console server on a background thread.
// It blocks in its message loop until all clients disconnect.
std::thread serverThread([&] {
g_server->Run();
PostMessage(g_hwnd, WM_CLOSE, 0, 0);
});
serverThread.detach();
MSG msg;
while (GetMessageW(&msg, nullptr, 0, 0))
{
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
DeleteCriticalSection(&g_lock);
return 0;
}

115
src/conpty/conpty.idl Normal file
View File

@@ -0,0 +1,115 @@
import "unknwnbase.idl";
typedef struct _STARTUPINFOEXW* LPSTARTUPINFOEXW;
typedef struct _STARTUPINFOW* LPSTARTUPINFOW;
typedef struct _PROCESS_INFORMATION* LPPROCESS_INFORMATION;
typedef struct PTY_UTF8_STRING
{
[size_is(length)] const char* data;
SIZE_T length;
} PTY_UTF8_STRING;
typedef struct PTY_UTF16_STRING
{
[size_is(length)] const wchar_t* data;
SIZE_T length;
} PTY_UTF16_STRING;
// Screen buffer info. Mirrors CONSOLE_SCREEN_BUFFER_INFOEX.
// Passed between Server and Host for Get/SetConsoleScreenBufferInfoEx.
typedef struct PTY_SCREEN_BUFFER_INFO
{
COORD Size;
COORD CursorPosition;
WORD Attributes;
SMALL_RECT Window; // In the OG, this is the viewport rect within the buffer.
COORD MaximumWindowSize;
WORD PopupAttributes;
BOOLEAN FullscreenSupported;
COLORREF ColorTable[16];
ULONG CursorSize; // 1..100
BOOLEAN CursorVisible;
// Font info.
ULONG FontIndex;
COORD FontSize;
ULONG FontFamily;
ULONG FontWeight;
WCHAR FaceName[32]; // LF_FACESIZE
} PTY_SCREEN_BUFFER_INFO;
// Partial update to screen buffer info. Non-null pointers indicate which
// fields should be changed. Mirrors the spec's CONSRV_INFO_CHANGE.
typedef struct PTY_SCREEN_BUFFER_INFO_CHANGE
{
[unique] COORD* Size;
[unique] COORD* CursorPosition;
[unique] WORD* Attributes;
[unique] SMALL_RECT* Window;
[unique] WORD* PopupAttributes;
[unique] COLORREF* ColorTable; // Always 16 entries if non-null.
[unique] ULONG* CursorSize;
[unique] BOOLEAN* CursorVisible;
} PTY_SCREEN_BUFFER_INFO_CHANGE;
// A single cell in a console screen buffer.
// Binary-compatible with CHAR_INFO.
typedef struct PTY_CHAR_INFO
{
WCHAR Char;
WORD Attributes;
} PTY_CHAR_INFO;
[uuid(e9b4897e-19f8-4833-af28-f777deeba7e6), object, local, pointer_default(unique)]
interface IPtyHost : IUnknown
{
// Output: the server sends VT or text to the host for rendering.
void WriteUTF8(PTY_UTF8_STRING text);
void WriteUTF16(PTY_UTF16_STRING text);
// Screen buffer lifecycle.
// Returns an opaque buffer ID. NULL = main buffer (never explicitly created).
HRESULT CreateBuffer([out] void** buffer);
HRESULT ReleaseBuffer([in] void* buffer);
HRESULT ActivateBuffer([in] void* buffer);
// Query the full state of the currently active screen buffer.
HRESULT GetScreenBufferInfo([out] PTY_SCREEN_BUFFER_INFO* info);
// Apply a partial change to the currently active screen buffer.
HRESULT SetScreenBufferInfo([in] const PTY_SCREEN_BUFFER_INFO_CHANGE* change);
// Read cells from the active buffer. The host fills `infos` with
// `count` cells starting at column pos.X on row pos.Y.
// Coordinates are buffer-relative (0-indexed).
HRESULT ReadBuffer([in] COORD pos, [in] LONG count, [out, size_is(count)] PTY_CHAR_INFO* infos);
// Returns the console window handle. Used for GetConsoleWindow.
HRESULT GetConsoleWindow([out] HWND* hwnd);
}
[uuid(9a727f67-09bd-429b-8f73-766a718070f0), object, local, pointer_default(unique)]
interface IPtyServer : IUnknown
{
HRESULT SetHost(IPtyHost* host);
HRESULT WriteUTF8(PTY_UTF8_STRING input);
HRESULT WriteUTF16(PTY_UTF16_STRING input);
HRESULT Run();
// Identical to CreateProcessW(), except for the absence of lpStartupInfo.
HRESULT CreateProcessW(
[in] LPCWSTR lpApplicationName, // NOTE: optional
[in, out] LPWSTR lpCommandLine, // NOTE: optional
[in] LPSECURITY_ATTRIBUTES lpProcessAttributes, // NOTE: optional
[in] LPSECURITY_ATTRIBUTES lpThreadAttributes, // NOTE: optional
[in] BOOL bInheritHandles,
[in] DWORD dwCreationFlags,
[in] LPVOID lpEnvironment, // NOTE: optional
[in] LPCWSTR lpCurrentDirectory, // NOTE: optional
[out] LPPROCESS_INFORMATION lpProcessInformation
);
}
cpp_quote("HRESULT WINAPI PtyCreateServer(REFIID riid, void** server);")

61
src/conpty/conpty.vcxproj Normal file
View File

@@ -0,0 +1,61 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup Label="Globals">
<VCProjectVersion>18.0</VCProjectVersion>
<Keyword>Win32Proj</Keyword>
<ProjectGuid>{23a66bb9-dccf-420c-b1a1-fa9ecfe7db65}</ProjectGuid>
<ProjectName>conpty</ProjectName>
<RootNamespace>conpty</RootNamespace>
<WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
<ConfigurationType>StaticLibrary</ConfigurationType>
<TargetName>$(ProjectName)_v2</TargetName>
</PropertyGroup>
<Import Project="$(SolutionDir)src\common.build.pre.props" />
<Import Project="$(SolutionDir)src\common.nugetversions.props" />
<PropertyGroup>
<OutDir>$(SolutionDir)bin\$(Platform)\$(Configuration)\$(ProjectName)\</OutDir>
</PropertyGroup>
<ItemDefinitionGroup>
<ClCompile>
<AdditionalIncludeDirectories>$(OutDir);%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
</ClCompile>
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="pch.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>
</ClCompile>
<ClCompile Include="InputBuffer.cpp" />
<ClCompile Include="Server.cpp" />
<ClCompile Include="Server.handles.cpp" />
<ClCompile Include="Server.msg.cpp" />
<ClCompile Include="Server.msg.l1.cpp" />
<ClCompile Include="Server.msg.l2.cpp" />
<ClCompile Include="Server.msg.l3.cpp" />
<ClCompile Include="Server.raw.cpp" />
<ClCompile Include="VtParser.cpp" />
</ItemGroup>
<ItemGroup>
<Midl Include="conpty.idl">
<!--
AKA: /out $(OutDir) /h %(Filename).h /dlldata NUL /iid NUL /proxy NUL /char unsigned /cstruct_out /prefix all VTbl_
...but MSVC gets antsy about that. It already passes some of these on its own.
-->
<OutputDirectory>$(OutDir)</OutputDirectory>
<HeaderFileName>%(Filename).h</HeaderFileName>
<DllDataFileName>nul</DllDataFileName>
<ProxyFileName>nul</ProxyFileName>
<DefaultCharType>Unsigned</DefaultCharType>
<InterfaceIdentifierFileName>nul</InterfaceIdentifierFileName>
<AdditionalOptions>/cstruct_out /prefix all VTbl_ %(AdditionalOptions)</AdditionalOptions>
</Midl>
</ItemGroup>
<ItemGroup>
<ClInclude Include="pch.h" />
<ClInclude Include="InputBuffer.h" />
<ClInclude Include="Server.h" />
<ClInclude Include="VtParser.h" />
</ItemGroup>
<Import Project="$(SolutionDir)src\common.build.post.props" />
<Import Project="$(SolutionDir)src\common.nugetversions.targets" />
</Project>

View File

@@ -0,0 +1,73 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<Natvis Include="$(SolutionDir)tools\ConsoleTypes.natvis" />
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natvis" />
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natstepfilter" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="Server.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.handles.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="pch.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.raw.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.msg.l1.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.msg.l2.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.msg.l3.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Server.msg.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="InputBuffer.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="VtParser.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<Midl Include="conpty.idl">
<Filter>Source Files</Filter>
</Midl>
</ItemGroup>
<ItemGroup>
<ClInclude Include="Server.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="pch.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="InputBuffer.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="VtParser.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
</Project>

1
src/conpty/pch.cpp Normal file
View File

@@ -0,0 +1 @@
#include "pch.h"

29
src/conpty/pch.h Normal file
View File

@@ -0,0 +1,29 @@
#pragma once
#define NOMINMAX
#define UMDF_USING_NTSTATUS
#define WIN32_LEAN_AND_MEAN
#include <ntstatus.h>
#include <Windows.h>
#include <winioctl.h>
#include <winternl.h>
#include <ntcon.h>
// NOTE: These headers depend on Windows/winternl being included first.
#include <condrv.h>
#include <conmsgl1.h>
#include <conmsgl2.h>
#include <conmsgl3.h>
#include <array>
#include <atomic>
#include <span>
#include <cassert>
#include <deque>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <wil/com.h>

View File

@@ -1156,6 +1156,8 @@ void AtlasEngine::_mapComplex(IDWriteFontFace2* mappedFontFace, u32 idx, u32 len
const size_t col2 = _api.bufferLineColumn[a.textPosition + i];
const auto fg = colors[col1 << shift];
// TODO: Instead of aligning each DWrite-cluster to the cell grid,
// we should align each grapheme cluster to the cell grid.
const auto expectedAdvance = (col2 - col1) * _p.s->font->cellSize.x;
f32 actualAdvance = 0;
for (auto j = prevCluster; j < nextCluster; ++j)

View File

@@ -33,14 +33,16 @@
wil::unique_handle ServerHandle;
RETURN_IF_NTSTATUS_FAILED(DeviceHandle::CreateServerHandle(ServerHandle.addressof(), FALSE));
RETURN_IF_NTSTATUS_FAILED(Entrypoints::StartConsoleForServerHandle(ServerHandle.get(), args));
ServerHandle.release();
return S_OK;
wil::unique_handle ReferenceHandle;
RETURN_IF_NTSTATUS_FAILED(DeviceHandle::CreateClientHandle(ReferenceHandle.addressof(),
ServerHandle.get(),
L"\\Reference",
FALSE));
RETURN_IF_NTSTATUS_FAILED(Entrypoints::StartConsoleForServerHandle(ServerHandle.get(), args));
// If we get to here, we have transferred ownership of the server handle to the console, so release it.
// Keep a copy of the value so we can open the client handles even though we're no longer the owner.
const auto hServer = ServerHandle.release();

View File

@@ -1,12 +1,11 @@
#Requires -Version 7
# The project's root directory.
$script:OpenConsoleFallbackRoot="$PSScriptRoot\.."
$script:OpenConsoleFallbackRoot = "$PSScriptRoot\.."
#.SYNOPSIS
# Finds the root of the current Terminal checkout.
function Find-OpenConsoleRoot
{
function Find-OpenConsoleRoot {
$root = (git rev-parse --show-toplevel 2>$null)
If ($?) {
return $root
@@ -18,11 +17,10 @@ function Find-OpenConsoleRoot
# Finds and imports a module that should be local to the project
#.PARAMETER ModuleName
# The name of the module to import
function Import-LocalModule
{
function Import-LocalModule {
[CmdletBinding()]
param(
[parameter(Mandatory=$true, Position=0)]
[parameter(Mandatory = $true, Position = 0)]
[string]$Name
)
@@ -32,8 +30,7 @@ function Import-LocalModule
$local = $null -eq (Get-Module -Name $Name)
if (-not $local)
{
if (-not $local) {
return
}
@@ -49,7 +46,8 @@ function Import-LocalModule
Write-Verbose "Saving $Name to $modules_root"
Save-Module -InputObject $module -Path $modules_root
Import-Module "$modules_root\$Name\$version\$Name.psd1"
} else {
}
else {
Write-Verbose "$Name already downloaded"
$versions = Get-ChildItem "$modules_root\$Name" | Sort-Object
@@ -60,8 +58,7 @@ function Import-LocalModule
#.SYNOPSIS
# Grabs all environment variable set after vcvarsall.bat is called and pulls
# them into the Powershell environment.
function Set-MsbuildDevEnvironment
{
function Set-MsbuildDevEnvironment {
[CmdletBinding()]
param(
[switch]$Prerelease
@@ -74,9 +71,9 @@ function Set-MsbuildDevEnvironment
Write-Verbose 'Searching for VC++ instances'
$vsinfo = `
Get-VSSetupInstance -All -Prerelease:$Prerelease `
| Select-VSSetupInstance `
-Latest -Product * `
-Require 'Microsoft.VisualStudio.Component.VC.Tools.x86.x64'
| Select-VSSetupInstance `
-Latest -Product * `
-Require 'Microsoft.VisualStudio.Component.VC.Tools.x86.x64'
$vspath = $vsinfo.InstallationPath
@@ -114,20 +111,19 @@ function Set-MsbuildDevEnvironment
#
#.PARAMETER $TaefArgs
# Any arguments to path to Taef.
function Invoke-TaefInNewWindow()
{
function Invoke-TaefInNewWindow() {
[CmdletBinding()]
Param (
[parameter(Mandatory=$true)]
[parameter(Mandatory = $true)]
[string]$OpenConsolePath,
[parameter(Mandatory=$true)]
[parameter(Mandatory = $true)]
[string]$TaefPath,
[parameter(Mandatory=$true)]
[parameter(Mandatory = $true)]
[string]$TestDll,
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string[]]$TaefArgs
)
@@ -160,28 +156,27 @@ function Invoke-TaefInNewWindow()
#.PARAMETER Configuration
# The configuration of the OpenConsole tests to run. Can be "Debug" or
# "Release". Defaults to "Debug".
function Invoke-OpenConsoleTests()
{
function Invoke-OpenConsoleTests() {
[CmdletBinding()]
Param (
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[switch]$AllTests,
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[switch]$FTOnly,
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[ValidateSet('host', 'interactivityWin32', 'terminal', 'adapter', 'feature', 'uia', 'textbuffer', 'til', 'types', 'terminalCore', 'terminalApp', 'localTerminalApp', 'unitSettingsModel', 'unitControl', 'winconpty')]
[string]$Test,
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string[]]$TaefArgs,
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[ValidateSet('x64', 'x86')]
[string]$Platform = "x64",
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[ValidateSet('Debug', 'Release')]
[string]$Configuration = "Debug"
@@ -189,14 +184,12 @@ function Invoke-OpenConsoleTests()
$root = Find-OpenConsoleRoot
if (($AllTests -and $FTOnly) -or ($AllTests -and $Test) -or ($FTOnly -and $Test))
{
if (($AllTests -and $FTOnly) -or ($AllTests -and $Test) -or ($FTOnly -and $Test)) {
Write-Host "Invalid combination of flags" -ForegroundColor Red
return
}
$OpenConsolePlatform = $Platform
if ($Platform -eq 'x86')
{
if ($Platform -eq 'x86') {
$OpenConsolePlatform = 'Win32'
}
$OpenConsolePath = "$root\bin\$OpenConsolePlatform\$Configuration\OpenConsole.exe"
@@ -207,57 +200,46 @@ function Invoke-OpenConsoleTests()
# check if WinAppDriver needs to be started
$WinAppDriverExe = $null
if ($AllTests -or $FtOnly -or $Test -eq "uia")
{
if ($AllTests -or $FtOnly -or $Test -eq "uia") {
$WinAppDriverExe = [Diagnostics.Process]::Start("$root\dep\WinAppDriver\WinAppDriver.exe")
}
# select tests to run
if ($AllTests)
{
if ($AllTests) {
$TestsToRun = $TestConfig.tests.test
}
elseif ($FTOnly)
{
elseif ($FTOnly) {
$TestsToRun = $TestConfig.tests.test | Where-Object { $_.type -eq "ft" }
}
elseif ($Test)
{
elseif ($Test) {
$TestsToRun = $TestConfig.tests.test | Where-Object { $_.name -eq $Test }
}
else
{
else {
# run unit tests by default
$TestsToRun = $TestConfig.tests.test | Where-Object { $_.type -eq "unit" }
}
# run selected tests
foreach ($t in $TestsToRun)
{
foreach ($t in $TestsToRun) {
$currentTaefExe = $TaefExePath
if ($t.isolatedTaef -eq "true")
{
if ($t.isolatedTaef -eq "true") {
$currentTaefExe = (Join-Path (Split-Path (Join-Path $BinDir $t.binary)) "te.exe")
}
if ($t.type -eq "unit")
{
if ($t.type -eq "unit") {
& $currentTaefExe "$BinDir\$($t.binary)" $TaefArgs
}
elseif ($t.type -eq "ft")
{
elseif ($t.type -eq "ft") {
Invoke-TaefInNewWindow -OpenConsolePath $OpenConsolePath -TaefPath $currentTaefExe -TestDll "$BinDir\$($t.binary)" -TaefArgs $TaefArgs
}
else
{
else {
Write-Host "Invalid test type $t.type for test: $t.name" -ForegroundColor Red
return
}
}
# stop running WinAppDriver if it was launched
if ($WinAppDriverExe)
{
if ($WinAppDriverExe) {
Stop-Process -Id $WinAppDriverExe.Id
}
}
@@ -265,8 +247,7 @@ function Invoke-OpenConsoleTests()
#.SYNOPSIS
# Builds OpenConsole.slnx using msbuild. Any arguments get passed on to msbuild.
function Invoke-OpenConsoleBuild()
{
function Invoke-OpenConsoleBuild() {
$root = Find-OpenConsoleRoot
& "$root\dep\nuget\nuget.exe" restore "$root\OpenConsole.slnx"
& "$root\dep\nuget\nuget.exe" restore "$root\dep\nuget\packages.config"
@@ -283,18 +264,16 @@ function Invoke-OpenConsoleBuild()
#.PARAMETER Configuration
# The configuration of the OpenConsole executable to launch. Can be "Debug" or
# "Release". Defaults to "Debug".
function Start-OpenConsole()
{
function Start-OpenConsole() {
[CmdletBinding()]
Param (
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string]$Platform = "x64",
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string]$Configuration = "Debug"
)
if ($Platform -like "x86")
{
if ($Platform -like "x86") {
$Platform = "Win32"
}
& "$(Find-OpenConsoleRoot)\bin\$Platform\$Configuration\OpenConsole.exe"
@@ -310,18 +289,16 @@ function Start-OpenConsole()
#.PARAMETER Configuration
# The configuration of the OpenConsole executable to launch. Can be "Debug" or
# "Release". Defaults to "Debug".
function Debug-OpenConsole()
{
function Debug-OpenConsole() {
[CmdletBinding()]
Param (
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string]$Platform = "x64",
[parameter(Mandatory=$false)]
[parameter(Mandatory = $false)]
[string]$Configuration = "Debug"
)
if ($Platform -like "x86")
{
if ($Platform -like "x86") {
$Platform = "Win32"
}
$process = [Diagnostics.Process]::Start("$(Find-OpenConsoleRoot)\bin\$Platform\$Configuration\OpenConsole.exe")
@@ -336,10 +313,10 @@ function Debug-OpenConsole()
function Invoke-ClangFormat {
[CmdletBinding()]
Param (
[Parameter(Mandatory=$true,ValueFromPipeline=$true)]
[Parameter(Mandatory = $true, ValueFromPipeline = $true)]
[string[]]$Path,
[Parameter(Mandatory=$false)]
[Parameter(Mandatory = $false)]
[string]$ClangFormatPath = "clang-format" # (whichever one is in $PATH)
)
@@ -349,16 +326,17 @@ function Invoke-ClangFormat {
}
Process {
ForEach($_ in $Path) {
ForEach ($_ in $Path) {
$Paths += Get-Item $_ -ErrorAction Stop | Select -Expand FullName
}
}
End {
For($i = [int]0; $i -Lt $Paths.Length; $i += $BatchSize) {
For ($i = [int]0; $i -Lt $Paths.Length; $i += $BatchSize) {
Try {
& $ClangFormatPath -i $Paths[$i .. ($i + $BatchSize - 1)]
} Catch {
}
Catch {
Write-Error $_
}
}
@@ -412,23 +390,28 @@ function Invoke-CodeFormat() {
[CmdletBinding()]
Param (
[parameter(Mandatory=$false)]
[switch]$IgnoreXaml
[parameter(Mandatory = $false)]
[switch]$IgnoreXaml,
[Parameter(Mandatory = $false)]
[string]$ClangFormatPath
)
$clangFormatPath = & 'C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe' -latest -find "**\x64\bin\clang-format.exe"
If ([String]::IsNullOrEmpty($clangFormatPath)) {
Write-Error "No Visual Studio-supplied version of clang-format could be found."
if (!$ClangFormatPath) {
$ClangFormatPath = & 'C:\Program Files (x86)\Microsoft Visual Studio\Installer\vswhere.exe' -latest -find "**\x64\bin\clang-format.exe"
If ([String]::IsNullOrEmpty($ClangFormatPath)) {
Write-Error "No Visual Studio-supplied version of clang-format could be found."
}
}
$root = Find-OpenConsoleRoot
Get-ChildItem -Recurse "$root\src" -Include *.cpp, *.hpp, *.h |
Where FullName -NotLike "*Generated Files*" |
Invoke-ClangFormat -ClangFormatPath $clangFormatPath
Get-ChildItem -Recurse "$root\src" -Include *.cpp, *.hpp, *.h
| Where-Object FullName -NotLike "*Generated Files*"
| Invoke-ClangFormat -ClangFormatPath $ClangFormatPath
if (-Not $IgnoreXaml) {
Invoke-XamlFormat
}
}
Export-ModuleMember -Function Set-MsbuildDevEnvironment,Invoke-OpenConsoleTests,Invoke-OpenConsoleBuild,Start-OpenConsole,Debug-OpenConsole,Invoke-CodeFormat,Invoke-XamlFormat,Test-XamlFormat
Export-ModuleMember -Function Set-MsbuildDevEnvironment, Invoke-OpenConsoleTests, Invoke-OpenConsoleBuild, Start-OpenConsole, Debug-OpenConsole, Invoke-CodeFormat, Invoke-XamlFormat, Test-XamlFormat