/*
 * PROJECT: Alcyone System Kernel
 * LICENSE: BSD Clause 3
 * PURPOSE: Mutexes
 * NT KERNEL: 5.11.9360
 * COPYRIGHT:  2023-2029 Dibymartanda Samanta <>
 */ 


#include <ntoskrnl.h>
#define NTDEBUG

extern "C"
/*Mutex Count :
0 => Can Be aquired, 
1 => Is Aquired by a Thread 
In Negative Indigates, Number of Threads waiting*/

constexpr ULONG MUTEX_READY_TO_BE_AQUIRED = 0;

/*Internal Function*/ 

/* Fast Mutex definitions */
#define FM_LOCK_BIT           0x1
#define FM_LOCK_BIT_V         0x0
#define FM_LOCK_WAITER_WOKEN  0x2
#define FM_LOCK_WAITER_INC    0x4

typedef struct _FAST_MUTEX
{
    LONG Count;                                                             //0x0
    VOID* Owner;                                                            //0x4
    ULONG Contention;                                                       //0x8
    struct _KEVENT Event;                                                   //0xc
    ULONG OldIrql;                                                          //0x1c
} FAST_MUTEX, *PFAST_MUTEX;                                                 //0x20 bytes (sizeof)

typedef PFAST_MUTEX PKGUARDED_MUTEX;

/*Internal Functio*/
VOID
FASTCALL
KiAcquireFastMutex(
    _Inout_ PFAST_MUTEX Mutex
    )
{
    LONG AcquireMarker = {0};
    LONG AcquireBit = {0};
    LONG OldCount = {0};

    PAGED_CODE();

    /* Increment contention count */
    InterlockedIncrement(&Mutex->Contention);

    /* Initialize loop variables */
    AcquireMarker = 4;
    AcquireBit = 1;

    while(true)
    {
        /* Read current count */
        OldCount = ReadForWriteAccess(&Mutex->Count);

        /* Check if mutex is free */
        if ((OldCount & 1) == 0)
        {
            /* Attempt to acquire by incrementing count */
            if (InterlockedCompareExchange(&Mutex->Count, OldCount + AcquireMarker,OldCount) == OldCount)
            {
                /* Wait for the mutex event */
                KeWaitForSingleObject(&Mutex->Event,WrFastMutex,KernelMode,false,0);

                AcquireMarker = 2;
                AcquireBit = 3;
                continue;
            }
        }
        else
        {
            /* Attempt to mark mutex as owned */
            if (InterlockedCompareExchange(&Mutex->Count, AcquireBit ^ OldCount,OldCount) == OldCount)
            {
                /* Mutex acquired successfully */
                break;
            }
        }
    }
}

FASTCALL
KeReleaseFastMutexContended(
    IN PFAST_MUTEX FastMutex,
    IN LONG OldValue)
{
    BOOLEAN WakeWaiter = false;
    LONG NewValue = {0};
    PKTHREAD WokenThread = nullptr;
    KPRIORITY HandoffPriority = {0};

    /* Loop until we successfully update the mutex state */
    for (;;)
    {
        WakeWaiter = false;
        NewValue = OldValue + FM_LOCK_BIT;

        if (!(OldValue & FM_LOCK_WAITER_WOKEN))
        {
            NewValue = OldValue - FM_LOCK_BIT;
            WakeWaiter = true;
        }

        LONG PreviousValue = InterlockedCompareExchange(&FastMutex->Lock, NewValue, OldValue);
        if (PreviousValue == OldValue)
            break;

        OldValue = PreviousValue;
    }

    if (WakeWaiter)
    {
        /* Wake up a waiter */
        KeSetEventBoostPriority(&FastMutex->Event);
    }
}


/* Exported   Function */ 

VOID
NTAPI
KeInitializeFastMutex(
    _Out_ PFAST_MUTEX Mutex
    )
{

    /* Initialize the mutex structure */
    RtlZeroMemory(Mutex, sizeof(FAST_MUTEX));

    /* Set initial values */
    Mutex->Owner = nullptr;
    Mutex->Contention = 0;
    Mutex->Count = 1;

    /* Initialize the Mutex Gate */
    KeInitializeEvent(&Mutex->Event, SynchronizationEvent, FALSE);
}

