Merge pull request #2784 from ReinUsesLisp/smem
shader_ir: Implement shared memory
This commit is contained in:
		
						commit
						b31880dc5e
					
				| @ -325,6 +325,7 @@ public: | |||||||
|         DeclareRegisters(); |         DeclareRegisters(); | ||||||
|         DeclarePredicates(); |         DeclarePredicates(); | ||||||
|         DeclareLocalMemory(); |         DeclareLocalMemory(); | ||||||
|  |         DeclareSharedMemory(); | ||||||
|         DeclareInternalFlags(); |         DeclareInternalFlags(); | ||||||
|         DeclareInputAttributes(); |         DeclareInputAttributes(); | ||||||
|         DeclareOutputAttributes(); |         DeclareOutputAttributes(); | ||||||
| @ -499,6 +500,13 @@ private: | |||||||
|         code.AddNewLine(); |         code.AddNewLine(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     void DeclareSharedMemory() { | ||||||
|  |         if (stage != ProgramType::Compute) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         code.AddLine("shared uint {}[];", GetSharedMemory()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     void DeclareInternalFlags() { |     void DeclareInternalFlags() { | ||||||
|         for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { |         for (u32 flag = 0; flag < static_cast<u32>(InternalFlag::Amount); flag++) { | ||||||
|             const auto flag_code = static_cast<InternalFlag>(flag); |             const auto flag_code = static_cast<InternalFlag>(flag); | ||||||
| @ -881,6 +889,12 @@ private: | |||||||
|                 Type::Uint}; |                 Type::Uint}; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if (const auto smem = std::get_if<SmemNode>(&*node)) { | ||||||
|  |             return { | ||||||
|  |                 fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), | ||||||
|  |                 Type::Uint}; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { |         if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) { | ||||||
|             return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool}; |             return {GetInternalFlag(internal_flag->GetFlag()), Type::Bool}; | ||||||
|         } |         } | ||||||
| @ -1286,6 +1300,11 @@ private: | |||||||
|             target = { |             target = { | ||||||
|                 fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), |                 fmt::format("{}[{} >> 2]", GetLocalMemory(), Visit(lmem->GetAddress()).AsUint()), | ||||||
|                 Type::Uint}; |                 Type::Uint}; | ||||||
|  |         } else if (const auto smem = std::get_if<SmemNode>(&*dest)) { | ||||||
|  |             ASSERT(stage == ProgramType::Compute); | ||||||
|  |             target = { | ||||||
|  |                 fmt::format("{}[{} >> 2]", GetSharedMemory(), Visit(smem->GetAddress()).AsUint()), | ||||||
|  |                 Type::Uint}; | ||||||
|         } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { |         } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) { | ||||||
|             const std::string real = Visit(gmem->GetRealAddress()).AsUint(); |             const std::string real = Visit(gmem->GetRealAddress()).AsUint(); | ||||||
|             const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); |             const std::string base = Visit(gmem->GetBaseAddress()).AsUint(); | ||||||
| @ -2175,6 +2194,10 @@ private: | |||||||
|         return "lmem_" + suffix; |         return "lmem_" + suffix; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     std::string GetSharedMemory() const { | ||||||
|  |         return fmt::format("smem_{}", suffix); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     std::string GetInternalFlag(InternalFlag flag) const { |     std::string GetInternalFlag(InternalFlag flag) const { | ||||||
|         constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", |         constexpr std::array InternalFlagNames = {"zero_flag", "sign_flag", "carry_flag", | ||||||
|                                                   "overflow_flag"}; |                                                   "overflow_flag"}; | ||||||
|  | |||||||
| @ -35,7 +35,7 @@ u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) { | |||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| } | } | ||||||
| } // namespace
 | } // Anonymous namespace
 | ||||||
| 
 | 
 | ||||||
