Refactor random functions to be runner specific (#7816)

This commit is contained in:
FosterProgramming 2025-11-11 21:48:41 +01:00 committed by GitHub
parent fed29c8ae6
commit 44d4e0c1de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 307 additions and 124 deletions

View File

@ -271,4 +271,13 @@ const void *RandomElementArrayDefault(enum RandomTag, const void *array, size_t
u8 RandomWeightedIndex(u8 *weights, u8 length);
#if TESTING
u32 RandomUniformTrials(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller);
u32 RandomUniformDefaultValue(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller);
u32 RandomWeightedArrayTrials(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller);
u32 RandomWeightedArrayDefaultValue(enum RandomTag tag, u32 n, const u8 *weights, void *caller);
const void *RandomElementArrayTrials(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller);
const void *RandomElementArrayDefaultValue(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller);
#endif
#endif // GUARD_RANDOM_H

View File

@ -658,18 +658,12 @@ struct QueuedEvent
} as;
};
struct TurnRNG
{
u16 tag;
u16 value;
};
struct BattlerTurn
{
u8 hit:2;
u8 criticalHit:2;
u8 secondaryEffect:2;
struct TurnRNG rng;
struct RiggedRNG rng;
};
struct ExpectedAIAction
@ -1099,7 +1093,7 @@ enum { TURN_CLOSED, TURN_OPEN, TURN_CLOSING };
#define SKIP_TURN(battler) SkipTurn(__LINE__, battler)
#define SEND_OUT(battler, partyIndex) SendOut(__LINE__, battler, partyIndex)
#define USE_ITEM(battler, ...) UseItem(__LINE__, battler, (struct ItemContext) { R_APPEND_TRUE(__VA_ARGS__) })
#define WITH_RNG(tag, value) rng: ((struct TurnRNG) { tag, value })
#define WITH_RNG(tag, value) rng: ((struct RiggedRNG) { tag, value })
struct MoveContext
{
@ -1124,7 +1118,7 @@ struct MoveContext
u16 explicitNotExpected:1;
struct BattlePokemon *target;
bool8 explicitTarget;
struct TurnRNG rng;
struct RiggedRNG rng;
bool8 explicitRNG;
};
@ -1136,7 +1130,7 @@ struct ItemContext
u16 explicitPartyIndex:1;
u16 move;
u16 explicitMove:1;
struct TurnRNG rng;
struct RiggedRNG rng;
u16 explicitRNG:1;
};

View File

