Common: Add TransferableSharedMutex class and unit tests.

This commit is contained in:
Jordan Woyak 2025-10-19 20:31:22 -05:00
parent b0652925fa
commit 1d9e475123
4 changed files with 216 additions and 0 deletions

View File

@ -151,6 +151,7 @@ add_library(common
Timer.h
TimeUtil.cpp
TimeUtil.h
TransferableSharedMutex.h
TraversalClient.cpp
TraversalClient.h
TraversalProto.h

View File

@ -0,0 +1,92 @@
// Copyright 2025 Dolphin Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#pragma once
#include <atomic>
#include <cassert>
#include <cstdint>
namespace Common
{
// Behaves like `std::shared_mutex` but locks and unlocks may come from different threads.
class TransferableSharedMutex
{
public:
void lock()
{
while (true)
{
CounterType old_value{};
if (m_counter.compare_exchange_strong(old_value, EXCLUSIVE_LOCK_VALUE,
std::memory_order_acquire, std::memory_order_relaxed))
{
return;
}
// lock() or lock_shared() is already held.
// Wait for an unlock notification and try again.
m_counter.wait(old_value, std::memory_order_relaxed);
}
}
bool try_lock()
{
CounterType old_value{};
return m_counter.compare_exchange_weak(old_value, EXCLUSIVE_LOCK_VALUE,
std::memory_order_acquire, std::memory_order_relaxed);
}
void unlock()
{
m_counter.store(0, std::memory_order_release);
m_counter.notify_all(); // Notify potentially multiple wait()ers in lock_shared().
}
void lock_shared()
{
while (true)
{
auto old_value = m_counter.load(std::memory_order_relaxed);
while (old_value < LAST_SHARED_LOCK_VALUE)
{
if (m_counter.compare_exchange_strong(old_value, old_value + 1, std::memory_order_acquire,
std::memory_order_relaxed))
{
return;
}
}
// Something has gone very wrong if m_counter is nearly saturated with shared_lock().
assert(old_value != LAST_SHARED_LOCK_VALUE);
// lock() is already held.
// Wait for an unlock notification and try again.
m_counter.wait(old_value, std::memory_order_relaxed);
}
}
bool try_lock_shared()
{
auto old_value = m_counter.load(std::memory_order_relaxed);
return (old_value < LAST_SHARED_LOCK_VALUE) &&
m_counter.compare_exchange_weak(old_value, old_value + 1, std::memory_order_acquire,
std::memory_order_relaxed);
}
void unlock_shared()
{
if (m_counter.fetch_sub(1, std::memory_order_release) == 1)
m_counter.notify_one(); // Notify one of the wait()ers in lock().
}
private:
using CounterType = std::uintptr_t;
static constexpr auto EXCLUSIVE_LOCK_VALUE = CounterType(-1);
static constexpr auto LAST_SHARED_LOCK_VALUE = EXCLUSIVE_LOCK_VALUE - 1;
std::atomic<CounterType> m_counter{};
};
} // namespace Common

View File

@ -171,6 +171,7 @@
<ClInclude Include="Common\Thread.h" />
<ClInclude Include="Common\Timer.h" />
<ClInclude Include="Common\TimeUtil.h" />
<ClInclude Include="Common\TransferableSharedMutex.h" />
<ClInclude Include="Common\TraversalClient.h" />
<ClInclude Include="Common\TraversalProto.h" />
<ClInclude Include="Common\TypeUtils.h" />

View File

@ -6,9 +6,11 @@
#include <algorithm>
#include <chrono>
#include <mutex>
#include <shared_mutex>
#include <thread>
#include "Common/Mutex.h"
#include "Common/TransferableSharedMutex.h"
template <typename MutexType>
static void DoAtomicMutexTests(const char mutex_name[])
@ -100,3 +102,123 @@ TEST(Mutex, AtomicMutex)
DoAtomicMutexTests<Common::AtomicMutex>("AtomicMutex");
DoAtomicMutexTests<Common::SpinMutex>("SpinMutex");
}
TEST(Mutex, TransferableSharedMutex)
{
Common::TransferableSharedMutex work_mutex;
bool worker_done = false;
static constexpr auto SLEEP_TIME = std::chrono::microseconds{1};
// lock() on main thread, unlock() on worker thread.
std::thread thread{[&, lk = std::unique_lock{work_mutex}] {
std::this_thread::sleep_for(SLEEP_TIME);
worker_done = true;
}};
// lock() waits for the thread to unlock().
{
std::lock_guard lk{work_mutex};
EXPECT_TRUE(worker_done);
}
thread.join();
// Prevent below workers from incrementing `done_count`.
Common::TransferableSharedMutex done_mutex;
std::unique_lock done_lk{done_mutex};
// try_*() fails when holding an exclusive lock.
EXPECT_FALSE(done_mutex.try_lock());
EXPECT_FALSE(done_mutex.try_lock_shared());
static constexpr int THREAD_COUNT = 4;
static constexpr int REPEAT_COUNT = 100;
static constexpr int TOTAL_ITERATIONS = THREAD_COUNT * REPEAT_COUNT;
std::atomic<int> work_count = 0;
std::atomic<int> done_count = 0;
int additional_work_count = 0;
std::atomic<int> try_lock_fail_count = 0;
std::atomic<int> try_lock_shared_fail_count = 0;
std::vector<std::thread> threads(THREAD_COUNT);
for (auto& t : threads)
{
// lock_shared() multiple times on main thread.
t = std::thread{[&, work_lk = std::shared_lock{work_mutex}]() mutable {
std::this_thread::sleep_for(SLEEP_TIME);
// try_lock() fails after lock_shared().
EXPECT_FALSE(work_mutex.try_lock());
// Main thread already holds done_mutex.
EXPECT_FALSE(done_mutex.try_lock());
EXPECT_FALSE(done_mutex.try_lock_shared());
++work_count;
// Signal work is done.
work_lk.unlock();
// lock_shared() blocks until main thread unlock()s.
{
std::shared_lock lk{done_mutex};
++done_count;
}
// Contesting all of [try_]lock[_shared] doesn't explode.
for (int i = 0; i != REPEAT_COUNT; ++i)
{
while (!work_mutex.try_lock())
{
try_lock_fail_count.fetch_add(1, std::memory_order_relaxed);
}
work_mutex.unlock();
while (!work_mutex.try_lock_shared())
{
try_lock_shared_fail_count.fetch_add(1, std::memory_order_relaxed);
}
work_mutex.unlock_shared();
{
std::lock_guard lk{work_mutex};
++additional_work_count;
}
std::shared_lock lk{work_mutex};
}
}};
}
// lock() waits for threads to unlock_shared().
{
std::lock_guard lk{work_mutex};
EXPECT_EQ(work_count.load(std::memory_order_relaxed), THREAD_COUNT);
}
std::this_thread::sleep_for(SLEEP_TIME);
// The threads are still blocking on done_mutex.
EXPECT_EQ(done_count, 0);
done_lk.unlock();
std::ranges::for_each(threads, &std::thread::join);
// The threads finished.
EXPECT_EQ(done_count, THREAD_COUNT);
EXPECT_EQ(additional_work_count, TOTAL_ITERATIONS);
GTEST_LOG_(INFO) << "try_lock() failure %: "
<< (try_lock_fail_count * 100.0 / (TOTAL_ITERATIONS + try_lock_fail_count));
GTEST_LOG_(INFO) << "try_lock_shared() failure %: "
<< (try_lock_shared_fail_count * 100.0 /
(TOTAL_ITERATIONS + try_lock_shared_fail_count));
// Things are still sane after contesting in worker threads.
done_lk.lock();
std::lock_guard lk{work_mutex};
}