mirror of
https://github.com/stenzek/duckstation.git
synced 2026-02-10 16:34:36 +00:00
SPIRVModule: Add class for iterating/manipulating SPIR-V
This commit is contained in:
@@ -65,6 +65,8 @@ add_library(util
|
||||
shiftjis.h
|
||||
sockets.cpp
|
||||
sockets.h
|
||||
spirv_module.cpp
|
||||
spirv_module.h
|
||||
state_wrapper.cpp
|
||||
state_wrapper.h
|
||||
texture_decompress.cpp
|
||||
|
||||
333
src/util/spirv_module.cpp
Normal file
333
src/util/spirv_module.cpp
Normal file
@@ -0,0 +1,333 @@
|
||||
// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
|
||||
// SPDX-License-Identifier: CC-BY-NC-ND-4.0
|
||||
|
||||
#define SPV_ENABLE_UTILITY_CODE
|
||||
|
||||
#include "spirv_module.h"
|
||||
|
||||
#include "common/assert.h"
|
||||
#include "common/error.h"
|
||||
|
||||
SPIRVModule::SPIRVModule(std::span<u32> module) : m_module(module)
|
||||
{
|
||||
DebugAssert(ValidateHeader(module, nullptr));
|
||||
}
|
||||
|
||||
SPIRVModule::~SPIRVModule() = default;
|
||||
|
||||
void SPIRVModule::SetBound(u32 bound)
|
||||
{
|
||||
m_module[3] = bound;
|
||||
}
|
||||
|
||||
bool SPIRVModule::SetDecoration(u32 id, u32 decoration, u32 value, Error* error)
|
||||
{
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
{
|
||||
if (it.GetOpcode() == spv::Op::OpDecorate)
|
||||
{
|
||||
const u32 target_id = it.GetOperand(0);
|
||||
const u32 decor = it.GetOperand(1);
|
||||
if (target_id == id && decor == decoration)
|
||||
{
|
||||
// Found existing decoration, update value
|
||||
if (it.GetOperandCount() < 3)
|
||||
{
|
||||
Error::SetStringView(error, "Existing decoration has no value to update");
|
||||
return false;
|
||||
}
|
||||
it.SetOperand(2, value);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
// Decoration not found, append new one
|
||||
size_t old_size = m_module.size();
|
||||
m_module = m_module.subspan(0, old_size + 3); // OpDecorate with 3 operands
|
||||
m_module[old_size + 0] = (3 << 16) | static_cast<u32>(spv::Op::OpDecorate);
|
||||
m_module[old_size + 1] = id;
|
||||
m_module[old_size + 2] = decoration;
|
||||
m_module[old_size + 3] = value;
|
||||
#endif
|
||||
|
||||
Error::SetStringFmt(error, "OpDecorate({}) not found for {}", decoration, id);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::optional<SPIRVModule> SPIRVModule::Get(std::span<u32> module, Error* error)
|
||||
{
|
||||
if (!ValidateHeader(module, error))
|
||||
return std::nullopt;
|
||||
|
||||
return SPIRVModule(module);
|
||||
}
|
||||
|
||||
bool SPIRVModule::ValidateHeader(std::span<const u32> module, Error* error)
|
||||
{
|
||||
if (module.size() < HEADER_SIZE)
|
||||
{
|
||||
Error::SetStringView(error, "Invalid SPIR-V module: too small for header");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (module[0] != spv::MagicNumber)
|
||||
{
|
||||
Error::SetStringView(error, "Invalid SPIR-V magic number");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::operator!=(const SPIRVInstructionIterator& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::operator==(const SPIRVInstructionIterator& other) const
|
||||
{
|
||||
return m_module.data() == other.m_module.data() && m_offset == other.m_offset;
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator(std::span<u32> module)
|
||||
// Skip SPIR-V header (5 words)
|
||||
: m_module(module), m_offset(HEADER_SIZE)
|
||||
{
|
||||
DebugAssert(m_module.size() >= HEADER_SIZE);
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator(std::span<u32> module, size_t offset)
|
||||
: m_module(module), m_offset(offset)
|
||||
{
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVInstructionIterator::SPIRVInstructionIterator() : m_module(), m_offset(0)
|
||||
{
|
||||
}
|
||||
|
||||
spv::Op SPIRVModule::SPIRVInstructionIterator::GetOpcode() const
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return static_cast<spv::Op>(m_module[m_offset] & 0xFFFF);
|
||||
}
|
||||
|
||||
u16 SPIRVModule::SPIRVInstructionIterator::GetWordCount() const
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return static_cast<u16>(std::min<size_t>(m_module[m_offset] >> 16, m_module.size() - m_offset));
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::HasResult() const
|
||||
{
|
||||
return GetResultIndex() != -1;
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::HasResultType() const
|
||||
{
|
||||
return GetResultTypeIndex() != -1;
|
||||
}
|
||||
|
||||
u32 SPIRVModule::SPIRVInstructionIterator::GetResult() const
|
||||
{
|
||||
const int idx = GetResultIndex();
|
||||
if (idx == -1)
|
||||
Panic("Instruction has no result ID");
|
||||
|
||||
return m_module[m_offset + idx];
|
||||
}
|
||||
|
||||
u32 SPIRVModule::SPIRVInstructionIterator::GetResultType() const
|
||||
{
|
||||
const int idx = GetResultTypeIndex();
|
||||
if (idx == -1)
|
||||
Panic("Instruction has no result type ID");
|
||||
|
||||
return m_module[m_offset + idx];
|
||||
}
|
||||
|
||||
void SPIRVModule::SPIRVInstructionIterator::SetResult(u32 id)
|
||||
{
|
||||
const int idx = GetResultIndex();
|
||||
if (idx == -1)
|
||||
Panic("Instruction has no result ID");
|
||||
|
||||
m_module[m_offset + idx] = id;
|
||||
}
|
||||
|
||||
void SPIRVModule::SPIRVInstructionIterator::SetResultType(u32 id)
|
||||
{
|
||||
const int idx = GetResultTypeIndex();
|
||||
if (idx == -1)
|
||||
Panic("Instruction has no result type ID");
|
||||
|
||||
m_module[m_offset + idx] = id;
|
||||
}
|
||||
|
||||
size_t SPIRVModule::SPIRVInstructionIterator::GetOperandCount() const
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
|
||||
size_t count = GetWordCount() - 1; // Subtract opcode word
|
||||
if (HasResultType())
|
||||
count--;
|
||||
if (HasResult())
|
||||
count--;
|
||||
return count;
|
||||
}
|
||||
|
||||
u32 SPIRVModule::SPIRVInstructionIterator::GetOperand(size_t index) const
|
||||
{
|
||||
const size_t actual_index = GetOperandStartIndex() + index;
|
||||
if (actual_index >= GetWordCount())
|
||||
Panic("Operand index out of range");
|
||||
|
||||
return m_module[m_offset + actual_index];
|
||||
}
|
||||
|
||||
void SPIRVModule::SPIRVInstructionIterator::SetOperand(size_t index, u32 value)
|
||||
{
|
||||
const size_t actual_index = GetOperandStartIndex() + index;
|
||||
if (actual_index >= GetWordCount())
|
||||
Panic("Operand index out of range");
|
||||
|
||||
m_module[m_offset + actual_index] = value;
|
||||
}
|
||||
|
||||
const u32* SPIRVModule::SPIRVInstructionIterator::GetOperandPtr(size_t index) const
|
||||
{
|
||||
const size_t actual_index = GetOperandStartIndex() + index;
|
||||
if (actual_index >= GetWordCount())
|
||||
Panic("Operand index out of range");
|
||||
|
||||
return &m_module[m_offset + actual_index];
|
||||
}
|
||||
|
||||
u32* SPIRVModule::SPIRVInstructionIterator::GetOperandPtr(size_t index)
|
||||
{
|
||||
const size_t actual_index = GetOperandStartIndex() + index;
|
||||
if (actual_index >= GetWordCount())
|
||||
Panic("Operand index out of range");
|
||||
|
||||
return &m_module[m_offset + actual_index];
|
||||
}
|
||||
|
||||
std::span<const u32> SPIRVModule::SPIRVInstructionIterator::GetInstructionSpan() const
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return m_module.subspan(m_offset, GetWordCount());
|
||||
}
|
||||
|
||||
std::span<u32> SPIRVModule::SPIRVInstructionIterator::GetInstructionSpan()
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return m_module.subspan(m_offset, GetWordCount());
|
||||
}
|
||||
|
||||
const u32* SPIRVModule::SPIRVInstructionIterator::Data() const
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return &m_module[m_offset];
|
||||
}
|
||||
|
||||
u32* SPIRVModule::SPIRVInstructionIterator::Data()
|
||||
{
|
||||
DebugAssert(IsValid());
|
||||
return &m_module[m_offset];
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVInstructionIterator SPIRVModule::SPIRVInstructionIterator::operator++(int)
|
||||
{
|
||||
SPIRVInstructionIterator tmp = *this;
|
||||
++(*this);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVModule::SPIRVInstructionIterator& SPIRVModule::SPIRVInstructionIterator::operator++()
|
||||
{
|
||||
if (m_offset >= m_module.size())
|
||||
Panic("Cannot increment past end");
|
||||
|
||||
m_offset += GetWordCount();
|
||||
return *this;
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVInstructionIterator SPIRVModule::SPIRVInstructionIterator::operator--(int)
|
||||
{
|
||||
SPIRVInstructionIterator tmp = *this;
|
||||
--(*this);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
SPIRVModule::SPIRVModule::SPIRVInstructionIterator& SPIRVModule::SPIRVInstructionIterator::operator--()
|
||||
{
|
||||
if (m_offset <= HEADER_SIZE)
|
||||
Panic("Cannot decrement past beginning");
|
||||
|
||||
// Search backwards for instruction start
|
||||
size_t prev = m_offset - 1;
|
||||
while (prev >= HEADER_SIZE)
|
||||
{
|
||||
u16 wordCount = static_cast<u16>(m_module[prev] >> 16);
|
||||
if (wordCount > 0 && prev + wordCount == m_offset)
|
||||
{
|
||||
m_offset = prev;
|
||||
return *this;
|
||||
}
|
||||
if (prev == 0)
|
||||
break;
|
||||
--prev;
|
||||
}
|
||||
Panic("Failed to find previous instruction");
|
||||
}
|
||||
|
||||
const u32& SPIRVModule::SPIRVInstructionIterator::operator*() const
|
||||
{
|
||||
return m_module[m_offset];
|
||||
}
|
||||
|
||||
u32& SPIRVModule::SPIRVInstructionIterator::operator*()
|
||||
{
|
||||
return m_module[m_offset];
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::IsValid() const
|
||||
{
|
||||
return m_offset < m_module.size();
|
||||
}
|
||||
|
||||
bool SPIRVModule::SPIRVInstructionIterator::IsEnd() const
|
||||
{
|
||||
return m_offset >= m_module.size();
|
||||
}
|
||||
|
||||
int SPIRVModule::SPIRVInstructionIterator::GetResultTypeIndex() const
|
||||
{
|
||||
const spv::Op op = GetOpcode();
|
||||
bool has_result, has_result_type;
|
||||
spv::HasResultAndType(op, &has_result, &has_result_type);
|
||||
return (has_result && has_result_type) ? 1 : -1;
|
||||
}
|
||||
|
||||
int SPIRVModule::SPIRVInstructionIterator::GetResultIndex() const
|
||||
{
|
||||
spv::Op op = GetOpcode();
|
||||
bool has_result, has_result_type;
|
||||
spv::HasResultAndType(op, &has_result, &has_result_type);
|
||||
if (has_result && has_result_type)
|
||||
return 2;
|
||||
else if (op == spv::Op::OpLabel || op == spv::Op::OpString || op == spv::Op::OpExtInstImport)
|
||||
return 1;
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t SPIRVModule::SPIRVInstructionIterator::GetOperandStartIndex() const
|
||||
{
|
||||
size_t idx = 1; // Skip opcode word
|
||||
if (HasResultType())
|
||||
idx++;
|
||||
if (HasResult())
|
||||
idx++;
|
||||
return idx;
|
||||
}
|
||||
124
src/util/spirv_module.h
Normal file
124
src/util/spirv_module.h
Normal file
@@ -0,0 +1,124 @@
|
||||
// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
|
||||
// SPDX-License-Identifier: CC-BY-NC-ND-4.0
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/types.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <spirv.hpp>
|
||||
|
||||
class Error;
|
||||
|
||||
// Helper class for range-based iteration
|
||||
class SPIRVModule
|
||||
{
|
||||
public:
|
||||
static constexpr size_t HEADER_SIZE = 5;
|
||||
|
||||
class SPIRVInstructionIterator
|
||||
{
|
||||
public:
|
||||
using iterator_category = std::bidirectional_iterator_tag;
|
||||
using value_type = u32;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using pointer = u32*;
|
||||
using const_pointer = const u32*;
|
||||
using reference = u32&;
|
||||
|
||||
private:
|
||||
std::span<u32> m_module;
|
||||
size_t m_offset;
|
||||
|
||||
public:
|
||||
// Constructors
|
||||
SPIRVInstructionIterator();
|
||||
|
||||
SPIRVInstructionIterator(std::span<u32> module, size_t offset);
|
||||
|
||||
explicit SPIRVInstructionIterator(std::span<u32> module);
|
||||
|
||||
// Query methods
|
||||
spv::Op GetOpcode() const;
|
||||
u16 GetWordCount() const;
|
||||
|
||||
bool HasResult() const;
|
||||
bool HasResultType() const;
|
||||
|
||||
u32 GetResult() const;
|
||||
u32 GetResultType() const;
|
||||
|
||||
// Set result IDs
|
||||
void SetResult(u32 id);
|
||||
void SetResultType(u32 id);
|
||||
|
||||
// Operand access
|
||||
size_t GetOperandCount() const;
|
||||
u32 GetOperand(size_t index) const;
|
||||
void SetOperand(size_t index, u32 value);
|
||||
|
||||
// Get pointer to operands (for direct manipulation)
|
||||
u32* GetOperandPtr(size_t index);
|
||||
const u32* GetOperandPtr(size_t index) const;
|
||||
|
||||
// Get span of current instruction
|
||||
std::span<u32> GetInstructionSpan();
|
||||
std::span<const u32> GetInstructionSpan() const;
|
||||
|
||||
// Raw instruction access
|
||||
u32* Data();
|
||||
const u32* Data() const;
|
||||
|
||||
// Iterator operations
|
||||
SPIRVInstructionIterator& operator++();
|
||||
SPIRVInstructionIterator operator++(int);
|
||||
|
||||
SPIRVInstructionIterator& operator--();
|
||||
SPIRVInstructionIterator operator--(int);
|
||||
|
||||
bool operator==(const SPIRVInstructionIterator& other) const;
|
||||
bool operator!=(const SPIRVInstructionIterator& other) const;
|
||||
|
||||
u32& operator*();
|
||||
const u32& operator*() const;
|
||||
|
||||
bool IsValid() const;
|
||||
bool IsEnd() const;
|
||||
|
||||
private:
|
||||
int GetResultTypeIndex() const;
|
||||
int GetResultIndex() const;
|
||||
|
||||
size_t GetOperandStartIndex() const;
|
||||
};
|
||||
|
||||
public:
|
||||
~SPIRVModule();
|
||||
|
||||
ALWAYS_INLINE SPIRVInstructionIterator begin() { return SPIRVInstructionIterator(m_module); }
|
||||
ALWAYS_INLINE SPIRVInstructionIterator end() { return SPIRVInstructionIterator(m_module, m_module.size()); }
|
||||
|
||||
// Header access methods
|
||||
ALWAYS_INLINE u32 GetMagicNumber() const { return m_module[0]; }
|
||||
ALWAYS_INLINE u32 GetVersion() const { return m_module[1]; }
|
||||
ALWAYS_INLINE u32 GetGeneratorMagic() const { return m_module[2]; }
|
||||
ALWAYS_INLINE u32 GetBound() const { return m_module[3]; }
|
||||
ALWAYS_INLINE u32 GetSchema() const { return m_module[4]; }
|
||||
|
||||
ALWAYS_INLINE std::span<u32> GetData() const { return m_module; }
|
||||
|
||||
void SetBound(u32 bound);
|
||||
|
||||
bool SetDecoration(u32 id, u32 decoration, u32 value, Error* error);
|
||||
|
||||
static std::optional<SPIRVModule> Get(std::span<u32> module, Error* error);
|
||||
|
||||
private:
|
||||
explicit SPIRVModule(std::span<u32> module);
|
||||
|
||||
static bool ValidateHeader(std::span<const u32> module, Error* error);
|
||||
|
||||
std::span<u32> m_module;
|
||||
};
|
||||
@@ -93,6 +93,7 @@
|
||||
<ClInclude Include="shadergen.h" />
|
||||
<ClInclude Include="shiftjis.h" />
|
||||
<ClInclude Include="sockets.h" />
|
||||
<ClInclude Include="spirv_module.h" />
|
||||
<ClInclude Include="state_wrapper.h" />
|
||||
<ClInclude Include="texture_decompress.h" />
|
||||
<ClInclude Include="vulkan_builders.h" />
|
||||
@@ -203,6 +204,7 @@
|
||||
<ClCompile Include="shiftjis.cpp" />
|
||||
<ClCompile Include="page_fault_handler.cpp" />
|
||||
<ClCompile Include="sockets.cpp" />
|
||||
<ClCompile Include="spirv_module.cpp" />
|
||||
<ClCompile Include="state_wrapper.cpp" />
|
||||
<ClCompile Include="texture_decompress.cpp" />
|
||||
<ClCompile Include="vulkan_builders.cpp" />
|
||||
|
||||
@@ -77,6 +77,7 @@
|
||||
<ClInclude Include="animated_image.h" />
|
||||
<ClInclude Include="dyn_shaderc.h" />
|
||||
<ClInclude Include="dyn_spirv_cross.h" />
|
||||
<ClInclude Include="spirv_module.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="state_wrapper.cpp" />
|
||||
@@ -160,6 +161,7 @@
|
||||
<ClCompile Include="texture_decompress.cpp" />
|
||||
<ClCompile Include="opengl_context_sdl.cpp" />
|
||||
<ClCompile Include="animated_image.cpp" />
|
||||
<ClCompile Include="spirv_module.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="metal_shaders.metal" />
|
||||
|
||||
Reference in New Issue
Block a user