citra/src/core/hle/kernel/address_arbiter.cpp

220 lines
7.4 KiB
C++
Raw Normal View History

// Copyright 2018 yuzu emulator team
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <algorithm>
#include <vector>
#include "common/assert.h"
#include "common/common_types.h"
#include "core/core.h"
#include "core/core_cpu.h"
#include "core/hle/kernel/address_arbiter.h"
#include "core/hle/kernel/errors.h"
#include "core/hle/kernel/object.h"
#include "core/hle/kernel/process.h"
#include "core/hle/kernel/scheduler.h"
#include "core/hle/kernel/thread.h"
#include "core/hle/result.h"
#include "core/memory.h"
namespace Kernel {
namespace {
2018-06-22 03:05:34 +00:00
// Wake up num_to_wake (or all) threads in a vector.
void WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, s32 num_to_wake) {
2019-03-29 21:13:00 +00:00
auto& system = Core::System::GetInstance();
2018-06-22 03:05:34 +00:00
// Only process up to 'target' threads, unless 'target' is <= 0, in which case process
// them all.
std::size_t last = waiting_threads.size();
if (num_to_wake > 0) {
last = std::min(last, static_cast<std::size_t>(num_to_wake));
}
2018-06-22 03:05:34 +00:00
// Signal the waiting threads.
for (std::size_t i = 0; i < last; i++) {
ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb);
2018-06-22 03:05:34 +00:00
waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS);
waiting_threads[i]->SetArbiterWaitAddress(0);
2018-06-22 03:05:34 +00:00
waiting_threads[i]->ResumeFromWait();
2019-04-02 13:22:53 +00:00
system.PrepareReschedule(waiting_threads[i]->GetProcessorID());
2018-06-22 03:05:34 +00:00
}
}
} // Anonymous namespace
2018-06-22 03:05:34 +00:00
AddressArbiter::AddressArbiter(Core::System& system) : system{system} {}
AddressArbiter::~AddressArbiter() = default;
ResultCode AddressArbiter::SignalToAddress(VAddr address, SignalType type, s32 value,
s32 num_to_wake) {
switch (type) {
case SignalType::Signal:
return SignalToAddressOnly(address, num_to_wake);
case SignalType::IncrementAndSignalIfEqual:
return IncrementAndSignalToAddressIfEqual(address, value, num_to_wake);
case SignalType::ModifyByWaitingCountAndSignalIfEqual:
return ModifyByWaitingCountAndSignalToAddressIfEqual(address, value, num_to_wake);
default:
return ERR_INVALID_ENUM_VALUE;
}
}
ResultCode AddressArbiter::SignalToAddressOnly(VAddr address, s32 num_to_wake) {
const std::vector<std::shared_ptr<Thread>> waiting_threads =
GetThreadsWaitingOnAddress(address);
2018-06-22 03:05:34 +00:00
WakeThreads(waiting_threads, num_to_wake);
return RESULT_SUCCESS;
}
ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake) {
auto& memory = system.Memory();
2018-06-22 03:05:34 +00:00
// Ensure that we can write to the address.
if (!memory.IsValidVirtualAddress(address)) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_ADDRESS_STATE;
}
if (static_cast<s32>(memory.Read32(address)) != value) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_STATE;
}
memory.Write32(address, static_cast<u32>(value + 1));
return SignalToAddressOnly(address, num_to_wake);
2018-06-22 03:05:34 +00:00
}
ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake) {
auto& memory = system.Memory();
2018-06-22 03:05:34 +00:00
// Ensure that we can write to the address.
if (!memory.IsValidVirtualAddress(address)) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_ADDRESS_STATE;
}
// Get threads waiting on the address.
const std::vector<std::shared_ptr<Thread>> waiting_threads =
GetThreadsWaitingOnAddress(address);
2018-06-22 03:05:34 +00:00
// Determine the modified value depending on the waiting count.
s32 updated_value;
if (num_to_wake <= 0) {
if (waiting_threads.empty()) {
updated_value = value + 1;
} else {
updated_value = value - 1;
}
2018-06-22 03:05:34 +00:00
} else {
if (waiting_threads.empty()) {
updated_value = value + 1;
} else if (waiting_threads.size() <= static_cast<u32>(num_to_wake)) {
updated_value = value - 1;
} else {
updated_value = value;
}
2018-06-22 03:05:34 +00:00
}
if (static_cast<s32>(memory.Read32(address)) != value) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_STATE;
}
memory.Write32(address, static_cast<u32>(updated_value));
2018-06-22 03:05:34 +00:00
WakeThreads(waiting_threads, num_to_wake);
return RESULT_SUCCESS;
}
ResultCode AddressArbiter::WaitForAddress(VAddr address, ArbitrationType type, s32 value,
s64 timeout_ns) {
switch (type) {
case ArbitrationType::WaitIfLessThan:
return WaitForAddressIfLessThan(address, value, timeout_ns, false);
case ArbitrationType::DecrementAndWaitIfLessThan:
return WaitForAddressIfLessThan(address, value, timeout_ns, true);
case ArbitrationType::WaitIfEqual:
return WaitForAddressIfEqual(address, value, timeout_ns);
default:
return ERR_INVALID_ENUM_VALUE;
}
}
ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout,
bool should_decrement) {
auto& memory = system.Memory();
2018-06-22 03:05:34 +00:00
// Ensure that we can read the address.
if (!memory.IsValidVirtualAddress(address)) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_ADDRESS_STATE;
}
const s32 cur_value = static_cast<s32>(memory.Read32(address));
if (cur_value >= value) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_STATE;
}
if (should_decrement) {
memory.Write32(address, static_cast<u32>(cur_value - 1));
}
2018-06-22 03:05:34 +00:00
// Short-circuit without rescheduling, if timeout is zero.
if (timeout == 0) {
return RESULT_TIMEOUT;
}
return WaitForAddressImpl(address, timeout);
2018-06-22 03:05:34 +00:00
}
ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
auto& memory = system.Memory();
2018-06-22 03:05:34 +00:00
// Ensure that we can read the address.
if (!memory.IsValidVirtualAddress(address)) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_ADDRESS_STATE;
}
2018-06-22 03:05:34 +00:00
// Only wait for the address if equal.
if (static_cast<s32>(memory.Read32(address)) != value) {
2018-06-22 03:05:34 +00:00
return ERR_INVALID_STATE;
}
// Short-circuit without rescheduling if timeout is zero.
2018-06-22 03:05:34 +00:00
if (timeout == 0) {
return RESULT_TIMEOUT;
}
return WaitForAddressImpl(address, timeout);
2018-06-22 03:05:34 +00:00
}
ResultCode AddressArbiter::WaitForAddressImpl(VAddr address, s64 timeout) {
Thread* current_thread = system.CurrentScheduler().GetCurrentThread();
current_thread->SetArbiterWaitAddress(address);
current_thread->SetStatus(ThreadStatus::WaitArb);
current_thread->InvalidateWakeupCallback();
current_thread->WakeAfterDelay(timeout);
2019-04-02 13:22:53 +00:00
system.PrepareReschedule(current_thread->GetProcessorID());
return RESULT_TIMEOUT;
}
std::vector<std::shared_ptr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(
VAddr address) const {
// Retrieve all threads that are waiting for this address.
std::vector<std::shared_ptr<Thread>> threads;
2019-03-29 21:13:00 +00:00
const auto& scheduler = system.GlobalScheduler();
const auto& thread_list = scheduler.GetThreadList();
for (const auto& thread : thread_list) {
if (thread->GetArbiterWaitAddress() == address) {
threads.push_back(thread);
}
}
// Sort them by priority, such that the highest priority ones come first.
std::sort(threads.begin(), threads.end(),
[](const std::shared_ptr<Thread>& lhs, const std::shared_ptr<Thread>& rhs) {
return lhs->GetPriority() < rhs->GetPriority();
});
return threads;
}
} // namespace Kernel