#include <folly/coro/SharedMutex.h>
#if FOLLY_HAS_COROUTINES
using namespace folly::coro;
SharedMutexFair::~SharedMutexFair() {
assert(state_.lock()->lockedFlagAndReaderCount_ == kUnlocked);
assert(state_.lock()->waitersHead_ == nullptr);
}
bool SharedMutexFair::try_lock() noexcept {
auto lock = state_.contextualLock();
if (lock->lockedFlagAndReaderCount_ == kUnlocked) {
lock->lockedFlagAndReaderCount_ = kExclusiveLockFlag;
return true;
}
return false;
}
bool SharedMutexFair::try_lock_shared() noexcept {
auto lock = state_.contextualLock();
if (lock->lockedFlagAndReaderCount_ == kUnlocked ||
(lock->lockedFlagAndReaderCount_ >= kSharedLockCountIncrement &&
lock->waitersHead_ == nullptr)) {
lock->lockedFlagAndReaderCount_ += kSharedLockCountIncrement;
return true;
}
return false;
}
void SharedMutexFair::unlock() noexcept {
LockAwaiterBase* awaitersToResume = nullptr;
{
auto lockedState = state_.contextualLock();
assert(lockedState->lockedFlagAndReaderCount_ == kExclusiveLockFlag);
awaitersToResume = unlockOrGetNextWaitersToResume(*lockedState);
}
resumeWaiters(awaitersToResume);
}
void SharedMutexFair::unlock_shared() noexcept {
LockAwaiterBase* awaitersToResume = nullptr;
{
auto lockedState = state_.contextualLock();
assert(lockedState->lockedFlagAndReaderCount_ >= kSharedLockCountIncrement);
lockedState->lockedFlagAndReaderCount_ -= kSharedLockCountIncrement;
if (lockedState->lockedFlagAndReaderCount_ != kUnlocked) {
return;
}
awaitersToResume = unlockOrGetNextWaitersToResume(*lockedState);
}
resumeWaiters(awaitersToResume);
}
SharedMutexFair::LockAwaiterBase*
SharedMutexFair::unlockOrGetNextWaitersToResume(
SharedMutexFair::State& state) noexcept {
auto* head = state.waitersHead_;
if (head != nullptr) {
if (head->lockType_ == LockType::EXCLUSIVE) {
state.waitersHead_ = std::exchange(head->nextAwaiter_, nullptr);
state.lockedFlagAndReaderCount_ = kExclusiveLockFlag;
} else {
std::size_t newState = kSharedLockCountIncrement;
auto* last = head;
auto* next = last->nextAwaiter_;
while (next != nullptr && next->lockType_ == LockType::SHARED) {
last = next;
next = next->nextAwaiter_;
newState += kSharedLockCountIncrement;
}
last->nextAwaiter_ = nullptr;
state.lockedFlagAndReaderCount_ = newState;
state.waitersHead_ = next;
}
if (state.waitersHead_ == nullptr) {
state.waitersTailNext_ = &state.waitersHead_;
}
} else {
state.lockedFlagAndReaderCount_ = kUnlocked;
}
return head;
}
void SharedMutexFair::resumeWaiters(LockAwaiterBase* awaiters) noexcept {
while (awaiters != nullptr) {
std::exchange(awaiters, awaiters->nextAwaiter_)->resume();
}
}
#endif