BOOLEAN
VECTORCALL
KeTryToAcquireFastMutex(
    _Inout_ PFAST_MUTEX Mutex)
{
    KIRQL CurrentIrql = KeGetCurrentIrql();
    BOOLEAN Result = false; 
    if(_InterlockedBitTestAndReset(&FastMutex->Count, MUTEX_READY_TO_BE_AQUIRED))
    {  
        FastMutex->Owner = (PVOID)KeGetCurrentThread();
        Mutex->OldIrql = KeRaiseIrql(APC_LEVEL);
        Result = TRUE;
    }
    else 
    {
        /* Failed to acquire the mutex */
        KeLowerIrql(CurrentIrql);
        KeYieldProcessor();
        Result = FALSE;
    }
    
return Result;
}

VOID
NTAPI
KeEnterCriticalRegionAndAcquireFastMutexUnsafe( 
    _In_ PFAST_MUTEX FastMutex)
{
    PKTHREAD OwnerThread = nullptr;
	KeEnterCriticalRegion();
    
    /* Get the current thread again (following the pseudocode) */
    OwnerThread = KeGetCurrentThread();

    /* Try to acquire the FastMutex */
    if (_InterlockedBitTestAndReset(&FastMutex->Lock, 0))
    {
        /* FastMutex was free, we acquired it */
        FastMutex->Owner = OwnerThread;
    }
    else
    {
        /* FastMutex was locked, we need to wait */
        KiAcquireFastMutex(FastMutex);
        FastMutex->Owner = OwnerThread;
    }
}

VOID
FASTCALL
KeReleaseFastMutexUnsafeAndLeaveCriticalRegion(  
    _In_ PFAST_MUTEX FastMutex)
{
    LONG OldValue = {0};
    PKTHREAD CurrentThread = nullptr ;
    SHORT NewValue ={0};

    /* Clear the owner */
    FastMutex->Owner = nullptr;

    /* Try to release the FastMutex */
    OldValue = InterlockedCompareExchange(&FastMutex->Lock, 1, 0);
    if (OldValue != 0)
    {
        /* Contended case, call the contended release function */
        KeReleaseFastMutexContended(FastMutex, OldValue);
    }

    /* leave critical region*/
    KeLeaveCriticalRegion();
}


VOID
NTAPI
KeAcquireFastMutex( 
    _In_ PFAST_MUTEX FastMutex)
{
    KIRQL OldIrql = {0};

    /* Raise IRQL to APC_LEVEL */
    OldIrql = KeRaiseIrqlToSynchLevel();

    /* Try to acquire the FastMutex */
    if (InterlockedBitTestAndReset(&FastMutex->Lock, 0) == 0)
    {
        /* We didn't acquire it, we'll have to wait */
        KiAcquireFastMutex(FastMutex);
    }

    /* Set the owner thread and save the original IRQL */
    FastMutex->Owner = KeGetCurrentThread();
    FastMutex->OldIrql = OldIrql;
}

VOID
NTAPI
KeAcquireFastMutexUnsafe(
    _In_ PFAST_MUTEX FastMutex)  
{
    PKTHREAD CurrentThread = nullptr;

    /* Get the current thread */
    CurrentThread = KeGetCurrentThread();

    /* Try to acquire the FastMutex */
    if (!InterlockedBitTestAndReset(&FastMutex->Lock, 0))
    {
        /* FastMutex was locked, we need to wait */
        KiAcquireFastMutex(FastMutex);
    }

    /* Set the owner */
    FastMutex->Owner = CurrentThread;
}

VOID
NTAPI
KeReleaseFastMutex(  
    _Inout_ PFAST_MUTEX FastMutex
)
{
    KIRQL OldIrql ={0};
    LONG OldCount ={0};

    FastMutex->Owner = nullptr;
    OldIrql = FastMutex->OldIrql;

    OldCount = InterlockedExchangeAdd(&FastMutex->Count, 1);

    if (OldCount != 0 && 
        (OldCount & 2) == 0 && 
        InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1) == OldCount + 1)
    {
        KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE);
    }

    KeLowerIrql(OldIrql);
}