| u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { | u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { | ||||||
|     const Instruction instr = {program_code[pc]}; |     const Instruction instr = {program_code[pc]}; | ||||||
| @ -106,16 +106,17 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { | |||||||
|         } |         } | ||||||
|         break; |         break; | ||||||
|     } |     } | ||||||
|     case OpCode::Id::LD_L: { |     case OpCode::Id::LD_L: | ||||||
|         LOG_DEBUG(HW_GPU, "LD_L cache management mode: {}", |         LOG_DEBUG(HW_GPU, "LD_L cache management mode: {}", static_cast<u64>(instr.ld_l.unknown)); | ||||||
|                   static_cast<u64>(instr.ld_l.unknown.Value())); |         [[fallthrough]]; | ||||||
| 
 |     case OpCode::Id::LD_S: { | ||||||
|         const auto GetLmem = [&](s32 offset) { |         const auto GetMemory = [&](s32 offset) { | ||||||
|             ASSERT(offset % 4 == 0); |             ASSERT(offset % 4 == 0); | ||||||
|             const Node immediate_offset = Immediate(static_cast<s32>(instr.smem_imm) + offset); |             const Node immediate_offset = Immediate(static_cast<s32>(instr.smem_imm) + offset); | ||||||
|             const Node address = Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), |             const Node address = Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), | ||||||
|                                            immediate_offset); |                                            immediate_offset); | ||||||
|             return GetLocalMemory(address); |             return opcode->get().GetId() == OpCode::Id::LD_S ? GetSharedMemory(address) | ||||||
|  |                                                              : GetLocalMemory(address); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         switch (instr.ldst_sl.type.Value()) { |         switch (instr.ldst_sl.type.Value()) { | ||||||
| @ -135,14 +136,16 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { | |||||||
|                     return 0; |                     return 0; | ||||||
|                 } |                 } | ||||||
|             }(); |             }(); | ||||||
|             for (u32 i = 0; i < count; ++i) |             for (u32 i = 0; i < count; ++i) { | ||||||
|                 SetTemporary(bb, i, GetLmem(i * 4)); |                 SetTemporary(bb, i, GetMemory(i * 4)); | ||||||
|             for (u32 i = 0; i < count; ++i) |             } | ||||||
|  |             for (u32 i = 0; i < count; ++i) { | ||||||
|                 SetRegister(bb, instr.gpr0.Value() + i, GetTemporary(i)); |                 SetRegister(bb, instr.gpr0.Value() + i, GetTemporary(i)); | ||||||
|  |             } | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         default: |         default: | ||||||
|             UNIMPLEMENTED_MSG("LD_L Unhandled type: {}", |             UNIMPLEMENTED_MSG("{} Unhandled type: {}", opcode->get().GetName(), | ||||||
|                               static_cast<u32>(instr.ldst_sl.type.Value())); |                               static_cast<u32>(instr.ldst_sl.type.Value())); | ||||||
|         } |         } | ||||||
|         break; |         break; | ||||||
| @ -209,27 +212,34 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { | |||||||
| 
 | 
 | ||||||
|         break; |         break; | ||||||
|     } |     } | ||||||
|     case OpCode::Id::ST_L: { |     case OpCode::Id::ST_L: | ||||||
|         LOG_DEBUG(HW_GPU, "ST_L cache management mode: {}", |         LOG_DEBUG(HW_GPU, "ST_L cache management mode: {}", | ||||||
|                   static_cast<u64>(instr.st_l.cache_management.Value())); |                   static_cast<u64>(instr.st_l.cache_management.Value())); | ||||||
| 
 |         [[fallthrough]]; | ||||||
|         const auto GetLmemAddr = [&](s32 offset) { |     case OpCode::Id::ST_S: { | ||||||
|  |         const auto GetAddress = [&](s32 offset) { | ||||||
|             ASSERT(offset % 4 == 0); |             ASSERT(offset % 4 == 0); | ||||||
|             const Node immediate = Immediate(static_cast<s32>(instr.smem_imm) + offset); |             const Node immediate = Immediate(static_cast<s32>(instr.smem_imm) + offset); | ||||||
|             return Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), immediate); |             return Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), immediate); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|  |         const auto set_memory = opcode->get().GetId() == OpCode::Id::ST_L | ||||||
|  |                                     ? &ShaderIR::SetLocalMemory | ||||||
|  |                                     : &ShaderIR::SetSharedMemory; | ||||||
|  | 
 | ||||||
