From 44d4e0c1de9ee4aecaaf5588ac4107b988f6e097 Mon Sep 17 00:00:00 2001 From: FosterProgramming Date: Tue, 11 Nov 2025 21:48:41 +0100 Subject: [PATCH] Refactor random functions to be runner specific (#7816) --- include/random.h | 9 ++ include/test/battle.h | 14 +-- include/test/test.h | 16 +++ test/test_runner.c | 173 ++++++++++++++++++++++++++++++ test/test_runner_battle.c | 219 ++++++++++++++++++-------------------- 5 files changed, 307 insertions(+), 124 deletions(-) diff --git a/include/random.h b/include/random.h index 08650c0c78..84318bedc0 100644 --- a/include/random.h +++ b/include/random.h @@ -271,4 +271,13 @@ const void *RandomElementArrayDefault(enum RandomTag, const void *array, size_t u8 RandomWeightedIndex(u8 *weights, u8 length); +#if TESTING +u32 RandomUniformTrials(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller); +u32 RandomUniformDefaultValue(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller); +u32 RandomWeightedArrayTrials(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller); +u32 RandomWeightedArrayDefaultValue(enum RandomTag tag, u32 n, const u8 *weights, void *caller); +const void *RandomElementArrayTrials(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller); +const void *RandomElementArrayDefaultValue(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller); +#endif + #endif // GUARD_RANDOM_H diff --git a/include/test/battle.h b/include/test/battle.h index bf1e7ceee9..d0f0458a7e 100644 --- a/include/test/battle.h +++ b/include/test/battle.h @@ -658,18 +658,12 @@ struct QueuedEvent } as; }; -struct TurnRNG -{ - u16 tag; - u16 value; -}; - struct BattlerTurn { u8 hit:2; u8 criticalHit:2; u8 secondaryEffect:2; - struct TurnRNG rng; + struct RiggedRNG rng; }; struct ExpectedAIAction @@ -1099,7 +1093,7 @@ enum { TURN_CLOSED, TURN_OPEN, TURN_CLOSING }; #define SKIP_TURN(battler) SkipTurn(__LINE__, battler) #define SEND_OUT(battler, partyIndex) SendOut(__LINE__, battler, partyIndex) #define USE_ITEM(battler, ...) UseItem(__LINE__, battler, (struct ItemContext) { R_APPEND_TRUE(__VA_ARGS__) }) -#define WITH_RNG(tag, value) rng: ((struct TurnRNG) { tag, value }) +#define WITH_RNG(tag, value) rng: ((struct RiggedRNG) { tag, value }) struct MoveContext { @@ -1124,7 +1118,7 @@ struct MoveContext u16 explicitNotExpected:1; struct BattlePokemon *target; bool8 explicitTarget; - struct TurnRNG rng; + struct RiggedRNG rng; bool8 explicitRNG; }; @@ -1136,7 +1130,7 @@ struct ItemContext u16 explicitPartyIndex:1; u16 move; u16 explicitMove:1; - struct TurnRNG rng; + struct RiggedRNG rng; u16 explicitRNG:1; }; diff --git a/include/test/test.h b/include/test/test.h index 7b3225934d..15f71d50b1 100644 --- a/include/test/test.h +++ b/include/test/test.h @@ -2,8 +2,10 @@ #define GUARD_TEST_H #include "test_runner.h" +#include "random.h" #define MAX_PROCESSES 32 // See also tools/mgba-rom-test-hydra/main.c +#define RIGGED_RNG_COUNT 8 enum TestResult { @@ -26,6 +28,9 @@ struct TestRunner void (*tearDown)(void *); bool32 (*checkProgress)(void *); bool32 (*handleExitWithResult)(void *, enum TestResult); + u32 (*randomUniform)(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller); + u32 (*randomWeightedArray)(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller); + const void* (*randomElementArray)(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller); }; struct Test @@ -76,11 +81,18 @@ extern const char gTestRunnerArgv[256]; extern const struct TestRunner gAssumptionsRunner; +struct RiggedRNG +{ + u16 tag; + u16 value; +}; + struct FunctionTestRunnerState { u16 parameters; u16 runParameter; u16 checkProgressParameter; + struct RiggedRNG rngList[RIGGED_RNG_COUNT]; }; extern const struct TestRunner gFunctionTestRunner; @@ -97,6 +109,8 @@ void Test_ExpectCrash(bool32); void Test_ExitWithResult(enum TestResult, u32 stopLine, const char *fmt, ...); u32 SourceLine(u32 sourceLineOffset); u32 SourceLineOffset(u32 sourceLine); +void SetupRiggedRng(u32 sourceLine, enum RandomTag randomTag, u32 value); +void ClearRiggedRng(); s32 Test_MgbaPrintf(const char *fmt, ...); @@ -245,6 +259,8 @@ static inline struct Benchmark BenchmarkStop(void) #define PARAMETRIZE_LABEL(f, label) if (gFunctionTestRunnerState->parameters++ == gFunctionTestRunnerState->runParameter && (Test_MgbaPrintf(":N%s: " f " (%d/%d)", gTestRunnerState.test->name, label, gFunctionTestRunnerState->runParameter + 1, gFunctionTestRunnerState->parameters), 1)) +#define SET_RNG(tag, value) SetupRiggedRng(__LINE__, tag, value) + #define TO_DO \ do { \ Test_ExpectedResult(TEST_RESULT_TODO); \ diff --git a/test/test_runner.c b/test/test_runner.c index ac4b20cd2c..ef12e8d5f1 100644 --- a/test/test_runner.c +++ b/test/test_runner.c @@ -10,6 +10,7 @@ #include "constants/characters.h" #include "test_runner.h" #include "test/test.h" +#include "test/battle.h" #define TIMEOUT_SECONDS 60 @@ -482,6 +483,7 @@ void Test_ExpectCrash(bool32 expectCrash) static void FunctionTest_SetUp(void *data) { (void)data; + ClearRiggedRng(); gFunctionTestRunnerState = AllocZeroed(sizeof(*gFunctionTestRunnerState)); SeedRng(0); } @@ -513,12 +515,84 @@ static bool32 FunctionTest_CheckProgress(void *data) return madeProgress; } +static u32 FunctionTest_RandomUniform(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller) +{ + //rigged + for (u32 i = 0; i < RIGGED_RNG_COUNT; i++) + { + if (gFunctionTestRunnerState->rngList[i].tag == tag) + { + if (reject && reject(gFunctionTestRunnerState->rngList[i].value)) + Test_ExitWithResult(TEST_RESULT_INVALID, SourceLine(0), ":LWITH_RNG specified a rejected value (%d)", gFunctionTestRunnerState->rngList[i].value); + return gFunctionTestRunnerState->rngList[i].value; + } + } + //trials + /* + if (tag == STATE->rngTag) + return RandomUniformTrials(tag, lo, hi, reject); + */ + + //default + return RandomUniformDefaultValue(tag, lo, hi, reject, caller); +} + +static u32 FunctionTest_RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller) +{ + //rigged + for (u32 i = 0; i < RIGGED_RNG_COUNT; i++) + { + if (gFunctionTestRunnerState->rngList[i].tag == tag) + return gFunctionTestRunnerState->rngList[i].value; + } + + //trials + /* + if (tag == STATE->rngTag) + return RandomWeightedArrayTrials(tag, sum, n, weights); + */ + + //default + return RandomWeightedArrayDefaultValue(tag, n, weights, caller); +} + +static const void* FunctionTest_RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller) +{ + //rigged + for (u32 i = 0; i < RIGGED_RNG_COUNT; i++) + { + if (gFunctionTestRunnerState->rngList[i].tag == tag) + { + u32 element = 0; + for (u32 index = 0; index < count; index++) + { + memcpy(&element, (const u8 *)array + size * index, size); + if (element == gFunctionTestRunnerState->rngList[i].value) + return (const u8 *)array + size * index; + } + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":L%s: RandomElement illegal value requested: %d", gTestRunnerState.test->filename, gFunctionTestRunnerState->rngList[i].value); + } + } + + //trials + /* + if (tag == STATE->rngTag) + return RandomElementTrials(tag, array, size, count); + */ + + //default + return RandomElementArrayDefaultValue(tag, array, size, count, caller); +} + const struct TestRunner gFunctionTestRunner = { .setUp = FunctionTest_SetUp, .run = FunctionTest_Run, .tearDown = FunctionTest_TearDown, .checkProgress = FunctionTest_CheckProgress, + .randomUniform = FunctionTest_RandomUniform, + .randomWeightedArray = FunctionTest_RandomWeightedArray, + .randomElementArray = FunctionTest_RandomElementArray, }; static void Assumptions_Run(void *data) @@ -826,3 +900,102 @@ u32 SourceLineOffset(u32 sourceLine) else return sourceLine - test->sourceLine; } + +u32 RandomUniform(enum RandomTag tag, u32 lo, u32 hi) +{ + void *caller = __builtin_extract_return_addr(__builtin_return_address(0)); + if (gTestRunnerState.test->runner->randomUniform) + return gTestRunnerState.test->runner->randomUniform(tag, lo, hi, NULL, caller); + else + return RandomUniformDefault(tag, lo, hi); +} + +u32 RandomUniformExcept(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32)) +{ + void *caller = __builtin_extract_return_addr(__builtin_return_address(0)); + if (gTestRunnerState.test->runner->randomUniform) + return gTestRunnerState.test->runner->randomUniform(tag, lo, hi, reject, caller); + else + return RandomUniformExceptDefault(tag, lo, hi, reject); +} + +u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights) +{ + void *caller = __builtin_extract_return_addr(__builtin_return_address(0)); + if (gTestRunnerState.test->runner->randomWeightedArray) + return gTestRunnerState.test->runner->randomWeightedArray(tag, sum, n, weights, caller); + else + return RandomWeightedArrayDefault(tag, sum, n, weights); +} + +const void *RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count) +{ + void *caller = __builtin_extract_return_addr(__builtin_return_address(0)); + if (gTestRunnerState.test->runner->randomElementArray) + return gTestRunnerState.test->runner->randomElementArray(tag, array, size, count, caller); + else + return RandomElementArrayDefault(tag, array, size, count); +} + +u32 RandomUniformDefaultValue(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller) +{ + u32 default_ = hi; + if (reject) + { + while (reject(default_)) + { + if (default_ == lo) + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d rejected all values", caller, tag); + default_--; + } + } + return default_; +} + +u32 RandomWeightedArrayDefaultValue(enum RandomTag tag, u32 n, const u8 *weights, void *caller) +{ + while (weights[n-1] == 0) + { + if (n == 1) + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called from %p with tag %d and all zero weights", caller, tag); + n--; + } + return n-1; +} + +const void *RandomElementArrayDefaultValue(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller) +{ + return (const u8 *)array + size * (count - 1); +} + +void ClearRiggedRng(void) +{ + struct RiggedRNG zeroRng = {.tag = RNG_NONE, .value = 0}; + for (u32 i = 0; i < RIGGED_RNG_COUNT; i++) + memcpy(&gFunctionTestRunnerState->rngList[i], &zeroRng, sizeof(zeroRng)); +} + +void SetupRiggedRng(u32 sourceLine, enum RandomTag randomTag, u32 value) +{ + struct RiggedRNG rng = {.tag = randomTag, .value = value}; + u32 i; + for (i = 0; i < RIGGED_RNG_COUNT; i++) + { + if (gFunctionTestRunnerState->rngList[i].tag == randomTag) + { + memcpy(&gFunctionTestRunnerState->rngList[i], &rng, sizeof(rng)); + break; + } + else if (gFunctionTestRunnerState->rngList[i].tag > RNG_NONE) + { + continue; + } + else + { + memcpy(&gFunctionTestRunnerState->rngList[i], &rng, sizeof(rng)); + break; + } + } + if (i == RIGGED_RNG_COUNT) + Test_ExitWithResult(TEST_RESULT_FAIL, __LINE__, ":L%s:%d: Too many rigged RNGs to set up", gTestRunnerState.test->filename, sourceLine); +} diff --git a/test/test_runner_battle.c b/test/test_runner_battle.c index daa65869bb..a42b671552 100644 --- a/test/test_runner_battle.c +++ b/test/test_runner_battle.c @@ -554,96 +554,113 @@ static void BattleTest_Run(void *data) PrintTestName(); } -u32 RandomUniform(enum RandomTag tag, u32 lo, u32 hi) +u32 RandomUniformTrials(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller) { - const struct BattlerTurn *turn = NULL; - - if (gCurrentTurnActionNumber < gBattlersCount) + STATE->didRunRandomly = TRUE; + if (STATE->trials == 1) { - u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber]; - turn = &DATA.battleRecordTurns[gBattleResults.battleTurnCounter][battlerId]; - if (turn && turn->rng.tag == tag) - return turn->rng.value; - } - - if (tag == STATE->rngTag) - { - STATE->didRunRandomly = TRUE; - u32 n = hi - lo + 1; - if (STATE->trials == 1) + u32 n = 0, i; + if (reject) { - STATE->trials = n; - PrintTestName(); - } - else if (STATE->trials != n) - { - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniform called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, n); - } - STATE->trialRatio = Q_4_12(1) / STATE->trials; - return STATE->runTrial + lo; - } - - return hi; -} - -u32 RandomUniformExcept(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32)) -{ - const struct BattlerTurn *turn = NULL; - u32 default_; - - if (gCurrentTurnActionNumber < gBattlersCount) - { - u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber]; - turn = &DATA.battleRecordTurns[gBattleResults.battleTurnCounter][battlerId]; - if (turn && turn->rng.tag == tag) - { - if (reject(turn->rng.value)) - Test_ExitWithResult(TEST_RESULT_INVALID, SourceLine(0), ":LWITH_RNG specified a rejected value (%d)", turn->rng.value); - return turn->rng.value; - } - } - - if (tag == STATE->rngTag) - { - STATE->didRunRandomly = TRUE; - if (STATE->trials == 1) - { - u32 n = 0, i; for (i = lo; i <= hi; i++) if (!reject(i)) n++; STATE->trials = n; - PrintTestName(); } - STATE->trialRatio = Q_4_12(1) / STATE->trials; - - while (reject(STATE->runTrial + lo + STATE->rngTrialOffset)) - { - if (STATE->runTrial + lo + STATE->rngTrialOffset > hi) - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d and inconsistent reject", __builtin_extract_return_addr(__builtin_return_address(0)), tag); - STATE->rngTrialOffset++; - } - - return STATE->runTrial + lo + STATE->rngTrialOffset; + else + STATE->trials = hi - lo + 1; + PrintTestName(); } + STATE->trialRatio = Q_4_12(1) / STATE->trials; - default_ = hi; - while (reject(default_)) + if (!reject) { - if (default_ == lo) - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d rejected all values", __builtin_extract_return_addr(__builtin_return_address(0)), tag); - default_--; + if (STATE->trials != (hi - lo + 1)) + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniform called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, hi - lo + 1); + return STATE->runTrial + lo; } - return default_; + + while (reject(STATE->runTrial + lo + STATE->rngTrialOffset)) + { + if (STATE->runTrial + lo + STATE->rngTrialOffset > hi) + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomUniformExcept called from %p with tag %d and inconsistent reject", caller, tag); + STATE->rngTrialOffset++; + } + + return STATE->runTrial + lo + STATE->rngTrialOffset; + } -u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights) +u32 RandomWeightedArrayTrials(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller) { + //Detect inconsistent sum + u32 weightSum = 0; + if (STATE->runTrial == 0) + { + for (u32 i = 0; i < n; i++) + weightSum += weights[i]; + if (weightSum != sum) + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p has weights not matching its sum", caller); + } + + STATE->didRunRandomly = TRUE; + if (STATE->trials == 1) + { + STATE->trials = n; + PrintTestName(); + } + else if (STATE->trials != n) + { + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, n); + } + + STATE->trialRatio = Q_4_12(weights[STATE->runTrial]) / sum; + return STATE->runTrial; +} + +const void *RandomElementArrayTrials(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller) +{ + STATE->didRunRandomly = TRUE; + if (STATE->trials == 1) + { + STATE->trials = count; + PrintTestName(); + } + else if (STATE->trials != count) + { + Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomElement called from %p with tag %d and inconsistent trials %d and %d", caller, tag, STATE->trials, count); + } + STATE->trialRatio = Q_4_12(1) / count; + return (const u8 *)array + size * STATE->runTrial; +} + +static u32 BattleTest_RandomUniform(enum RandomTag tag, u32 lo, u32 hi, bool32 (*reject)(u32), void *caller) +{ + //rigged const struct BattlerTurn *turn = NULL; + if (gCurrentTurnActionNumber < gBattlersCount) + { + u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber]; + turn = &DATA.battleRecordTurns[gBattleResults.battleTurnCounter][battlerId]; + if (turn && turn->rng.tag == tag) + { + if (reject && reject(turn->rng.value)) + Test_ExitWithResult(TEST_RESULT_INVALID, SourceLine(0), ":LWITH_RNG specified a rejected value (%d)", turn->rng.value); + return turn->rng.value; + } + } + //trials + if (tag == STATE->rngTag) + return RandomUniformTrials(tag, lo, hi, reject, caller); - if (sum == 0) - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called with zero sum"); + //default + return RandomUniformDefaultValue(tag, lo, hi, reject, caller); +} +static u32 BattleTest_RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights, void *caller) +{ + //rigged + const struct BattlerTurn *turn = NULL; if (gCurrentTurnActionNumber < gBattlersCount || tag == RNG_SHELL_SIDE_ARM) { u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber]; @@ -652,23 +669,11 @@ u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights) return turn->rng.value; } + //trials if (tag == STATE->rngTag) - { - STATE->didRunRandomly = TRUE; - if (STATE->trials == 1) - { - STATE->trials = n; - PrintTestName(); - } - else if (STATE->trials != n) - { - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeighted called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, n); - } - // TODO: Detect inconsistent sum. - STATE->trialRatio = Q_4_12(weights[STATE->runTrial]) / sum; - return STATE->runTrial; - } + return RandomWeightedArrayTrials(tag, sum, n, weights, caller); + //default switch (tag) { case RNG_ACCURACY: @@ -693,21 +698,14 @@ u32 RandomWeightedArray(enum RandomTag tag, u32 sum, u32 n, const u8 *weights) return TRUE; default: - while (weights[n-1] == 0) - { - if (n == 1) - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomWeightedArray called from %p with tag %d and all zero weights", __builtin_extract_return_addr(__builtin_return_address(0)), tag); - n--; - } - return n-1; + return RandomWeightedArrayDefaultValue(tag, n, weights, caller); } } -const void *RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count) +static const void *BattleTest_RandomElementArray(enum RandomTag tag, const void *array, size_t size, size_t count, void *caller) { + //rigged const struct BattlerTurn *turn = NULL; - u32 index = count-1; - if (gCurrentTurnActionNumber < gBattlersCount) { u32 battlerId = gBattlerByTurnOrder[gCurrentTurnActionNumber]; @@ -715,33 +713,23 @@ const void *RandomElementArray(enum RandomTag tag, const void *array, size_t siz if (turn && turn->rng.tag == tag) { u32 element = 0; - for (index = 0; index < count; index++) + for (u32 index = 0; index < count; index++) { memcpy(&element, (const u8 *)array + size * index, size); if (element == turn->rng.value) return (const u8 *)array + size * index; } - // TODO: Incorporate the line number. Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":L%s: RandomElement illegal value requested: %d", gTestRunnerState.test->filename, turn->rng.value); } } + + //trials if (tag == STATE->rngTag) - { - STATE->didRunRandomly = TRUE; - if (STATE->trials == 1) - { - STATE->trials = count; - PrintTestName(); - } - else if (STATE->trials != count) - { - Test_ExitWithResult(TEST_RESULT_ERROR, SourceLine(0), ":LRandomElement called from %p with tag %d and inconsistent trials %d and %d", __builtin_extract_return_addr(__builtin_return_address(0)), tag, STATE->trials, count); - } - STATE->trialRatio = Q_4_12(1) / count; - return (const u8 *)array + size * STATE->runTrial; - } - return (const u8 *)array + size * index; + return RandomElementArrayTrials(tag, array, size, count, caller); + + //default + return RandomElementArrayDefaultValue(tag, array, size, count, caller); } static s32 TryAbilityPopUp(s32 i, s32 n, u32 battlerId, enum Ability ability) @@ -1738,6 +1726,9 @@ const struct TestRunner gBattleTestRunner = .tearDown = BattleTest_TearDown, .checkProgress = BattleTest_CheckProgress, .handleExitWithResult = BattleTest_HandleExitWithResult, + .randomUniform = BattleTest_RandomUniform, + .randomWeightedArray = BattleTest_RandomWeightedArray, + .randomElementArray = BattleTest_RandomElementArray, }; void SetFlagForTest(u32 sourceLine, u16 flagId)