SPIRVModule: Add class for iterating/manipulating SPIR-V

This commit is contained in:
Stenzek
2025-10-27 21:44:12 +10:00
parent eb0a8890e2
commit fc618b8b62
5 changed files with 463 additions and 0 deletions

View File

@@ -65,6 +65,8 @@ add_library(util
shiftjis.h
sockets.cpp
sockets.h
spirv_module.cpp
spirv_module.h
state_wrapper.cpp
state_wrapper.h
texture_decompress.cpp

333
src/util/spirv_module.cpp Normal file
View File

@@ -0,0 +1,333 @@
// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: CC-BY-NC-ND-4.0
#define SPV_ENABLE_UTILITY_CODE
#include "spirv_module.h"
#include "common/assert.h"
#include "common/error.h"
SPIRVModule::SPIRVModule(std::span<u32> module) : m_module(module)
{
DebugAssert(ValidateHeader(module, nullptr));
}
SPIRVModule::~SPIRVModule() = default;
void SPIRVModule::SetBound(u32 bound)
{
m_module[3] = bound;
}
bool SPIRVModule::SetDecoration(u32 id, u32 decoration, u32 value, Error* error)
{
for (auto it = begin(); it != end(); ++it)
{
if (it.GetOpcode() == spv::Op::OpDecorate)
{
const u32 target_id = it.GetOperand(0);
const u32 decor = it.GetOperand(1);
if (target_id == id && decor == decoration)
{
// Found existing decoration, update value
if (it.GetOperandCount() < 3)
{
Error::SetStringView(error, "Existing decoration has no value to update");
return false;
}
it.SetOperand(2, value);
return true;
}
}
}
#if 0
// Decoration not found, append new one
size_t old_size = m_module.size();
m_module = m_module.subspan(0, old_size + 3); // OpDecorate with 3 operands
m_module[old_size + 0] = (3 << 16) | static_cast<u32>(spv::Op::OpDecorate);
m_module[old_size + 1] = id;
m_module[old_size + 2] = decoration;
m_module[old_size + 3] = value;
#endif
Error::SetStringFmt(error, "OpDecorate({}) not found for {}", decoration, id);
return false;
}
std::optional<SPIRVModule> SPIRVModule::Get(std::span<u32> module, Error* error)
{
if (!ValidateHeader(module, error))
return std::nullopt;
return SPIRVModule(module);
}
bool SPIRVModule::ValidateHeader(std::span<const u32> module, Error* error)
{
if (module.size() < HEADER_SIZE)
{
Error::SetStringView(error, "Invalid SPIR-V module: too small for header");
return false;
}
if (module[0] != spv::MagicNumber)
{
Error::SetStringView(error, "Invalid SPIR-V magic number");
return false;
}
return true;
}
bool SPIRVModule::SPIRVInstructionIterator::operator!=(const SPIRVInstructionIterator& other) const
{
return !(*this == other);
}
bool SPIRVModule::SPIRVInstructionIterator::operator==(const SPIRVInstructionIterator& other) const
{
return m_module.data() == other.m_module.data() && m_offset == other.m_offset;
}
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator(std::span<u32> module)
// Skip SPIR-V header (5 words)
: m_module(module), m_offset(HEADER_SIZE)
{
DebugAssert(m_module.size() >= HEADER_SIZE);
}
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator(std::span<u32> module, size_t offset)
: m_module(module), m_offset(offset)
{
}
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator() : m_module(), m_offset(0)
{
}
spv::Op SPIRVModule::SPIRVInstructionIterator::GetOpcode() const
{
DebugAssert(IsValid());
return static_cast<spv::Op>(m_module[m_offset] & 0xFFFF);
}
u16 SPIRVModule::SPIRVInstructionIterator::GetWordCount() const
{
DebugAssert(IsValid());
return static_cast<u16>(std::min<size_t>(m_module[m_offset] >> 16, m_module.size() - m_offset));
}
bool SPIRVModule::SPIRVInstructionIterator::HasResult() const
{
return GetResultIndex() != -1;
}
bool SPIRVModule::SPIRVInstructionIterator::HasResultType() const
{
return GetResultTypeIndex() != -1;
}
u32 SPIRVModule::SPIRVInstructionIterator::GetResult() const
{
const int idx = GetResultIndex();
if (idx == -1)
Panic("Instruction has no result ID");
return m_module[m_offset + idx];
}
u32 SPIRVModule::SPIRVInstructionIterator::GetResultType() const
{
const int idx = GetResultTypeIndex();
if (idx == -1)
Panic("Instruction has no result type ID");
return m_module[m_offset + idx];
}
void SPIRVModule::SPIRVInstructionIterator::SetResult(u32 id)
{
const int idx = GetResultIndex();
if (idx == -1)
Panic("Instruction has no result ID");
m_module[m_offset + idx] = id;
}
void SPIRVModule::SPIRVInstructionIterator::SetResultType(u32 id)
{
const int idx = GetResultTypeIndex();
if (idx == -1)
Panic("Instruction has no result type ID");
m_module[m_offset + idx] = id;
}
size_t SPIRVModule::SPIRVInstructionIterator::GetOperandCount() const
{
DebugAssert(IsValid());
size_t count = GetWordCount() - 1; // Subtract opcode word
if (HasResultType())
count--;
if (HasResult())
count--;
return count;
}
u32 SPIRVModule::SPIRVInstructionIterator::GetOperand(size_t index) const
{
const size_t actual_index = GetOperandStartIndex() + index;
if (actual_index >= GetWordCount())
Panic("Operand index out of range");
return m_module[m_offset + actual_index];
}
void SPIRVModule::SPIRVInstructionIterator::SetOperand(size_t index, u32 value)
{
const size_t actual_index = GetOperandStartIndex() + index;
if (actual_index >= GetWordCount())
Panic("Operand index out of range");
m_module[m_offset + actual_index] = value;
}
const u32* SPIRVModule::SPIRVInstructionIterator::GetOperandPtr(size_t index) const
{
const size_t actual_index = GetOperandStartIndex() + index;
if (actual_index >= GetWordCount())
Panic("Operand index out of range");
return &m_module[m_offset + actual_index];
}
u32* SPIRVModule::SPIRVInstructionIterator::GetOperandPtr(size_t index)
{
const size_t actual_index = GetOperandStartIndex() + index;
if (actual_index >= GetWordCount())
Panic("Operand index out of range");
return &m_module[m_offset + actual_index];
}
std::span<const u32> SPIRVModule::SPIRVInstructionIterator::GetInstructionSpan() const
{
DebugAssert(IsValid());
return m_module.subspan(m_offset, GetWordCount());
}
std::span<u32> SPIRVModule::SPIRVInstructionIterator::GetInstructionSpan()
{
DebugAssert(IsValid());
return m_module.subspan(m_offset, GetWordCount());
}
const u32* SPIRVModule::SPIRVInstructionIterator::Data() const
{
DebugAssert(IsValid());
return &m_module[m_offset];
}
u32* SPIRVModule::SPIRVInstructionIterator::Data()
{
DebugAssert(IsValid());
return &m_module[m_offset];
}
SPIRVModule::SPIRVInstructionIterator SPIRVModule::SPIRVInstructionIterator::operator++(int)
{
SPIRVInstructionIterator tmp = *this;
++(*this);
return tmp;
}
SPIRVModule::SPIRVModule::SPIRVInstructionIterator& SPIRVModule::SPIRVInstructionIterator::operator++()
{
if (m_offset >= m_module.size())
Panic("Cannot increment past end");
m_offset += GetWordCount();
return *this;
}
SPIRVModule::SPIRVInstructionIterator SPIRVModule::SPIRVInstructionIterator::operator--(int)
{
SPIRVInstructionIterator tmp = *this;
--(*this);
return tmp;
}
SPIRVModule::SPIRVModule::SPIRVInstructionIterator& SPIRVModule::SPIRVInstructionIterator::operator--()
{
if (m_offset <= HEADER_SIZE)
Panic("Cannot decrement past beginning");
// Search backwards for instruction start
size_t prev = m_offset - 1;
while (prev >= HEADER_SIZE)
{
u16 wordCount = static_cast<u16>(m_module[prev] >> 16);
if (wordCount > 0 && prev + wordCount == m_offset)
{
m_offset = prev;
return *this;
}
if (prev == 0)
break;
--prev;
}
Panic("Failed to find previous instruction");
}
const u32& SPIRVModule::SPIRVInstructionIterator::operator*() const
{
return m_module[m_offset];
}
u32& SPIRVModule::SPIRVInstructionIterator::operator*()
{
return m_module[m_offset];
}
bool SPIRVModule::SPIRVInstructionIterator::IsValid() const
{
return m_offset < m_module.size();
}
bool SPIRVModule::SPIRVInstructionIterator::IsEnd() const
{
return m_offset >= m_module.size();
}
int SPIRVModule::SPIRVInstructionIterator::GetResultTypeIndex() const
{
const spv::Op op = GetOpcode();
bool has_result, has_result_type;
spv::HasResultAndType(op, &has_result, &has_result_type);
return (has_result && has_result_type) ? 1 : -1;
}
int SPIRVModule::SPIRVInstructionIterator::GetResultIndex() const
{
spv::Op op = GetOpcode();
bool has_result, has_result_type;
spv::HasResultAndType(op, &has_result, &has_result_type);
if (has_result && has_result_type)
return 2;
else if (op == spv::Op::OpLabel || op == spv::Op::OpString || op == spv::Op::OpExtInstImport)
return 1;
return -1;
}
size_t SPIRVModule::SPIRVInstructionIterator::GetOperandStartIndex() const
{
size_t idx = 1; // Skip opcode word
if (HasResultType())
idx++;
if (HasResult())
idx++;
return idx;
}

124
src/util/spirv_module.h Normal file
View File

@@ -0,0 +1,124 @@
// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: CC-BY-NC-ND-4.0
#pragma once
#include "common/types.h"
#include <iterator>
#include <optional>
#include <span>
#include <spirv.hpp>
class Error;
// Helper class for range-based iteration
class SPIRVModule
{
public:
static constexpr size_t HEADER_SIZE = 5;
class SPIRVInstructionIterator
{
public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = u32;
using difference_type = std::ptrdiff_t;
using pointer = u32*;
using const_pointer = const u32*;
using reference = u32&;
private:
std::span<u32> m_module;
size_t m_offset;
public:
// Constructors
SPIRVInstructionIterator();
SPIRVInstructionIterator(std::span<u32> module, size_t offset);
explicit SPIRVInstructionIterator(std::span<u32> module);
// Query methods
spv::Op GetOpcode() const;
u16 GetWordCount() const;
bool HasResult() const;
bool HasResultType() const;
u32 GetResult() const;
u32 GetResultType() const;
// Set result IDs
void SetResult(u32 id);
void SetResultType(u32 id);
// Operand access
size_t GetOperandCount() const;
u32 GetOperand(size_t index) const;
void SetOperand(size_t index, u32 value);
// Get pointer to operands (for direct manipulation)
u32* GetOperandPtr(size_t index);
const u32* GetOperandPtr(size_t index) const;
// Get span of current instruction
std::span<u32> GetInstructionSpan();
std::span<const u32> GetInstructionSpan() const;
// Raw instruction access
u32* Data();
const u32* Data() const;
// Iterator operations
SPIRVInstructionIterator& operator++();
SPIRVInstructionIterator operator++(int);
SPIRVInstructionIterator& operator--();
SPIRVInstructionIterator operator--(int);
bool operator==(const SPIRVInstructionIterator& other) const;
bool operator!=(const SPIRVInstructionIterator& other) const;
u32& operator*();
const u32& operator*() const;
bool IsValid() const;
bool IsEnd() const;
private:
int GetResultTypeIndex() const;
int GetResultIndex() const;
size_t GetOperandStartIndex() const;
};
public:
~SPIRVModule();
ALWAYS_INLINE SPIRVInstructionIterator begin() { return SPIRVInstructionIterator(m_module); }
ALWAYS_INLINE SPIRVInstructionIterator end() { return SPIRVInstructionIterator(m_module, m_module.size()); }
// Header access methods
ALWAYS_INLINE u32 GetMagicNumber() const { return m_module[0]; }
ALWAYS_INLINE u32 GetVersion() const { return m_module[1]; }
ALWAYS_INLINE u32 GetGeneratorMagic() const { return m_module[2]; }
ALWAYS_INLINE u32 GetBound() const { return m_module[3]; }
ALWAYS_INLINE u32 GetSchema() const { return m_module[4]; }
ALWAYS_INLINE std::span<u32> GetData() const { return m_module; }
void SetBound(u32 bound);
bool SetDecoration(u32 id, u32 decoration, u32 value, Error* error);
static std::optional<SPIRVModule> Get(std::span<u32> module, Error* error);
private:
explicit SPIRVModule(std::span<u32> module);
static bool ValidateHeader(std::span<const u32> module, Error* error);
std::span<u32> m_module;
};

View File

@@ -93,6 +93,7 @@
<ClInclude Include="shadergen.h" />
<ClInclude Include="shiftjis.h" />
<ClInclude Include="sockets.h" />
<ClInclude Include="spirv_module.h" />
<ClInclude Include="state_wrapper.h" />
<ClInclude Include="texture_decompress.h" />
<ClInclude Include="vulkan_builders.h" />
@@ -203,6 +204,7 @@
<ClCompile Include="shiftjis.cpp" />
<ClCompile Include="page_fault_handler.cpp" />
<ClCompile Include="sockets.cpp" />
<ClCompile Include="spirv_module.cpp" />
<ClCompile Include="state_wrapper.cpp" />
<ClCompile Include="texture_decompress.cpp" />
<ClCompile Include="vulkan_builders.cpp" />

View File

@@ -77,6 +77,7 @@
<ClInclude Include="animated_image.h" />
<ClInclude Include="dyn_shaderc.h" />
<ClInclude Include="dyn_spirv_cross.h" />
<ClInclude Include="spirv_module.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="state_wrapper.cpp" />
@@ -160,6 +161,7 @@
<ClCompile Include="texture_decompress.cpp" />
<ClCompile Include="opengl_context_sdl.cpp" />
<ClCompile Include="animated_image.cpp" />
<ClCompile Include="spirv_module.cpp" />
</ItemGroup>
<ItemGroup>
<None Include="metal_shaders.metal" />