Merge pull request #3032 from ReinUsesLisp/simplify-control-flow-brx
shader/control_flow: Abstract repeated code chunks in BRX tracking
This commit is contained in:
		
						commit
						b6ae48966d
					
				| @ -16,7 +16,9 @@ | ||||
| #include "video_core/shader/shader_ir.h" | ||||
| 
 | ||||
| namespace VideoCommon::Shader { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using Tegra::Shader::Instruction; | ||||
| using Tegra::Shader::OpCode; | ||||
| 
 | ||||
| @ -68,15 +70,15 @@ struct CFGRebuildState { | ||||
|     const ProgramCode& program_code; | ||||
|     ConstBufferLocker& locker; | ||||
|     u32 start{}; | ||||
|     std::vector<BlockInfo> block_info{}; | ||||
|     std::list<u32> inspect_queries{}; | ||||
|     std::list<Query> queries{}; | ||||
|     std::unordered_map<u32, u32> registered{}; | ||||
|     std::set<u32> labels{}; | ||||
|     std::map<u32, u32> ssy_labels{}; | ||||
|     std::map<u32, u32> pbk_labels{}; | ||||
|     std::unordered_map<u32, BlockStack> stacks{}; | ||||
|     ASTManager* manager; | ||||
|     std::vector<BlockInfo> block_info; | ||||
|     std::list<u32> inspect_queries; | ||||
|     std::list<Query> queries; | ||||
|     std::unordered_map<u32, u32> registered; | ||||
|     std::set<u32> labels; | ||||
|     std::map<u32, u32> ssy_labels; | ||||
|     std::map<u32, u32> pbk_labels; | ||||
|     std::unordered_map<u32, BlockStack> stacks; | ||||
|     ASTManager* manager{}; | ||||
| }; | ||||
| 
 | ||||
| enum class BlockCollision : u32 { None, Found, Inside }; | ||||
| @ -109,7 +111,7 @@ BlockInfo& CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) { | ||||
| } | ||||
| 
 | ||||
| Pred GetPredicate(u32 index, bool negated) { | ||||
|     return static_cast<Pred>(index + (negated ? 8 : 0)); | ||||
|     return static_cast<Pred>(static_cast<u64>(index) + (negated ? 8ULL : 0ULL)); | ||||
| } | ||||
| 
 | ||||
