Common/Fiber: Implement Rewinding.
This commit is contained in:
		
							parent
							
								
									41013381d6
								
							
						
					
					
						commit
						137d862d9b
					
				| @ -12,10 +12,13 @@ | ||||
| 
 | ||||
| namespace Common { | ||||
| 
 | ||||
| constexpr std::size_t default_stack_size = 256 * 1024; // 256kb
 | ||||
| 
 | ||||
| #if defined(_WIN32) || defined(WIN32) | ||||
| 
 | ||||
| struct Fiber::FiberImpl { | ||||
|     LPVOID handle = nullptr; | ||||
|     LPVOID rewind_handle = nullptr; | ||||
| }; | ||||
| 
 | ||||
| void Fiber::start() { | ||||
| @ -26,15 +29,29 @@ void Fiber::start() { | ||||
|     UNREACHABLE(); | ||||
| } | ||||
| 
 | ||||
| void Fiber::onRewind() { | ||||
|     ASSERT(impl->handle != nullptr); | ||||
|     DeleteFiber(impl->handle); | ||||
|     impl->handle = impl->rewind_handle; | ||||
|     impl->rewind_handle = nullptr; | ||||
|     rewind_point(rewind_parameter); | ||||
|     UNREACHABLE(); | ||||
| } | ||||
| 
 | ||||
| void __stdcall Fiber::FiberStartFunc(void* fiber_parameter) { | ||||
|     auto fiber = static_cast<Fiber*>(fiber_parameter); | ||||
|     fiber->start(); | ||||
| } | ||||
| 
 | ||||
| void __stdcall Fiber::RewindStartFunc(void* fiber_parameter) { | ||||
|     auto fiber = static_cast<Fiber*>(fiber_parameter); | ||||
|     fiber->onRewind(); | ||||
| } | ||||
| 
 | ||||
| Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) | ||||
|     : entry_point{std::move(entry_point_func)}, start_parameter{start_parameter} { | ||||
|     impl = std::make_unique<FiberImpl>(); | ||||
|     impl->handle = CreateFiber(0, &FiberStartFunc, this); | ||||
|     impl->handle = CreateFiber(default_stack_size, &FiberStartFunc, this); | ||||
| } | ||||
| 
 | ||||
| Fiber::Fiber() { | ||||
| @ -60,6 +77,18 @@ void Fiber::Exit() { | ||||
|     guard.unlock(); | ||||
| } | ||||
| 
 | ||||
| void Fiber::SetRewindPoint(std::function<void(void*)>&& rewind_func, void* start_parameter) { | ||||
|     rewind_point = std::move(rewind_func); | ||||
|     rewind_parameter = start_parameter; | ||||
| } | ||||
| 
 | ||||
| void Fiber::Rewind() { | ||||
|     ASSERT(rewind_point); | ||||
|     ASSERT(impl->rewind_handle == nullptr); | ||||
|     impl->rewind_handle = CreateFiber(default_stack_size, &RewindStartFunc, this); | ||||
|     SwitchToFiber(impl->rewind_handle); | ||||
| } | ||||
| 
 | ||||
| void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { | ||||
|     ASSERT_MSG(from != nullptr, "Yielding fiber is null!"); | ||||
|     ASSERT_MSG(to != nullptr, "Next fiber is null!"); | ||||
| @ -81,7 +110,6 @@ std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | ||||
| } | ||||
| 
 | ||||
| #else | ||||
| constexpr std::size_t default_stack_size = 1024 * 1024; // 1MB
 | ||||
| 
 | ||||
| struct Fiber::FiberImpl { | ||||
|     alignas(64) std::array<u8, default_stack_size> stack; | ||||
|  | ||||
| @ -46,6 +46,10 @@ public: | ||||
|     static void YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to); | ||||
|     static std::shared_ptr<Fiber> ThreadToFiber(); | ||||
| 
 | ||||
|     void SetRewindPoint(std::function<void(void*)>&& rewind_func, void* start_parameter); | ||||
| 
 | ||||
|     void Rewind(); | ||||
| 
 | ||||
|     /// Only call from main thread's fiber
 | ||||
|     void Exit(); | ||||
| 
 | ||||
| @ -58,8 +62,10 @@ private: | ||||
|     Fiber(); | ||||
| 
 | ||||
| #if defined(_WIN32) || defined(WIN32) | ||||
|     void onRewind(); | ||||
|     void start(); | ||||
|     static void FiberStartFunc(void* fiber_parameter); | ||||
|     static void RewindStartFunc(void* fiber_parameter); | ||||
| #else | ||||
|     void start(boost::context::detail::transfer_t& transfer); | ||||
|     static void FiberStartFunc(boost::context::detail::transfer_t transfer); | ||||
| @ -69,6 +75,8 @@ private: | ||||
| 
 | ||||
|     SpinLock guard{}; | ||||
|     std::function<void(void*)> entry_point{}; | ||||
|     std::function<void(void*)> rewind_point{}; | ||||
|     void* rewind_parameter{}; | ||||
|     void* start_parameter{}; | ||||
|     std::shared_ptr<Fiber> previous_fiber{}; | ||||
|     std::unique_ptr<FiberImpl> impl; | ||||
|  | ||||
| @ -309,4 +309,50 @@ TEST_CASE("Fibers::StartRace", "[common]") { | ||||
|     REQUIRE(test_control.value3 == 1); | ||||
| } | ||||
| 
 | ||||
| class TestControl4; | ||||
| 
 | ||||
| static void WorkControl4(void* control); | ||||
| 
 | ||||
| class TestControl4 { | ||||
| public: | ||||
|     TestControl4() { | ||||
|         fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl4}, this); | ||||
|         goal_reached = false; | ||||
|         rewinded = false; | ||||
|     } | ||||
| 
 | ||||
|     void Execute() { | ||||
|         thread_fiber = Fiber::ThreadToFiber(); | ||||
|         Fiber::YieldTo(thread_fiber, fiber1); | ||||
|         thread_fiber->Exit(); | ||||
|     } | ||||
| 
 | ||||
|     void DoWork() { | ||||
|         fiber1->SetRewindPoint(std::function<void(void*)>{WorkControl4}, this); | ||||
|         if (rewinded) { | ||||
|             goal_reached = true; | ||||
|             Fiber::YieldTo(fiber1, thread_fiber); | ||||
|         } | ||||
|         rewinded = true; | ||||
|         fiber1->Rewind(); | ||||
|     } | ||||
| 
 | ||||
|     std::shared_ptr<Common::Fiber> fiber1; | ||||
|     std::shared_ptr<Common::Fiber> thread_fiber; | ||||
|     bool goal_reached; | ||||
|     bool rewinded; | ||||
| }; | ||||
| 
 | ||||
| static void WorkControl4(void* control) { | ||||
|     auto* test_control = static_cast<TestControl4*>(control); | ||||
|     test_control->DoWork(); | ||||
| } | ||||
| 
 | ||||
| TEST_CASE("Fibers::Rewind", "[common]") { | ||||
|     TestControl4 test_control{}; | ||||
|     test_control.Execute(); | ||||
|     REQUIRE(test_control.goal_reached); | ||||
|     REQUIRE(test_control.rewinded); | ||||
| } | ||||
| 
 | ||||
| } // namespace Common
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Fernando Sahmkow
						Fernando Sahmkow