diff --git a/NTOSKRNL/KE/mutex.cpp b/NTOSKRNL/KE/mutex.cpp index 713ee57..cd54a06 100644 --- a/NTOSKRNL/KE/mutex.cpp +++ b/NTOSKRNL/KE/mutex.cpp @@ -20,6 +20,12 @@ constexpr ULONG MUTEX_READY_TO_BE_AQUIRED = 0; /*Internal Function*/ +/* Fast Mutex definitions */ +#define FM_LOCK_BIT 0x1 +#define FM_LOCK_BIT_V 0x0 +#define FM_LOCK_WAITER_WOKEN 0x2 +#define FM_LOCK_WAITER_INC 0x4 + typedef struct _FAST_MUTEX { LONG Count; //0x0 @@ -29,6 +35,8 @@ typedef struct _FAST_MUTEX ULONG OldIrql; //0x1c } FAST_MUTEX, *PFAST_MUTEX; //0x20 bytes (sizeof) +typedef PFAST_MUTEX PKGUARDED_MUTEX; + /*Internal Functio*/ VOID FASTCALL @@ -36,9 +44,9 @@ KiAcquireFastMutex( _Inout_ PFAST_MUTEX Mutex ) { - LONG AcquireMarker; - LONG AcquireBit; - LONG OldCount; + LONG AcquireMarker = {0}; + LONG AcquireBit = {0}; + LONG OldCount = {0}; PAGED_CODE(); @@ -49,7 +57,6 @@ KiAcquireFastMutex( AcquireMarker = 4; AcquireBit = 1; -AcquireLoop: while(true) { /* Read current count */ @@ -66,7 +73,7 @@ AcquireLoop: AcquireMarker = 2; AcquireBit = 3; - goto AcquireLoop; + continue; } } else @@ -81,6 +88,43 @@ AcquireLoop: } } +FASTCALL +KeReleaseFastMutexContended( + IN PFAST_MUTEX FastMutex, + IN LONG OldValue) +{ + BOOLEAN WakeWaiter = false; + LONG NewValue = {0}; + PKTHREAD WokenThread = nullptr; + KPRIORITY HandoffPriority = {0}; + + /* Loop until we successfully update the mutex state */ + for (;;) + { + WakeWaiter = false; + NewValue = OldValue + FM_LOCK_BIT; + + if (!(OldValue & FM_LOCK_WAITER_WOKEN)) + { + NewValue = OldValue - FM_LOCK_BIT; + WakeWaiter = true; + } + + LONG PreviousValue = InterlockedCompareExchange(&FastMutex->Lock, NewValue, OldValue); + if (PreviousValue == OldValue) + break; + + OldValue = PreviousValue; + } + + if (WakeWaiter) + { + /* Wake up a waiter */ + KeSetEventBoostPriority(&FastMutex->Event); + } +} + + /* Exported Function */ VOID @@ -89,7 +133,6 @@ KeInitializeFastMutex( _Out_ PFAST_MUTEX Mutex ) { - PAGED_CODE(); /* Initialize the mutex structure */ RtlZeroMemory(Mutex, sizeof(FAST_MUTEX)); @@ -127,7 +170,240 @@ KeTryToAcquireFastMutex( return Result; } +VOID +NTAPI +KeEnterCriticalRegionAndAcquireFastMutexUnsafe( + _In_ PFAST_MUTEX FastMutex) +{ + PKTHREAD OwnerThread = nullptr; + KeEnterCriticalRegion(); + + /* Get the current thread again (following the pseudocode) */ + OwnerThread = KeGetCurrentThread(); + + /* Try to acquire the FastMutex */ + if (_InterlockedBitTestAndReset(&FastMutex->Lock, 0)) + { + /* FastMutex was free, we acquired it */ + FastMutex->Owner = OwnerThread; + } + else + { + /* FastMutex was locked, we need to wait */ + KiAcquireFastMutex(FastMutex); + FastMutex->Owner = OwnerThread; + } +} + +VOID +FASTCALL +KeReleaseFastMutexUnsafeAndLeaveCriticalRegion( + _In_ PFAST_MUTEX FastMutex) +{ + LONG OldValue = {0}; + PKTHREAD CurrentThread = nullptr ; + SHORT NewValue ={0}; + + /* Clear the owner */ + FastMutex->Owner = nullptr; + + /* Try to release the FastMutex */ + OldValue = InterlockedCompareExchange(&FastMutex->Lock, 1, 0); + if (OldValue != 0) + { + /* Contended case, call the contended release function */ + KeReleaseFastMutexContended(FastMutex, OldValue); + } + + /* leave critical region*/ + KeLeaveCriticalRegion(); +} +VOID +NTAPI +KeAcquireFastMutex( + _In_ PFAST_MUTEX FastMutex) +{ + KIRQL OldIrql = {0}; + /* Raise IRQL to APC_LEVEL */ + OldIrql = KeRaiseIrqlToSynchLevel(); + /* Try to acquire the FastMutex */ + if (InterlockedBitTestAndReset(&FastMutex->Lock, 0) == 0) + { + /* We didn't acquire it, we'll have to wait */ + KiAcquireFastMutex(FastMutex); + } + + /* Set the owner thread and save the original IRQL */ + FastMutex->Owner = KeGetCurrentThread(); + FastMutex->OldIrql = OldIrql; +} + +VOID +NTAPI +KeAcquireFastMutexUnsafe( + _In_ PFAST_MUTEX FastMutex) +{ + PKTHREAD CurrentThread = nullptr; + + /* Get the current thread */ + CurrentThread = KeGetCurrentThread(); + + /* Try to acquire the FastMutex */ + if (!InterlockedBitTestAndReset(&FastMutex->Lock, 0)) + { + /* FastMutex was locked, we need to wait */ + KiAcquireFastMutex(FastMutex); + } + + /* Set the owner */ + FastMutex->Owner = CurrentThread; +} + +VOID +NTAPI +KeReleaseFastMutex( + _Inout_ PFAST_MUTEX FastMutex +) +{ + KIRQL OldIrql ={0}; + LONG OldCount ={0}; + + FastMutex->Owner = nullptr; + OldIrql = FastMutex->OldIrql; + + OldCount = InterlockedExchangeAdd(&FastMutex->Count, 1); + + if (OldCount != 0 && + (OldCount & 2) == 0 && + InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1) == OldCount + 1) + { + KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE); + } + + KeLowerIrql(OldIrql); +} + +VOID +NTAPI +KeReleaseFastMutexUnsafe( + _In_ PFAST_MUTEX FastMutex) +{ + LONG OldValue = {0}; + + /* Clear the owner */ + FastMutex->Owner = nullptr; + + /* Release the lock and get the old value */ + OldValue = InterlockedExchangeAdd(&FastMutex->Lock, 1); + + /* Check if there were waiters */ + if (OldValue != 0) + { + /* Check if no waiter has been woken up yet */ + if ((OldValue & FM_LOCK_WAITER_WOKEN) == 0) + { + /* Try to wake up a waiter */ + if (OldValue + 1 == InterlockedCompareExchange(&FastMutex->Lock, + OldValue - 1, + OldValue + 1)) + { + /* Wake up one waiter */ + KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE); + } + } + } +} + +/*Guarded Mutexes in Modern NT behave just like Fast Mutexes with bit of protection */ + +VOID +NTAPI +KeInitializeGuardedMutex(_Out_ PKGUARDED_MUTEX GuardedMutex) +{ + /* Initialize the GuardedMutex*/ + GuardedMutex->Count = 1; + GuardedMutex->Owner = nullptr; + GuardedMutex->Contention = 0; + /* Initialize the Mutex Gate */ + KeInitializeEvent(&Mutex->Event, SynchronizationEvent, FALSE); +} + +VOID +NTAPI +KeAcquireGuardedMutex(_Inout_ PKGUARDED_MUTEX Mutex) +{ + PKTHREAD OwnerThread = KeGetCurrentThread(); + KeEnterGuardedRegion(); + if (!_Interlockedbittestandreset(&Mutex->Count, 0) ) + KiAcquireFastMutex(Mutex); + Mutex->Owner = OwnerThread; +} + +VOID +NTAPI +KeAcquireGuardedMutexUnsafe( + _Inout_ PKGUARDED_MUTEX FastMutex +) +{ + PKTHREAD CurrentThread = nullptr; + KeEnterGuardedRegion(); + CurrentThread = KeGetCurrentThread(); + + if (!_InterlockedBitTestAndReset(&FastMutex->Count, 0)) + { + KiAcquireFastMutex(FastMutex); + } + + FastMutex->Owner = CurrentThread; +} + +VOID +NTAPI +KeReleaseGuardedMutexUnsafe( + _Inout_ PKGUARDED_MUTEX FastMutex +) +{ + LONG OldCount ={0}; + + FastMutex->Owner = nullptr; + + OldCount = _InterlockedExchangeAdd(&FastMutex->Count, 1); + + if (OldCount != 0 && + (OldCount & FM_LOCK_WAITER_WOKEN) == 0 && + OldCount + 1 == InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1)) + { + KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE); + } + KeLeaveGuardedRegion(); +} + +VOID +NTAPI +KeReleaseGuardedMutex( + _In_ PKGUARDED_MUTEX FastMutex) +{ + KIRQL OldIrql ={0}; + LONG OldValue ={0}; + + /* Save the old IRQL and clear the owner */ + OldIrql = FastMutex->OldIrql; + FastMutex->Owner = nullptr; + + /* Try to release the FastMutex */ + OldValue = _InterlockedExchangeAdd(&Mutex->Count, 1); + if (OldCount != 0 && + (OldCount & FM_LOCK_WAITER_WOKEN) == 0 && + OldCount + 1 == InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1)) + { + KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE); + } + + /* Lower IRQL */ + KeLowerIrql(OldIrql); + KeLeaveGuardedRegion(); +}