VOID
NTAPI
KeReleaseFastMutexUnsafe( 
    _In_ PFAST_MUTEX FastMutex)
{
    LONG OldValue = {0};

    /* Clear the owner */
    FastMutex->Owner = nullptr;

    /* Release the lock and get the old value */
    OldValue = InterlockedExchangeAdd(&FastMutex->Lock, 1);

    /* Check if there were waiters */
    if (OldValue != 0)
    {
        /* Check if no waiter has been woken up yet */
        if ((OldValue & FM_LOCK_WAITER_WOKEN) == 0)
        {
            /* Try to wake up a waiter */
            if (OldValue + 1 == InterlockedCompareExchange(&FastMutex->Lock,
                                                           OldValue - 1,
                                                           OldValue + 1))
            {
                /* Wake up one waiter */
                KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE);
            }
        }
    }
}

/*Guarded Mutexes in Modern  NT behave just like Fast Mutexes with bit of protection */

VOID
NTAPI
KeInitializeGuardedMutex(_Out_ PKGUARDED_MUTEX GuardedMutex)
{
    /* Initialize the GuardedMutex*/
    GuardedMutex->Count = 1;
    GuardedMutex->Owner = nullptr;
    GuardedMutex->Contention = 0;
   /* Initialize the Mutex Gate */
    KeInitializeEvent(&Mutex->Event, SynchronizationEvent, FALSE);
}

VOID 
NTAPI  
KeAcquireGuardedMutex(_Inout_ PKGUARDED_MUTEX Mutex)
{
  PKTHREAD OwnerThread = KeGetCurrentThread();
  KeEnterGuardedRegion();
  if (!_Interlockedbittestandreset(&Mutex->Count, 0) )
    KiAcquireFastMutex(Mutex);
  Mutex->Owner = OwnerThread;
}

VOID
NTAPI
KeAcquireGuardedMutexUnsafe(
    _Inout_ PKGUARDED_MUTEX FastMutex
)
{   
    PKTHREAD CurrentThread = nullptr;
    KeEnterGuardedRegion();
    CurrentThread = KeGetCurrentThread();

    if (!_InterlockedBitTestAndReset(&FastMutex->Count, 0))
    {
        KiAcquireFastMutex(FastMutex);
    }

    FastMutex->Owner = CurrentThread;
}

VOID
NTAPI
KeReleaseGuardedMutexUnsafe(
    _Inout_ PKGUARDED_MUTEX FastMutex
)
{
    LONG OldCount ={0};

    FastMutex->Owner = nullptr;

    OldCount = _InterlockedExchangeAdd(&FastMutex->Count, 1);

    if (OldCount != 0 && 
        (OldCount & FM_LOCK_WAITER_WOKEN) == 0 && 
        OldCount + 1 == InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1))
    {
        KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE);
    }
    KeLeaveGuardedRegion();
}

VOID
NTAPI
KeReleaseGuardedMutex(
    _In_ PKGUARDED_MUTEX FastMutex) 
{
    KIRQL OldIrql ={0};
    LONG OldValue ={0};

    /* Save the old IRQL and clear the owner */
    OldIrql = FastMutex->OldIrql;
    FastMutex->Owner = nullptr;

    /* Try to release the FastMutex */
    OldValue = _InterlockedExchangeAdd(&Mutex->Count, 1);
     if (OldCount != 0 && 
        (OldCount & FM_LOCK_WAITER_WOKEN) == 0 && 
        OldCount + 1 == InterlockedCompareExchange(&FastMutex->Count, OldCount - 1, OldCount + 1))
    {
        KeSetEvent(&FastMutex->Event, IO_NO_INCREMENT, FALSE);
    }

    /* Lower IRQL */
    KeLowerIrql(OldIrql);
    KeLeaveGuardedRegion();
}