|         switch (instr.ldst_sl.type.Value()) { |         switch (instr.ldst_sl.type.Value()) { | ||||||
|         case Tegra::Shader::StoreType::Bits128: |         case Tegra::Shader::StoreType::Bits128: | ||||||
|             SetLocalMemory(bb, GetLmemAddr(12), GetRegister(instr.gpr0.Value() + 3)); |             (this->*set_memory)(bb, GetAddress(12), GetRegister(instr.gpr0.Value() + 3)); | ||||||
|             SetLocalMemory(bb, GetLmemAddr(8), GetRegister(instr.gpr0.Value() + 2)); |             (this->*set_memory)(bb, GetAddress(8), GetRegister(instr.gpr0.Value() + 2)); | ||||||
|  |             [[fallthrough]]; | ||||||
|         case Tegra::Shader::StoreType::Bits64: |         case Tegra::Shader::StoreType::Bits64: | ||||||
|             SetLocalMemory(bb, GetLmemAddr(4), GetRegister(instr.gpr0.Value() + 1)); |             (this->*set_memory)(bb, GetAddress(4), GetRegister(instr.gpr0.Value() + 1)); | ||||||
|  |             [[fallthrough]]; | ||||||
|         case Tegra::Shader::StoreType::Bits32: |         case Tegra::Shader::StoreType::Bits32: | ||||||
|             SetLocalMemory(bb, GetLmemAddr(0), GetRegister(instr.gpr0)); |             (this->*set_memory)(bb, GetAddress(0), GetRegister(instr.gpr0)); | ||||||
|             break; |             break; | ||||||
|         default: |         default: | ||||||
|             UNIMPLEMENTED_MSG("ST_L Unhandled type: {}", |             UNIMPLEMENTED_MSG("{} unhandled type: {}", opcode->get().GetName(), | ||||||
|                               static_cast<u32>(instr.ldst_sl.type.Value())); |                               static_cast<u32>(instr.ldst_sl.type.Value())); | ||||||
|         } |         } | ||||||
|         break; |         break; | ||||||
|  | |||||||
| @ -206,12 +206,13 @@ class PredicateNode; | |||||||
| class AbufNode; | class AbufNode; | ||||||
| class CbufNode; | class CbufNode; | ||||||
| class LmemNode; | class LmemNode; | ||||||
|  | class SmemNode; | ||||||
| class GmemNode; | class GmemNode; | ||||||
| class CommentNode; | class CommentNode; | ||||||
| 
 | 
 | ||||||
| using NodeData = | using NodeData = | ||||||
|     std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode, |     std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode, | ||||||
|                  PredicateNode, AbufNode, CbufNode, LmemNode, GmemNode, CommentNode>; |                  PredicateNode, AbufNode, CbufNode, LmemNode, SmemNode, GmemNode, CommentNode>; | ||||||
| using Node = std::shared_ptr<NodeData>; | using Node = std::shared_ptr<NodeData>; | ||||||
| using Node4 = std::array<Node, 4>; | using Node4 = std::array<Node, 4>; | ||||||
| using NodeBlock = std::vector<Node>; | using NodeBlock = std::vector<Node>; | ||||||
| @ -583,6 +584,19 @@ private: | |||||||
|     Node address; |     Node address; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | /// Shared memory node
 | ||||||
