From 25fa10327ba46d15e6d796591a16fca8a52ea2eb Mon Sep 17 00:00:00 2001
From: liushuyu <liushuyu011@gmail.com>
Date: Sun, 3 Feb 2019 22:42:18 -0700
Subject: [PATCH] audio_core: hle: mf: use object proxy

---
 src/audio_core/hle/wmf_decoder.cpp       | 10 ++-----
 src/audio_core/hle/wmf_decoder_utils.cpp | 36 ++++++++++--------------
 src/audio_core/hle/wmf_decoder_utils.h   | 23 +++++++++++++++
 3 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/src/audio_core/hle/wmf_decoder.cpp b/src/audio_core/hle/wmf_decoder.cpp
index a2025a85e1..7440049f4e 100644
--- a/src/audio_core/hle/wmf_decoder.cpp
+++ b/src/audio_core/hle/wmf_decoder.cpp
@@ -70,15 +70,13 @@ std::optional<BinaryResponse> WMFDecoder::Impl::Initalize(const BinaryRequest& r
     }
 
     BinaryResponse response;
-    IMFTransform* tmp = nullptr;
     std::memcpy(&response, &request, sizeof(response));
     response.unknown1 = 0x0;
 
-    if (!MFDecoderInit(&tmp)) {
+    if (!MFDecoderInit(Amp(transform))) {
         LOG_CRITICAL(Audio_DSP, "Can't init decoder");
         return response;
     }
-    transform.reset(tmp);
 
     HRESULT hr = transform->GetStreamIDs(1, &in_stream_id, 1, &out_stream_id);
     if (hr == E_NOTIMPL) {
@@ -108,9 +106,6 @@ MFOutputState WMFDecoder::Impl::DecodingLoop(ADTSData adts_header,
     MFOutputState output_status = MFOutputState::OK;
     char* output_buffer = nullptr;
     DWORD output_len = 0;
-    DWORD tmp = 0;
-    // IMFSample* output_tmp = nullptr;
-    IMFMediaBuffer* mdbuf = nullptr;
     unique_mfptr<IMFSample> output;
 
     while (true) {
@@ -150,7 +145,8 @@ MFOutputState WMFDecoder::Impl::DecodingLoop(ADTSData adts_header,
         if (output_status == MFOutputState::HaveMoreData)
             continue;
 
-        if (output_status == MFOutputState::NeedMoreInput) // according to MS document, this is not an error (?!)
+        // according to MS document, this is not an error (?!)
+        if (output_status == MFOutputState::NeedMoreInput)
             return MFOutputState::NeedMoreInput;
 
         return MFOutputState::FatalError; // return on other status
diff --git a/src/audio_core/hle/wmf_decoder_utils.cpp b/src/audio_core/hle/wmf_decoder_utils.cpp
index 26d1905a7d..8ecfae1af0 100644
--- a/src/audio_core/hle/wmf_decoder_utils.cpp
+++ b/src/audio_core/hle/wmf_decoder_utils.cpp
@@ -15,6 +15,7 @@ void ReportError(std::string msg, HRESULT hr) {
                   nullptr, hr,
                   // hardcode to use en_US because if any user had problems with this
                   // we can help them w/o translating anything
+                  // default is to use the language currently active on the operating system
                   MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US), (LPSTR)&err, 0, nullptr);
     if (err != nullptr) {
         LOG_CRITICAL(Audio_DSP, "{}: {}", msg, err);
@@ -79,24 +80,20 @@ void MFDeInit(IMFTransform* transform) {
 
 unique_mfptr<IMFSample> CreateSample(void* data, DWORD len, DWORD alignment, LONGLONG duration) {
     HRESULT hr = S_OK;
-    IMFMediaBuffer* buf_tmp = nullptr;
     unique_mfptr<IMFMediaBuffer> buf;
-    IMFSample* sample_tmp = nullptr;
     unique_mfptr<IMFSample> sample;
 
-    hr = MFCreateSample(&sample_tmp);
+    hr = MFCreateSample(Amp(sample));
     if (FAILED(hr)) {
         ReportError("Unable to allocate a sample", hr);
         return nullptr;
     }
-    sample.reset(sample_tmp);
     // Yes, the argument for alignment is the actual alignment - 1
-    hr = MFCreateAlignedMemoryBuffer(len, alignment - 1, &buf_tmp);
+    hr = MFCreateAlignedMemoryBuffer(len, alignment - 1, Amp(buf));
     if (FAILED(hr)) {
         ReportError("Unable to allocate a memory buffer for sample", hr);
         return nullptr;
     }
-    buf.reset(buf_tmp);
     if (data) {
         BYTE* buffer;
         // lock the MediaBuffer
@@ -116,6 +113,7 @@ unique_mfptr<IMFSample> CreateSample(void* data, DWORD len, DWORD alignment, LON
     sample->AddBuffer(buf.get());
     hr = sample->SetSampleDuration(duration);
     if (FAILED(hr)) {
+        // MFT will take a guess for you in this case
         ReportError("Unable to set sample duration, but continuing anyway", hr);
     }
 
@@ -125,11 +123,11 @@ unique_mfptr<IMFSample> CreateSample(void* data, DWORD len, DWORD alignment, LON
 bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSData& adts,
                           UINT8* user_data, UINT32 user_data_len, GUID audio_format) {
     HRESULT hr = S_OK;
-    IMFMediaType* t;
+    unique_mfptr<IMFMediaType> t;
 
     // actually you can get rid of the whole block of searching and filtering mess
     // if you know the exact parameters of your media stream
-    hr = MFCreateMediaType(&t);
+    hr = MFCreateMediaType(Amp(t));
     if (FAILED(hr)) {
         ReportError("Unable to create an empty MediaType", hr);
         return false;
@@ -146,7 +144,7 @@ bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSD
     t->SetUINT32(MF_MT_AAC_AUDIO_PROFILE_LEVEL_INDICATION, 254);
     t->SetUINT32(MF_MT_AUDIO_BLOCK_ALIGNMENT, 1);
     t->SetBlob(MF_MT_USER_DATA, user_data, user_data_len);
-    hr = transform->SetInputType(in_stream_id, t, 0);
+    hr = transform->SetInputType(in_stream_id, t.get(), 0);
     if (FAILED(hr)) {
         ReportError("failed to select input types for MFT", hr);
         return false;
@@ -158,15 +156,13 @@ bool SelectInputMediaType(IMFTransform* transform, int in_stream_id, const ADTSD
 bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, GUID audio_format) {
     HRESULT hr = S_OK;
     UINT32 tmp;
-    IMFMediaType* type;
-    unique_mfptr<IMFMediaType> t;
+    unique_mfptr<IMFMediaType> type;
 
-    // If you know what you need and what you are doing, you can specify the condition instead of
+    // If you know what you need and what you are doing, you can specify the conditions instead of
     // searching but it's better to use search since MFT may or may not support your output
     // parameters
     for (DWORD i = 0;; i++) {
-        hr = transform->GetOutputAvailableType(out_stream_id, i, &type);
-        t.reset(type);
+        hr = transform->GetOutputAvailableType(out_stream_id, i, Amp(type));
         if (hr == MF_E_NO_MORE_TYPES || hr == E_NOTIMPL) {
             return true;
         }
@@ -175,19 +171,19 @@ bool SelectOutputMediaType(IMFTransform* transform, int out_stream_id, GUID audi
             return false;
         }
 
-        hr = t->GetUINT32(MF_MT_AUDIO_BITS_PER_SAMPLE, &tmp);
+        hr = type->GetUINT32(MF_MT_AUDIO_BITS_PER_SAMPLE, &tmp);
 
         if (FAILED(hr))
             continue;
         // select PCM-16 format
         if (tmp == 32) {
-            hr = t->SetUINT32(MF_MT_AUDIO_BLOCK_ALIGNMENT, 1);
+            hr = type->SetUINT32(MF_MT_AUDIO_BLOCK_ALIGNMENT, 1);
             if (FAILED(hr)) {
                 ReportError("failed to set MF_MT_AUDIO_BLOCK_ALIGNMENT for MFT on output stream",
                             hr);
                 return false;
             }
-            hr = transform->SetOutputType(out_stream_id, t.get(), 0);
+            hr = transform->SetOutputType(out_stream_id, type.get(), 0);
             if (FAILED(hr)) {
                 ReportError("failed to select output types for MFT", hr);
                 return false;
@@ -331,7 +327,6 @@ std::tuple<MFOutputState, unique_mfptr<IMFSample>> ReceiveSample(IMFTransform* t
 
 int CopySampleToBuffer(IMFSample* sample, void** output, DWORD* len) {
     unique_mfptr<IMFMediaBuffer> buffer;
-    IMFMediaBuffer* tmp;
     HRESULT hr = S_OK;
     BYTE* data;
 
@@ -341,14 +336,13 @@ int CopySampleToBuffer(IMFSample* sample, void** output, DWORD* len) {
         return -1;
     }
 
-    hr = sample->ConvertToContiguousBuffer(&tmp);
+    hr = sample->ConvertToContiguousBuffer(Amp(buffer));
     if (FAILED(hr)) {
         ReportError("Failed to get sample buffer", hr);
         return -1;
     }
-    buffer.reset(tmp);
 
-    hr = tmp->Lock(&data, nullptr, nullptr);
+    hr = buffer->Lock(&data, nullptr, nullptr);
     if (FAILED(hr)) {
         ReportError("Failed to lock the buffer", hr);
         return -1;
diff --git a/src/audio_core/hle/wmf_decoder_utils.h b/src/audio_core/hle/wmf_decoder_utils.h
index 07ed76662d..79514c19ce 100644
--- a/src/audio_core/hle/wmf_decoder_utils.h
+++ b/src/audio_core/hle/wmf_decoder_utils.h
@@ -31,6 +31,29 @@ struct MFRelease {
 template <typename T>
 using unique_mfptr = std::unique_ptr<T, MFRelease<T>>;
 
+template <typename SmartPtr, typename RawPtr>
+class AmpImpl {
+public:
+    AmpImpl(SmartPtr& smart_ptr) : smart_ptr(smart_ptr) {}
+    ~AmpImpl() {
+        smart_ptr.reset(raw_ptr);
+    }
+
+    operator RawPtr*() {
+        return &raw_ptr;
+    }
+
+private:
+    SmartPtr& smart_ptr;
+    RawPtr raw_ptr;
+};
+
+template <typename SmartPtr>
+auto Amp(SmartPtr& smart_ptr) {
+    return AmpImpl<SmartPtr, decltype(smart_ptr.get())>(smart_ptr);
+}
+
+// convient function for formatting error messages
 void ReportError(std::string msg, HRESULT hr);
 
 // exported functions