@ -2,8 +2,10 @@
#define GUARD_TEST_H
#include "test_runner.h"
#include "random.h"
#define MAX_PROCESSES 32 // See also tools/mgba-rom-test-hydra/main.c
#define RIGGED_RNG_COUNT 8
enum TestResult
{
@ -26,6 +28,9 @@ struct TestRunner
void (*tearDown)(void *);
bool32 (*checkProgress)(void *);
bool32 (*handleExitWithResult)(void *, enum TestResult);
u32 (*randomUniform)(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller);
u32 (*randomWeightedArray)(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller);
const void* (*randomElementArray)(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller);
};
struct Test
@ -76,11 +81,18 @@ extern const char gTestRunnerArgv[256];
extern const struct TestRunner gAssumptionsRunner;
struct RiggedRNG
{
u16 tag;
u16 value;
};
struct FunctionTestRunnerState
{
u16 parameters;
u16 runParameter;
u16 checkProgressParameter;
struct RiggedRNG rngList[RIGGED_RNG_COUNT];
};
extern const struct TestRunner gFunctionTestRunner;
@ -97,6 +109,8 @@ void Test_ExpectCrash(bool32);
void Test_ExitWithResult(enum TestResult, u32 stopLine, const char *fmt, ...);
u32 SourceLine(u32 sourceLineOffset);
u32 SourceLineOffset(u32 sourceLine);
void SetupRiggedRng(u32 sourceLine, enum RandomTag randomTag, u32 value);
void ClearRiggedRng();
s32 Test_MgbaPrintf(const char *fmt, ...);
@ -245,6 +259,8 @@ static inline struct Benchmark BenchmarkStop(void)
#define PARAMETRIZE_LABEL(f, label) if (gFunctionTestRunnerState->parameters++ == gFunctionTestRunnerState->runParameter && (Test_MgbaPrintf(":N%s: " f " (%d/%d)", gTestRunnerState.test->name, label, gFunctionTestRunnerState->runParameter + 1, gFunctionTestRunnerState->parameters), 1))
#define SET_RNG(tag, value) SetupRiggedRng(__LINE__, tag, value)
#define TO_DO \
do { \
Test_ExpectedResult(TEST_RESULT_TODO); \

View File

@ -10,6 +10,7 @@
#include "constants/characters.h"
#include "test_runner.h"
#include "test/test.h"
#include "test/battle.h"
#define TIMEOUT_SECONDS 60
@ -482,6 +483,7 @@ void Test_ExpectCrash(bool32 expectCrash)
static void FunctionTest_SetUp(void *data)
{
(void)data;
ClearRiggedRng();
gFunctionTestRunnerState = AllocZeroed(sizeof(*gFunctionTestRunnerState));
SeedRng(0);
}
@ -513,12 +515,84 @@ static bool32 FunctionTest_CheckProgress(void *data)
return madeProgress;
}
static u32 FunctionTest_RandomUniform(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller)
{
//rigged
for (u32 i = 0; i < RIGGED_RNG_COUNT; i++)
{
if (gFunctionTestRunnerState->rngList[i].tag == tag)
{
if (reject && reject(gFunctionTestRunnerState->rngList[i].value))
Test_ExitWithResult(TEST_RESULT_INVALID, SourceLine(0), ":LWITH_RNG specified a rejected value (%d)", gFunctionTestRunnerState->rngList[i].value);
return gFunctionTestRunnerState->rngList[i].value;
}
}
//trials
/*
if (tag == STATE->rngTag)
return RandomUniformTrials(tag, lo, hi, reject);
*/
//default
return RandomUniformDefaultValue(tag, lo, hi, reject, caller);
}
static u32 FunctionTest_RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller)
{
//rigged
for (u32 i = 0; i < RIGGED_RNG_COUNT; i++)
{
if (gFunctionTestRunnerState->rngList[i].tag == tag)
return gFunctionTestRunnerState->rngList[i].value;
}
//trials
/*
if (tag == STATE->rngTag)
return RandomWeightedArrayTrials(tag, sum, n, weights);
*/
//default
return RandomWeightedArrayDefaultValue(tag, n, weights, caller);
}
static const void* FunctionTest_RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller)
{
//rigged
for (u32 i = 0; i < RIGGED_RNG_COUNT; i++)
{
if (gFunctionTestRunnerState->rngList[i].tag == tag)
{
u32 element = 0;
for (u32 index = 0; index < count; index++)
{
memcpy(&element, (const u8 *)array + size * index, size);
if (element == gFunctionTestRunnerState->rngList[i].value)
return (const u8 *)array + size * index;
}
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":L%s: RandomElement illegal value requested: %d", gTestRunnerState.test->filename, gFunctionTestRunnerState->rngList[i].value);
}
}
//trials
/*
if (tag == STATE->rngTag)
return RandomElementTrials(tag, array, size, count);
*/
//default
return RandomElementArrayDefaultValue(tag, array, size, count, caller);
}
const struct TestRunner gFunctionTestRunner =
{
.setUp = FunctionTest_SetUp,
.run = FunctionTest_Run,
.tearDown = FunctionTest_TearDown,
.checkProgress = FunctionTest_CheckProgress,
.randomUniform = FunctionTest_RandomUniform,
.randomWeightedArray = FunctionTest_RandomWeightedArray,
.randomElementArray = FunctionTest_RandomElementArray,
};
static void Assumptions_Run(void *data)
@ -826,3 +900,102 @@ u32 SourceLineOffset(u32 sourceLine)
else
return sourceLine - test->sourceLine;
}
u32 RandomUniform(enum RandomTag tag, u32 lo, u32 hi)
{
void *caller = __builtin_extract_return_addr(__builtin_return_address(0));
if (gTestRunnerState.test->runner->randomUniform)
return gTestRunnerState.test->runner->randomUniform(tag, lo, hi, NULL, caller);
else
return RandomUniformDefault(tag, lo, hi);
}
u32 RandomUniformExcept(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32))
{
void *caller = __builtin_extract_return_addr(__builtin_return_address(0));
if (gTestRunnerState.test->runner->randomUniform)
return gTestRunnerState.test->runner->randomUniform(tag, lo, hi, reject, caller);
else
return RandomUniformExceptDefault(tag, lo, hi, reject);
}
u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights)
{
void *caller = __builtin_extract_return_addr(__builtin_return_address(0));
if (gTestRunnerState.test->runner->randomWeightedArray)
return gTestRunnerState.test->runner->randomWeightedArray(tag, sum, n, weights, caller);
else
return RandomWeightedArrayDefault(tag, sum, n, weights);
}
const void *RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count)
{
void *caller = __builtin_extract_return_addr(__builtin_return_address(0));
if (gTestRunnerState.test->runner->randomElementArray)
return gTestRunnerState.test->runner->randomElementArray(tag, array, size, count, caller);
else
return RandomElementArrayDefault(tag, array, size, count);
}
u32 RandomUniformDefaultValue(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller)
{
u32 default_ = hi;
if (reject)
{
while (reject(default_))
{
if (default_ == lo)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d rejected all values", caller, tag);
default_--;
}
}
return default_;
}
u32 RandomWeightedArrayDefaultValue(enum RandomTag tag, u32 n, const u8 *weights, void *caller)
{
while (weights[n-1] == 0)
{
if (n == 1)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called from %p with tag %d and all zero weights", caller, tag);
n--;
}
return n-1;
}
const void *RandomElementArrayDefaultValue(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller)
{
return (const u8 *)array + size * (count - 1);
}
void ClearRiggedRng(void)
{
struct RiggedRNG zeroRng = {.tag = RNG_NONE, .value = 0};
for (u32 i = 0; i < RIGGED_RNG_COUNT; i++)
memcpy(&gFunctionTestRunnerState->rngList[i], &zeroRng, sizeof(zeroRng));
}
void SetupRiggedRng(u32 sourceLine, enum RandomTag randomTag, u32 value)
{
struct RiggedRNG rng = {.tag = randomTag, .value = value};
u32 i;
for (i = 0; i < RIGGED_RNG_COUNT; i++)
{
if (gFunctionTestRunnerState->rngList[i].tag == randomTag)
{
memcpy(&gFunctionTestRunnerState->rngList[i], &rng, sizeof(rng));
break;
}
else if (gFunctionTestRunnerState->rngList[i].tag > RNG_NONE)
{
continue;
}
else
{
memcpy(&gFunctionTestRunnerState->rngList[i], &rng, sizeof(rng));
break;
}
}
if (i == RIGGED_RNG_COUNT)
Test_ExitWithResult(TEST_RESULT_FAIL, __LINE__, ":L%s:%d: Too many rigged RNGs to set up", gTestRunnerState.test->filename, sourceLine);
}