|  | class SmemNode final { | ||||||
|  | public: | ||||||
|  |     explicit SmemNode(Node address) : address{std::move(address)} {} | ||||||
|  | 
 | ||||||
|  |     const Node& GetAddress() const { | ||||||
|  |         return address; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | private: | ||||||
|  |     Node address; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| /// Global memory node
 | /// Global memory node
 | ||||||
| class GmemNode final { | class GmemNode final { | ||||||
| public: | public: | ||||||
|  | |||||||
| @ -137,6 +137,10 @@ Node ShaderIR::GetLocalMemory(Node address) { | |||||||
|     return MakeNode<LmemNode>(std::move(address)); |     return MakeNode<LmemNode>(std::move(address)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Node ShaderIR::GetSharedMemory(Node address) { | ||||||
|  |     return MakeNode<SmemNode>(std::move(address)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Node ShaderIR::GetTemporary(u32 id) { | Node ShaderIR::GetTemporary(u32 id) { | ||||||
|     return GetRegister(Register::ZeroIndex + 1 + id); |     return GetRegister(Register::ZeroIndex + 1 + id); | ||||||
| } | } | ||||||
| @ -378,6 +382,11 @@ void ShaderIR::SetLocalMemory(NodeBlock& bb, Node address, Node value) { | |||||||
|         Operation(OperationCode::Assign, GetLocalMemory(std::move(address)), std::move(value))); |         Operation(OperationCode::Assign, GetLocalMemory(std::move(address)), std::move(value))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void ShaderIR::SetSharedMemory(NodeBlock& bb, Node address, Node value) { | ||||||
|  |     bb.push_back( | ||||||
|  |         Operation(OperationCode::Assign, GetSharedMemory(std::move(address)), std::move(value))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void ShaderIR::SetTemporary(NodeBlock& bb, u32 id, Node value) { | void ShaderIR::SetTemporary(NodeBlock& bb, u32 id, Node value) { | ||||||
|     SetRegister(bb, Register::ZeroIndex + 1 + id, std::move(value)); |     SetRegister(bb, Register::ZeroIndex + 1 + id, std::move(value)); | ||||||
| } | } | ||||||
|  | |||||||
| @ -208,6 +208,8 @@ private: | |||||||
|     Node GetInternalFlag(InternalFlag flag, bool negated = false); |     Node GetInternalFlag(InternalFlag flag, bool negated = false); | ||||||
|     /// Generates a node representing a local memory address
 |     /// Generates a node representing a local memory address
 | ||||||
|     Node GetLocalMemory(Node address); |     Node GetLocalMemory(Node address); | ||||||
|  |     /// Generates a node representing a shared memory address
 | ||||||
|  |     Node GetSharedMemory(Node address); | ||||||
|     /// Generates a temporary, internally it uses a post-RZ register
 |     /// Generates a temporary, internally it uses a post-RZ register
 | ||||||
|     Node GetTemporary(u32 id); |     Node GetTemporary(u32 id); | ||||||
| 
 | 
 | ||||||
| @ -217,8 +219,10 @@ private: | |||||||
|     void SetPredicate(NodeBlock& bb, u64 dest, Node src); |     void SetPredicate(NodeBlock& bb, u64 dest, Node src); | ||||||
|     /// Sets an internal flag. src value must be a bool-evaluated node
 |     /// Sets an internal flag. src value must be a bool-evaluated node
 | ||||||
|     void SetInternalFlag(NodeBlock& bb, InternalFlag flag, Node value); |     void SetInternalFlag(NodeBlock& bb, InternalFlag flag, Node value); | ||||||
|     /// Sets a local memory address. address and value must be a number-evaluated node
 |     /// Sets a local memory address with a value.
 | ||||||
|     void SetLocalMemory(NodeBlock& bb, Node address, Node value); |     void SetLocalMemory(NodeBlock& bb, Node address, Node value); | ||||||
|  |     /// Sets a shared memory address with a value.
 | ||||||
|  |     void SetSharedMemory(NodeBlock& bb, Node address, Node value); | ||||||
|     /// Sets a temporary. Internally it uses a post-RZ register
 |     /// Sets a temporary. Internally it uses a post-RZ register
 | ||||||
|     void SetTemporary(NodeBlock& bb, u32 id, Node value); |     void SetTemporary(NodeBlock& bb, u32 id, Node value); | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 bunnei
						bunnei