From 6c512f4bffde6bd8e4dbc74ed27cc84cd7fffadb Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Wed, 14 Apr 2021 00:32:18 -0400
Subject: [PATCH] spirv: Implement alpha test

---
 .../backend/spirv/emit_spirv_special.cpp      | 45 +++++++++++++++++++
 src/shader_recompiler/profile.h               | 15 ++++++-
 .../renderer_vulkan/vk_pipeline_cache.cpp     | 36 +++++++++++++++
 3 files changed, 95 insertions(+), 1 deletion(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp
index 7af29e4dd7..8bb94f5461 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp
@@ -37,6 +37,48 @@ Id DefaultVarying(EmitContext& ctx, u32 num_components, u32 element, Id zero, Id
     }
     throw InvalidArgument("Bad element");
 }
+
+Id ComparisonFunction(EmitContext& ctx, CompareFunction comparison, Id operand_1, Id operand_2) {
+    switch (comparison) {
+    case CompareFunction::Never:
+        return ctx.false_value;
+    case CompareFunction::Less:
+        return ctx.OpFOrdLessThan(ctx.U1, operand_1, operand_2);
+    case CompareFunction::Equal:
+        return ctx.OpFOrdEqual(ctx.U1, operand_1, operand_2);
+    case CompareFunction::LessThanEqual:
+        return ctx.OpFOrdLessThanEqual(ctx.U1, operand_1, operand_2);
+    case CompareFunction::Greater:
+        return ctx.OpFOrdGreaterThan(ctx.U1, operand_1, operand_2);
+    case CompareFunction::NotEqual:
+        return ctx.OpFOrdNotEqual(ctx.U1, operand_1, operand_2);
+    case CompareFunction::GreaterThanEqual:
+        return ctx.OpFOrdGreaterThanEqual(ctx.U1, operand_1, operand_2);
+    case CompareFunction::Always:
+        return ctx.true_value;
+    }
+    throw InvalidArgument("Comparison function {}", comparison);
+}
+
+void AlphaTest(EmitContext& ctx) {
+    const auto comparison{*ctx.profile.alpha_test_func};
+    if (comparison == CompareFunction::Always) {
+        return;
+    }
+    const Id type{ctx.F32[1]};
+    const Id rt0_color{ctx.OpLoad(ctx.F32[4], ctx.frag_color[0])};
+    const Id alpha{ctx.OpCompositeExtract(type, rt0_color, 3u)};
+
+    const Id true_label{ctx.OpLabel()};
+    const Id discard_label{ctx.OpLabel()};
+    const Id alpha_reference{ctx.Constant(ctx.F32[1], ctx.profile.alpha_test_reference)};
+    const Id condition{ComparisonFunction(ctx, comparison, alpha, alpha_reference)};
+
+    ctx.OpBranchConditional(condition, true_label, discard_label);
+    ctx.AddLabel(discard_label);
+    ctx.OpKill();
+    ctx.AddLabel(true_label);
+}
 } // Anonymous namespace
 
 void EmitPrologue(EmitContext& ctx) {
@@ -68,6 +110,9 @@ void EmitEpilogue(EmitContext& ctx) {
     if (ctx.stage == Stage::VertexB && ctx.profile.convert_depth_mode) {
         ConvertDepthMode(ctx);
     }
+    if (ctx.stage == Stage::Fragment) {
+        AlphaTest(ctx);
+    }
 }
 
 void EmitEmitVertex(EmitContext& ctx, const IR::Value& stream) {
diff --git a/src/shader_recompiler/profile.h b/src/shader_recompiler/profile.h
index 5ecae71b95..c26017d75f 100644
--- a/src/shader_recompiler/profile.h
+++ b/src/shader_recompiler/profile.h
@@ -5,8 +5,8 @@
 #pragma once
 
 #include <array>
-#include <vector>
 #include <optional>
+#include <vector>
 
 #include "common/common_types.h"
 
@@ -27,6 +27,17 @@ enum class InputTopology {
     TrianglesAdjacency,
 };
 
+enum class CompareFunction {
+    Never,
+    Less,
+    Equal,
+    LessThanEqual,
+    Greater,
+    NotEqual,
+    GreaterThanEqual,
+    Always,
+};
+
 struct TransformFeedbackVarying {
     u32 buffer{};
     u32 stride{};
@@ -66,6 +77,8 @@ struct Profile {
     InputTopology input_topology{};
 
     std::optional<float> fixed_state_point_size;
+    std::optional<CompareFunction> alpha_test_func;
+    float alpha_test_reference{};
 
     std::vector<TransformFeedbackVarying> xfb_varyings;
 };
diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
index de52d0f306..80f196d0e7 100644
--- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
@@ -492,6 +492,37 @@ private:
     u32 read_lowest{};
     u32 read_highest{};
 };
+
+Shader::CompareFunction MaxwellToCompareFunction(Maxwell::ComparisonOp comparison) {
+    switch (comparison) {
+    case Maxwell::ComparisonOp::Never:
+    case Maxwell::ComparisonOp::NeverOld:
+        return Shader::CompareFunction::Never;
+    case Maxwell::ComparisonOp::Less:
+    case Maxwell::ComparisonOp::LessOld:
+        return Shader::CompareFunction::Less;
+    case Maxwell::ComparisonOp::Equal:
+    case Maxwell::ComparisonOp::EqualOld:
+        return Shader::CompareFunction::Equal;
+    case Maxwell::ComparisonOp::LessEqual:
+    case Maxwell::ComparisonOp::LessEqualOld:
+        return Shader::CompareFunction::LessThanEqual;
+    case Maxwell::ComparisonOp::Greater:
+    case Maxwell::ComparisonOp::GreaterOld:
+        return Shader::CompareFunction::Greater;
+    case Maxwell::ComparisonOp::NotEqual:
+    case Maxwell::ComparisonOp::NotEqualOld:
+        return Shader::CompareFunction::NotEqual;
+    case Maxwell::ComparisonOp::GreaterEqual:
+    case Maxwell::ComparisonOp::GreaterEqualOld:
+        return Shader::CompareFunction::GreaterThanEqual;
+    case Maxwell::ComparisonOp::Always:
+    case Maxwell::ComparisonOp::AlwaysOld:
+        return Shader::CompareFunction::Always;
+    }
+    UNIMPLEMENTED_MSG("Unimplemented comparison op={}", comparison);
+    return {};
+}
 } // Anonymous namespace
 
 void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading,
@@ -1016,6 +1047,11 @@ Shader::Profile PipelineCache::MakeProfile(const GraphicsPipelineCacheKey& key,
         }
         profile.convert_depth_mode = gl_ndc;
         break;
+    case Shader::Stage::Fragment:
+        profile.alpha_test_func = MaxwellToCompareFunction(
+            key.state.UnpackComparisonOp(key.state.alpha_test_func.Value()));
+        profile.alpha_test_reference = Common::BitCast<float>(key.state.alpha_test_ref);
+        break;
     default:
         break;
     }