| /**
 | ||||
| @ -136,15 +138,13 @@ struct BranchIndirectInfo { | ||||
|     s32 relative_position{}; | ||||
| }; | ||||
| 
 | ||||
| std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state, | ||||
|                                                           u32 start_address, u32 current_position) { | ||||
|     const u32 shader_start = state.start; | ||||
|     u32 pos = current_position; | ||||
|     BranchIndirectInfo result{}; | ||||
|     u64 track_register = 0; | ||||
| struct BufferInfo { | ||||
|     u32 index; | ||||
|     u32 offset; | ||||
| }; | ||||
| 
 | ||||
|     // Step 0 Get BRX Info
 | ||||
|     const Instruction instr = {state.program_code[pos]}; | ||||
| std::optional<std::pair<s32, u64>> GetBRXInfo(const CFGRebuildState& state, u32& pos) { | ||||
|     const Instruction instr = state.program_code[pos]; | ||||
|     const auto opcode = OpCode::Decode(instr); | ||||
|     if (opcode->get().GetId() != OpCode::Id::BRX) { | ||||
|         return std::nullopt; | ||||
| @ -152,86 +152,94 @@ std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& | ||||
|     if (instr.brx.constant_buffer != 0) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     track_register = instr.gpr8.Value(); | ||||
|     result.relative_position = instr.brx.GetBranchExtend(); | ||||
|     pos--; | ||||
|     bool found_track = false; | ||||
|     --pos; | ||||
|     return std::make_pair(instr.brx.GetBranchExtend(), instr.gpr8.Value()); | ||||
| } | ||||
| 
 | ||||
|     // Step 1 Track LDC
 | ||||
|     while (pos >= shader_start) { | ||||
|         if (IsSchedInstruction(pos, shader_start)) { | ||||
|             pos--; | ||||
|             continue; | ||||
|         } | ||||
|         const Instruction instr = {state.program_code[pos]}; | ||||
|         const auto opcode = OpCode::Decode(instr); | ||||
|         if (opcode->get().GetId() == OpCode::Id::LD_C) { | ||||
|             if (instr.gpr0.Value() == track_register && | ||||
|                 instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single) { | ||||
|                 result.buffer = instr.cbuf36.index.Value(); | ||||
|                 result.offset = static_cast<u32>(instr.cbuf36.GetOffset()); | ||||
|                 track_register = instr.gpr8.Value(); | ||||
|                 pos--; | ||||
|                 found_track = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         pos--; | ||||
|     } | ||||
| 
 | ||||
|     if (!found_track) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     found_track = false; | ||||
| 
 | ||||
|     // Step 2 Track SHL
 | ||||
|     while (pos >= shader_start) { | ||||
|         if (IsSchedInstruction(pos, shader_start)) { | ||||
|             pos--; | ||||
| template <typename Result, typename TestCallable, typename PackCallable> | ||||
| // requires std::predicate<TestCallable, Instruction, const OpCode::Matcher&>
 | ||||
| // requires std::invocable<PackCallable, Instruction, const OpCode::Matcher&>
 | ||||
| std::optional<Result> TrackInstruction(const CFGRebuildState& state, u32& pos, TestCallable test, | ||||
|                                        PackCallable pack) { | ||||
|     for (; pos >= state.start; --pos) { | ||||
|         if (IsSchedInstruction(pos, state.start)) { | ||||
|             continue; | ||||
|         } | ||||
|         const Instruction instr = state.program_code[pos]; | ||||
|         const auto opcode = OpCode::Decode(instr); | ||||
|         if (opcode->get().GetId() == OpCode::Id::SHL_IMM) { | ||||
|             if (instr.gpr0.Value() == track_register) { | ||||
|                 track_register = instr.gpr8.Value(); | ||||
|                 pos--; | ||||
|                 found_track = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         pos--; | ||||
|     } | ||||
| 
 | ||||
|     if (!found_track) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     found_track = false; | ||||
| 
 | ||||
|     // Step 3 Track IMNMX
 | ||||
|     while (pos >= shader_start) { | ||||
|         if (IsSchedInstruction(pos, shader_start)) { | ||||
|             pos--; | ||||
|         if (!opcode) { | ||||
|             continue; | ||||
|         } | ||||
|         const Instruction instr = state.program_code[pos]; | ||||
|         const auto opcode = OpCode::Decode(instr); | ||||
|         if (opcode->get().GetId() == OpCode::Id::IMNMX_IMM) { | ||||
|             if (instr.gpr0.Value() == track_register) { | ||||
|                 track_register = instr.gpr8.Value(); | ||||
|                 result.entries = instr.alu.GetSignedImm20_20() + 1; | ||||
|                 pos--; | ||||
|                 found_track = true; | ||||
|                 break; | ||||
|             } | ||||
|         if (test(instr, opcode->get())) { | ||||
|             --pos; | ||||
|             return std::make_optional(pack(instr, opcode->get())); | ||||
|         } | ||||
|         pos--; | ||||
|     } | ||||
|     return std::nullopt; | ||||
| } | ||||
| 
 | ||||
|     if (!found_track) { | ||||
| std::optional<std::pair<BufferInfo, u64>> TrackLDC(const CFGRebuildState& state, u32& pos, | ||||
|                                                    u64 brx_tracked_register) { | ||||
|     return TrackInstruction<std::pair<BufferInfo, u64>>( | ||||
|         state, pos, | ||||
|         [brx_tracked_register](auto instr, const auto& opcode) { | ||||
|             return opcode.GetId() == OpCode::Id::LD_C && | ||||
|                    instr.gpr0.Value() == brx_tracked_register && | ||||
|                    instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single; | ||||
|         }, | ||||
|         [](auto instr, const auto& opcode) { | ||||
|             const BufferInfo info = {static_cast<u32>(instr.cbuf36.index.Value()), | ||||
|                                      static_cast<u32>(instr.cbuf36.GetOffset())}; | ||||
|             return std::make_pair(info, instr.gpr8.Value()); | ||||
|         }); | ||||
| } | ||||
| 
 | ||||
| std::optional<u64> TrackSHLRegister(const CFGRebuildState& state, u32& pos, | ||||
|                                     u64 ldc_tracked_register) { | ||||
|     return TrackInstruction<u64>(state, pos, | ||||
|                                  [ldc_tracked_register](auto instr, const auto& opcode) { | ||||
|                                      return opcode.GetId() == OpCode::Id::SHL_IMM && | ||||
|                                             instr.gpr0.Value() == ldc_tracked_register; | ||||
|                                  }, | ||||
|                                  [](auto instr, const auto&) { return instr.gpr8.Value(); }); | ||||
| } | ||||
| 
 | ||||
| std::optional<u32> TrackIMNMXValue(const CFGRebuildState& state, u32& pos, | ||||
|                                    u64 shl_tracked_register) { | ||||
|     return TrackInstruction<u32>(state, pos, | ||||
|                                  [shl_tracked_register](auto instr, const auto& opcode) { | ||||
|                                      return opcode.GetId() == OpCode::Id::IMNMX_IMM && | ||||
|                                             instr.gpr0.Value() == shl_tracked_register; | ||||
|                                  }, | ||||
|                                  [](auto instr, const auto&) { | ||||
|                                      return static_cast<u32>(instr.alu.GetSignedImm20_20() + 1); | ||||
|                                  }); | ||||
| } | ||||
| 
 | ||||
| std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state, u32 pos) { | ||||
|     const auto brx_info = GetBRXInfo(state, pos); | ||||
|     if (!brx_info) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     return result; | ||||
|     const auto [relative_position, brx_tracked_register] = *brx_info; | ||||
| 
 | ||||
|     const auto ldc_info = TrackLDC(state, pos, brx_tracked_register); | ||||
|     if (!ldc_info) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
|     const auto [buffer_info, ldc_tracked_register] = *ldc_info; | ||||
| 
 | ||||
|     const auto shl_tracked_register = TrackSHLRegister(state, pos, ldc_tracked_register); | ||||
|     if (!shl_tracked_register) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
| 
 | ||||
|     const auto entries = TrackIMNMXValue(state, pos, *shl_tracked_register); | ||||
|     if (!entries) { | ||||
|         return std::nullopt; | ||||
|     } | ||||
| 
 | ||||
|     return BranchIndirectInfo{buffer_info.index, buffer_info.offset, *entries, relative_position}; | ||||
| } | ||||
| 
 | ||||
| std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address) { | ||||
| @ -420,30 +428,30 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address) | ||||
|             break; | ||||
|         } | ||||
|         case OpCode::Id::BRX: { | ||||
|             auto tmp = TrackBranchIndirectInfo(state, address, offset); | ||||
|             if (tmp) { | ||||
|                 auto result = *tmp; | ||||
|                 std::vector<CaseBranch> branches{}; | ||||
|                 s32 pc_target = offset + result.relative_position; | ||||
|                 for (u32 i = 0; i < result.entries; i++) { | ||||
|                     auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4); | ||||
|                     if (!k) { | ||||
|                         return {ParseResult::AbnormalFlow, parse_info}; | ||||
|                     } | ||||
|                     u32 value = *k; | ||||
|                     u32 target = static_cast<u32>((value >> 3) + pc_target); | ||||
|                     insert_label(state, target); | ||||
|                     branches.emplace_back(value, target); | ||||
|                 } | ||||
|                 parse_info.end_address = offset; | ||||
|                 parse_info.branch_info = MakeBranchInfo<MultiBranch>( | ||||
|                     static_cast<u32>(instr.gpr8.Value()), std::move(branches)); | ||||
| 
 | ||||
|                 return {ParseResult::ControlCaught, parse_info}; | ||||
|             } else { | ||||
|             const auto tmp = TrackBranchIndirectInfo(state, offset); | ||||
|             if (!tmp) { | ||||
|                 LOG_WARNING(HW_GPU, "BRX Track Unsuccesful"); | ||||
|                 return {ParseResult::AbnormalFlow, parse_info}; | ||||
|             } | ||||
|             return {ParseResult::AbnormalFlow, parse_info}; | ||||
| 
 | ||||
|             const auto result = *tmp; | ||||
|             const s32 pc_target = offset + result.relative_position; | ||||
|             std::vector<CaseBranch> branches; | ||||
|             for (u32 i = 0; i < result.entries; i++) { | ||||
|                 auto key = state.locker.ObtainKey(result.buffer, result.offset + i * 4); | ||||
|                 if (!key) { | ||||
|                     return {ParseResult::AbnormalFlow, parse_info}; | ||||
|                 } | ||||
|                 u32 value = *key; | ||||
|                 u32 target = static_cast<u32>((value >> 3) + pc_target); | ||||
|                 insert_label(state, target); | ||||
|                 branches.emplace_back(value, target); | ||||
|             } | ||||
|             parse_info.end_address = offset; | ||||
|             parse_info.branch_info = MakeBranchInfo<MultiBranch>( | ||||
|                 static_cast<u32>(instr.gpr8.Value()), std::move(branches)); | ||||
| 
 | ||||
|             return {ParseResult::ControlCaught, parse_info}; | ||||
|         } | ||||
|         default: | ||||
|             break; | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 bunnei
						bunnei