View File

@ -554,22 +554,56 @@ static void BattleTest_Run(void *data)
PrintTestName();
}
u32 RandomUniform(enum RandomTag tag, u32 lo, u32 hi)
u32 RandomUniformTrials(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller)
{
const struct BattlerTurn *turn = NULL;
if (gCurrentTurnActionNumber < gBattlersCount)
STATE->didRunRandomly = TRUE;
if (STATE->trials == 1)
{
u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber];
turn = &DATA.battleRecordTurns[gBattleResults.battleTurnCounter][battlerId];
if (turn && turn->rng.tag == tag)
return turn->rng.value;
u32 n = 0, i;
if (reject)
{
for (i = lo; i <= hi; i++)
if (!reject(i))
n++;
STATE->trials = n;
}
else
STATE->trials = hi - lo + 1;
PrintTestName();
}
STATE->trialRatio = Q_4_12(1) / STATE->trials;
if (!reject)
{
if (STATE->trials != (hi - lo + 1))
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniform called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, hi - lo + 1);
return STATE->runTrial + lo;
}
if (tag == STATE->rngTag)
while (reject(STATE->runTrial + lo + STATE->rngTrialOffset))
{
if (STATE->runTrial + lo + STATE->rngTrialOffset > hi)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d and inconsistent reject", caller, tag);
STATE->rngTrialOffset++;
}
return STATE->runTrial + lo + STATE->rngTrialOffset;
}
u32 RandomWeightedArrayTrials(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller)
{
//Detect inconsistent sum
u32 weightSum = 0;
if (STATE->runTrial == 0)
{
for (u32 i = 0; i < n; i++)
weightSum += weights[i];
if (weightSum != sum)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p has weights not matching its sum", caller);
}
STATE->didRunRandomly = TRUE;
u32 n = hi - lo + 1;
if (STATE->trials == 1)
{
STATE->trials = n;
@ -577,73 +611,56 @@ u32 RandomUniform(enum RandomTag tag, u32 lo, u32 hi)
}
else if (STATE->trials != n)
{
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniform called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, n);
}
STATE->trialRatio = Q_4_12(1) / STATE->trials;
return STATE->runTrial + lo;
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, n);
}
return hi;
STATE->trialRatio = Q_4_12(weights[STATE->runTrial]) / sum;
return STATE->runTrial;
}
u32 RandomUniformExcept(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32))
const void *RandomElementArrayTrials(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller)
{
const struct BattlerTurn *turn = NULL;
u32 default_;
STATE->didRunRandomly = TRUE;
if (STATE->trials == 1)
{
STATE->trials = count;
PrintTestName();
}
else if (STATE->trials != count)
{
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomElement called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, count);
}
STATE->trialRatio = Q_4_12(1) / count;
return (const u8 *)array + size * STATE->runTrial;
}
static u32 BattleTest_RandomUniform(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller)
{
//rigged
const struct BattlerTurn *turn = NULL;
if (gCurrentTurnActionNumber < gBattlersCount)
{
u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber];
turn = &DATA.battleRecordTurns[gBattleResults.battleTurnCounter][battlerId];
if (turn && turn->rng.tag == tag)
{
if (reject(turn->rng.value))
if (reject && reject(turn->rng.value))
Test_ExitWithResult(TEST_RESULT_INVALID, SourceLine(0), ":LWITH_RNG specified a rejected value (%d)", turn->rng.value);
return turn->rng.value;
}
}
//trials
if (tag == STATE->rngTag)
{
STATE->didRunRandomly = TRUE;
if (STATE->trials == 1)
{
u32 n = 0, i;
for (i = lo; i <= hi; i++)
if (!reject(i))
n++;
STATE->trials = n;
PrintTestName();
}
STATE->trialRatio = Q_4_12(1) / STATE->trials;
return RandomUniformTrials(tag, lo, hi, reject, caller);
while (reject(STATE->runTrial + lo + STATE->rngTrialOffset))
{
if (STATE->runTrial + lo + STATE->rngTrialOffset > hi)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d and inconsistent reject", __builtin_extract_return_addr(__builtin_return_address(0)), tag);
STATE->rngTrialOffset++;
}
return STATE->runTrial + lo + STATE->rngTrialOffset;
}
default_ = hi;
while (reject(default_))
{
if (default_ == lo)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d rejected all values", __builtin_extract_return_addr(__builtin_return_address(0)), tag);
default_--;
}
return default_;
//default
return RandomUniformDefaultValue(tag, lo, hi, reject, caller);
}
u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights)
static u32 BattleTest_RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller)
{
//rigged
const struct BattlerTurn *turn = NULL;
if (sum == 0)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called with zero sum");
if (gCurrentTurnActionNumber < gBattlersCount || tag == RNG_SHELL_SIDE_ARM)
{
u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber];
@ -652,23 +669,11 @@ u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights)
return turn->rng.value;
}
//trials
if (tag == STATE->rngTag)
{
STATE->didRunRandomly = TRUE;
if (STATE->trials == 1)
{
STATE->trials = n;
PrintTestName();
}
else if (STATE->trials != n)
{
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, n);
}
// TODO: Detect inconsistent sum.
STATE->trialRatio = Q_4_12(weights[STATE->runTrial]) / sum;
return STATE->runTrial;
}
return RandomWeightedArrayTrials(tag, sum, n, weights, caller);
//default
switch (tag)
{
case RNG_ACCURACY:
@ -693,21 +698,14 @@ u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights)
return TRUE;
default:
while (weights[n-1] == 0)
{
if (n == 1)
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called from %p with tag %d and all zero weights", __builtin_extract_return_addr(__builtin_return_address(0)), tag);
n--;
}
return n-1;
return RandomWeightedArrayDefaultValue(tag, n, weights, caller);
}
}
const void *RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count)
static const void *BattleTest_RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller)
{
//rigged
const struct BattlerTurn *turn = NULL;
u32 index = count-1;
if (gCurrentTurnActionNumber < gBattlersCount)
{
u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber];
@ -715,33 +713,23 @@ const void *RandomElementArray(enum RandomTag tag, const void *array, size_t siz
if (turn && turn->rng.tag == tag)
{
u32 element = 0;
for (index = 0; index < count; index++)
for (u32 index = 0; index < count; index++)
{
memcpy(&element, (const u8 *)array + size * index, size);
if (element == turn->rng.value)
return (const u8 *)array + size * index;
}
// TODO: Incorporate the line number.
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":L%s: RandomElement illegal value requested: %d", gTestRunnerState.test->filename, turn->rng.value);
}
}
//trials
if (tag == STATE->rngTag)
{
STATE->didRunRandomly = TRUE;
if (STATE->trials == 1)
{
STATE->trials = count;
PrintTestName();
}
else if (STATE->trials != count)
{
Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomElement called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, count);
}
STATE->trialRatio = Q_4_12(1) / count;
return (const u8 *)array + size * STATE->runTrial;
}
return (const u8 *)array + size * index;
return RandomElementArrayTrials(tag, array, size, count, caller);
//default
return RandomElementArrayDefaultValue(tag, array, size, count, caller);
}
static s32 TryAbilityPopUp(s32 i, s32 n, u32 battlerId, enum Ability ability)
@ -1738,6 +1726,9 @@ const struct TestRunner gBattleTestRunner =
.tearDown = BattleTest_TearDown,
.checkProgress = BattleTest_CheckProgress,
.handleExitWithResult = BattleTest_HandleExitWithResult,
.randomUniform = BattleTest_RandomUniform,
.randomWeightedArray = BattleTest_RandomWeightedArray,
.randomElementArray = BattleTest_RandomElementArray,
};
void SetFlagForTest(u32 sourceLine, u16 flagId)