From a94af8ea62abb481b356813be2a3dd7aabf69c7f Mon Sep 17 00:00:00 2001 From: Wunk Date: Tue, 11 Jul 2023 09:21:37 -0700 Subject: [PATCH] shader_jit: Add optimizations up to `x86-64-v4` (#6668) --- .../shader/shader_jit_x64_compiler.cpp | 220 +++++++++++++----- 1 file changed, 157 insertions(+), 63 deletions(-) diff --git a/src/video_core/shader/shader_jit_x64_compiler.cpp b/src/video_core/shader/shader_jit_x64_compiler.cpp index 85681ab83f..6a30d4e23e 100644 --- a/src/video_core/shader/shader_jit_x64_compiler.cpp +++ b/src/video_core/shader/shader_jit_x64_compiler.cpp @@ -338,15 +338,39 @@ void JitShader::Compile_SanitizedMul(Xmm src1, Xmm src2, Xmm scratch) { // where neither source was, this NaN was generated by a 0 * inf multiplication, and so the // result should be transformed to 0 to match PICA fp rules. + if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL | Cpu::tAVX512DQ)) { + vmulps(scratch, src1, src2); + + // Mask of any NaN values found in the result + const Xbyak::Opmask zero_mask = k1; + vcmpunordps(zero_mask, scratch, scratch); + + // Mask of any non-NaN inputs producing NaN results + vcmpordps(zero_mask | zero_mask, src1, src2); + + knotb(zero_mask, zero_mask); + vmovaps(src1 | zero_mask | T_z, scratch); + + return; + } + // Set scratch to mask of (src1 != NaN and src2 != NaN) - movaps(scratch, src1); - cmpordps(scratch, src2); + if (host_caps.has(Cpu::tAVX)) { + vcmpordps(scratch, src1, src2); + } else { + movaps(scratch, src1); + cmpordps(scratch, src2); + } mulps(src1, src2); // Set src2 to mask of (result == NaN) - movaps(src2, src1); - cmpunordps(src2, src2); + if (host_caps.has(Cpu::tAVX)) { + vcmpunordps(src2, src2, src1); + } else { + movaps(src2, src1); + cmpunordps(src2, src2); + } // Clear components where scratch != src2 (i.e. if result is NaN where neither source was NaN) xorps(scratch, src2); @@ -406,13 +430,20 @@ void JitShader::Compile_DP3(Instruction instr) { Compile_SanitizedMul(SRC1, SRC2, SCRATCH); - movaps(SRC2, SRC1); - shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1)); + if (host_caps.has(Cpu::tAVX)) { + vshufps(SRC3, SRC1, SRC1, _MM_SHUFFLE(2, 2, 2, 2)); + vshufps(SRC2, SRC1, SRC1, _MM_SHUFFLE(1, 1, 1, 1)); + vshufps(SRC1, SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); + } else { + movaps(SRC2, SRC1); + shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1)); - movaps(SRC3, SRC1); - shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2)); + movaps(SRC3, SRC1); + shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2)); + + shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); + } - shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); addps(SRC1, SRC2); addps(SRC1, SRC3); @@ -589,9 +620,15 @@ void JitShader::Compile_MOV(Instruction instr) { void JitShader::Compile_RCP(Instruction instr) { Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); - // TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica - // performs this operation more accurately. This should be checked on hardware. - rcpss(SRC1, SRC1); + if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) { + // Accurate to 14 bits of precisions rather than 12 bits of rcpss + vrcp14ss(SRC1, SRC1, SRC1); + } else { + // TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica + // performs this operation more accurately. This should be checked on hardware. + rcpss(SRC1, SRC1); + } + shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX Compile_DestEnable(instr, SRC1); @@ -600,9 +637,15 @@ void JitShader::Compile_RCP(Instruction instr) { void JitShader::Compile_RSQ(Instruction instr) { Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); - // TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica - // performs this operation more accurately. This should be checked on hardware. - rsqrtss(SRC1, SRC1); + if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) { + // Accurate to 14 bits of precisions rather than 12 bits of rsqrtss + vrsqrt14ss(SRC1, SRC1, SRC1); + } else { + // TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica + // performs this operation more accurately. This should be checked on hardware. + rsqrtss(SRC1, SRC1); + } + shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX Compile_DestEnable(instr, SRC1); @@ -1050,32 +1093,47 @@ Xbyak::Label JitShader::CompilePrelude_Log2() { jp(input_is_nan); jae(input_out_of_range); - // Split input - movd(eax, SRC1); - mov(edx, eax); - and_(eax, 0x7f800000); - and_(edx, 0x007fffff); - movss(SCRATCH, xword[rip + c0]); // Preload c0. - or_(edx, 0x3f800000); - movd(SRC1, edx); - // SRC1 now contains the mantissa of the input. - mulss(SCRATCH, SRC1); - shr(eax, 23); - sub(eax, 0x7f); - cvtsi2ss(SCRATCH2, eax); - // SCRATCH2 now contains the exponent of the input. + // Split input: SRC1=MANT[1,2) SCRATCH2=Exponent + if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) { + vgetexpss(SCRATCH2, SRC1, SRC1); + vgetmantss(SRC1, SRC1, SRC1, 0x0'0); + } else { + movd(eax, SRC1); + mov(edx, eax); + and_(eax, 0x7f800000); + and_(edx, 0x007fffff); + or_(edx, 0x3f800000); + movd(SRC1, edx); + // SRC1 now contains the mantissa of the input. + shr(eax, 23); + sub(eax, 0x7f); + cvtsi2ss(SCRATCH2, eax); + // SCRATCH2 now contains the exponent of the input. + } + + movss(SCRATCH, xword[rip + c0]); // Complete computation of polynomial - addss(SCRATCH, xword[rip + c1]); - mulss(SCRATCH, SRC1); - addss(SCRATCH, xword[rip + c2]); - mulss(SCRATCH, SRC1); - addss(SCRATCH, xword[rip + c3]); - mulss(SCRATCH, SRC1); - subss(SRC1, ONE); - addss(SCRATCH, xword[rip + c4]); - mulss(SCRATCH, SRC1); - addss(SCRATCH2, SCRATCH); + if (host_caps.has(Cpu::tFMA)) { + vfmadd213ss(SCRATCH, SRC1, xword[rip + c1]); + vfmadd213ss(SCRATCH, SRC1, xword[rip + c2]); + vfmadd213ss(SCRATCH, SRC1, xword[rip + c3]); + vfmadd213ss(SCRATCH, SRC1, xword[rip + c4]); + subss(SRC1, ONE); + vfmadd231ss(SCRATCH2, SCRATCH, SRC1); + } else { + mulss(SCRATCH, SRC1); + addss(SCRATCH, xword[rip + c1]); + mulss(SCRATCH, SRC1); + addss(SCRATCH, xword[rip + c2]); + mulss(SCRATCH, SRC1); + addss(SCRATCH, xword[rip + c3]); + mulss(SCRATCH, SRC1); + subss(SRC1, ONE); + addss(SCRATCH, xword[rip + c4]); + mulss(SCRATCH, SRC1); + addss(SCRATCH2, SCRATCH); + } // Duplicate result across vector xorps(SRC1, SRC1); // break dependency chain @@ -1122,33 +1180,69 @@ Xbyak::Label JitShader::CompilePrelude_Exp2() { // Handle edge cases ucomiss(SRC1, SRC1); jp(ret_label); - // Clamp to maximum range since we shift the value directly into the exponent. - minss(SRC1, xword[rip + input_max]); - maxss(SRC1, xword[rip + input_min]); - // Decompose input - movss(SCRATCH, SRC1); - movss(SCRATCH2, xword[rip + c0]); // Preload c0. - subss(SCRATCH, xword[rip + half]); - cvtss2si(eax, SCRATCH); - cvtsi2ss(SCRATCH, eax); - // SCRATCH now contains input rounded to the nearest integer. - add(eax, 0x7f); - subss(SRC1, SCRATCH); - // SRC1 contains input - round(input), which is in [-0.5, 0.5). - mulss(SCRATCH2, SRC1); - shl(eax, 23); - movd(SCRATCH, eax); - // SCRATCH contains 2^(round(input)). + // Decompose input: + // SCRATCH=2^round(input) + // SRC1=input-round(input) [-0.5, 0.5) + if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) { + // input - 0.5 + vsubss(SCRATCH, SRC1, xword[rip + half]); + + // trunc(input - 0.5) + vrndscaless(SCRATCH2, SCRATCH, SCRATCH, _MM_FROUND_TRUNC); + + // SCRATCH = 1 * 2^(trunc(input - 0.5)) + vscalefss(SCRATCH, ONE, SCRATCH2); + + // SRC1 = input-trunc(input - 0.5) + vsubss(SRC1, SRC1, SCRATCH2); + } else { + // Clamp to maximum range since we shift the value directly into the exponent. + minss(SRC1, xword[rip + input_max]); + maxss(SRC1, xword[rip + input_min]); + + if (host_caps.has(Cpu::tAVX)) { + vsubss(SCRATCH, SRC1, xword[rip + half]); + } else { + movss(SCRATCH, SRC1); + subss(SCRATCH, xword[rip + half]); + } + + if (host_caps.has(Cpu::tSSE41)) { + roundss(SCRATCH, SCRATCH, _MM_FROUND_TRUNC); + cvtss2si(eax, SCRATCH); + } else { + cvtss2si(eax, SCRATCH); + cvtsi2ss(SCRATCH, eax); + } + // SCRATCH now contains input rounded to the nearest integer. + add(eax, 0x7f); + subss(SRC1, SCRATCH); + // SRC1 contains input - round(input), which is in [-0.5, 0.5). + shl(eax, 23); + movd(SCRATCH, eax); + // SCRATCH contains 2^(round(input)). + } // Complete computation of polynomial. - addss(SCRATCH2, xword[rip + c1]); - mulss(SCRATCH2, SRC1); - addss(SCRATCH2, xword[rip + c2]); - mulss(SCRATCH2, SRC1); - addss(SCRATCH2, xword[rip + c3]); - mulss(SRC1, SCRATCH2); - addss(SRC1, xword[rip + c4]); + movss(SCRATCH2, xword[rip + c0]); + + if (host_caps.has(Cpu::tFMA)) { + vfmadd213ss(SCRATCH2, SRC1, xword[rip + c1]); + vfmadd213ss(SCRATCH2, SRC1, xword[rip + c2]); + vfmadd213ss(SCRATCH2, SRC1, xword[rip + c3]); + vfmadd213ss(SRC1, SCRATCH2, xword[rip + c4]); + } else { + mulss(SCRATCH2, SRC1); + addss(SCRATCH2, xword[rip + c1]); + mulss(SCRATCH2, SRC1); + addss(SCRATCH2, xword[rip + c2]); + mulss(SCRATCH2, SRC1); + addss(SCRATCH2, xword[rip + c3]); + mulss(SRC1, SCRATCH2); + addss(SRC1, xword[rip + c4]); + } + mulss(SRC1, SCRATCH); // Duplicate result across vector