Merge pull request #3979 from ReinUsesLisp/thread-group
shader/other: Implement thread comparisons (NV_shader_thread_group)
This commit is contained in:
		
						commit
						487dd05170
					
				@ -2309,6 +2309,18 @@ private:
 | 
			
		||||
        return {"gl_SubGroupInvocationARB", Type::Uint};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <const std::string_view& comparison>
 | 
			
		||||
    Expression ThreadMask(Operation) {
 | 
			
		||||
        if (device.HasWarpIntrinsics()) {
 | 
			
		||||
            return {fmt::format("gl_Thread{}MaskNV", comparison), Type::Uint};
 | 
			
		||||
        }
 | 
			
		||||
        if (device.HasShaderBallot()) {
 | 
			
		||||
            return {fmt::format("uint(gl_SubGroup{}MaskARB)", comparison), Type::Uint};
 | 
			
		||||
        }
 | 
			
		||||
        LOG_ERROR(Render_OpenGL, "Thread mask intrinsics are required by the shader");
 | 
			
		||||
        return {"0U", Type::Uint};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Expression ShuffleIndexed(Operation operation) {
 | 
			
		||||
        std::string value = VisitOperand(operation, 0).AsFloat();
 | 
			
		||||
 | 
			
		||||
@ -2337,6 +2349,12 @@ private:
 | 
			
		||||
        static constexpr std::string_view NotEqual = "!=";
 | 
			
		||||
        static constexpr std::string_view GreaterEqual = ">=";
 | 
			
		||||
 | 
			
		||||
        static constexpr std::string_view Eq = "Eq";
 | 
			
		||||
        static constexpr std::string_view Ge = "Ge";
 | 
			
		||||
        static constexpr std::string_view Gt = "Gt";
 | 
			
		||||
        static constexpr std::string_view Le = "Le";
 | 
			
		||||
        static constexpr std::string_view Lt = "Lt";
 | 
			
		||||
 | 
			
		||||
        static constexpr std::string_view Add = "Add";
 | 
			
		||||
        static constexpr std::string_view Min = "Min";
 | 
			
		||||
        static constexpr std::string_view Max = "Max";
 | 
			
		||||
@ -2554,6 +2572,11 @@ private:
 | 
			
		||||
        &GLSLDecompiler::VoteEqual,
 | 
			
		||||
 | 
			
		||||
        &GLSLDecompiler::ThreadId,
 | 
			
		||||
        &GLSLDecompiler::ThreadMask<Func::Eq>,
 | 
			
		||||
        &GLSLDecompiler::ThreadMask<Func::Ge>,
 | 
			
		||||
        &GLSLDecompiler::ThreadMask<Func::Gt>,
 | 
			
		||||
        &GLSLDecompiler::ThreadMask<Func::Le>,
 | 
			
		||||
        &GLSLDecompiler::ThreadMask<Func::Lt>,
 | 
			
		||||
        &GLSLDecompiler::ShuffleIndexed,
 | 
			
		||||
 | 
			
		||||
        &GLSLDecompiler::MemoryBarrierGL,
 | 
			
		||||
 | 
			
		||||
@ -515,6 +515,16 @@ private:
 | 
			
		||||
    void DeclareCommon() {
 | 
			
		||||
        thread_id =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
 | 
			
		||||
        thread_masks[0] =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupEqMask, t_in_uint4, "thread_eq_mask");
 | 
			
		||||
        thread_masks[1] =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupGeMask, t_in_uint4, "thread_ge_mask");
 | 
			
		||||
        thread_masks[2] =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupGtMask, t_in_uint4, "thread_gt_mask");
 | 
			
		||||
        thread_masks[3] =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupLeMask, t_in_uint4, "thread_le_mask");
 | 
			
		||||
        thread_masks[4] =
 | 
			
		||||
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupLtMask, t_in_uint4, "thread_lt_mask");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void DeclareVertex() {
 | 
			
		||||
@ -2175,6 +2185,13 @@ private:
 | 
			
		||||
        return {OpLoad(t_uint, thread_id), Type::Uint};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    template <std::size_t index>
 | 
			
		||||
    Expression ThreadMask(Operation) {
 | 
			
		||||
        // TODO(Rodrigo): Handle devices with different warp sizes
 | 
			
		||||
        const Id mask = thread_masks[index];
 | 
			
		||||
        return {OpLoad(t_uint, AccessElement(t_in_uint, mask, 0)), Type::Uint};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Expression ShuffleIndexed(Operation operation) {
 | 
			
		||||
        const Id value = AsFloat(Visit(operation[0]));
 | 
			
		||||
        const Id index = AsUint(Visit(operation[1]));
 | 
			
		||||
@ -2639,6 +2656,11 @@ private:
 | 
			
		||||
        &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
 | 
			
		||||
 | 
			
		||||
        &SPIRVDecompiler::ThreadId,
 | 
			
		||||
        &SPIRVDecompiler::ThreadMask<0>, // Eq
 | 
			
		||||
        &SPIRVDecompiler::ThreadMask<1>, // Ge
 | 
			
		||||
        &SPIRVDecompiler::ThreadMask<2>, // Gt
 | 
			
		||||
        &SPIRVDecompiler::ThreadMask<3>, // Le
 | 
			
		||||
        &SPIRVDecompiler::ThreadMask<4>, // Lt
 | 
			
		||||
        &SPIRVDecompiler::ShuffleIndexed,
 | 
			
		||||
 | 
			
		||||
        &SPIRVDecompiler::MemoryBarrierGL,
 | 
			
		||||
@ -2763,6 +2785,7 @@ private:
 | 
			
		||||
    Id workgroup_id{};
 | 
			
		||||
    Id local_invocation_id{};
 | 
			
		||||
    Id thread_id{};
 | 
			
		||||
    std::array<Id, 5> thread_masks{}; // eq, ge, gt, le, lt
 | 
			
		||||
 | 
			
		||||
    VertexIndices in_indices;
 | 
			
		||||
    VertexIndices out_indices;
 | 
			
		||||
 | 
			
		||||
@ -109,6 +109,27 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) {
 | 
			
		||||
                return Operation(OperationCode::WorkGroupIdY);
 | 
			
		||||
            case SystemVariable::CtaIdZ:
 | 
			
		||||
                return Operation(OperationCode::WorkGroupIdZ);
 | 
			
		||||
            case SystemVariable::EqMask:
 | 
			
		||||
            case SystemVariable::LtMask:
 | 
			
		||||
            case SystemVariable::LeMask:
 | 
			
		||||
            case SystemVariable::GtMask:
 | 
			
		||||
            case SystemVariable::GeMask:
 | 
			
		||||
                uses_warps = true;
 | 
			
		||||
                switch (instr.sys20) {
 | 
			
		||||
                case SystemVariable::EqMask:
 | 
			
		||||
                    return Operation(OperationCode::ThreadEqMask);
 | 
			
		||||
                case SystemVariable::LtMask:
 | 
			
		||||
                    return Operation(OperationCode::ThreadLtMask);
 | 
			
		||||
                case SystemVariable::LeMask:
 | 
			
		||||
                    return Operation(OperationCode::ThreadLeMask);
 | 
			
		||||
                case SystemVariable::GtMask:
 | 
			
		||||
                    return Operation(OperationCode::ThreadGtMask);
 | 
			
		||||
                case SystemVariable::GeMask:
 | 
			
		||||
                    return Operation(OperationCode::ThreadGeMask);
 | 
			
		||||
                default:
 | 
			
		||||
                    UNREACHABLE();
 | 
			
		||||
                    return Immediate(0u);
 | 
			
		||||
                }
 | 
			
		||||
            default:
 | 
			
		||||
                UNIMPLEMENTED_MSG("Unhandled system move: {}",
 | 
			
		||||
                                  static_cast<u32>(instr.sys20.Value()));
 | 
			
		||||
 | 
			
		||||
@ -226,6 +226,11 @@ enum class OperationCode {
 | 
			
		||||
    VoteEqual,    /// (bool) -> bool
 | 
			
		||||
 | 
			
		||||
    ThreadId,       /// () -> uint
 | 
			
		||||
    ThreadEqMask,   /// () -> uint
 | 
			
		||||
    ThreadGeMask,   /// () -> uint
 | 
			
		||||
    ThreadGtMask,   /// () -> uint
 | 
			
		||||
    ThreadLeMask,   /// () -> uint
 | 
			
		||||
    ThreadLtMask,   /// () -> uint
 | 
			
		||||
    ShuffleIndexed, /// (uint value, uint index) -> uint
 | 
			
		||||
 | 
			
		||||
    MemoryBarrierGL, /// () -> void
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user