#include <stdio.h>
#include <inttypes.h>

typedef int32_t p_int;
typedef uint32_t p_uint;

p_uint scale(p_uint rand, p_uint n)
{
#define BITS (sizeof(p_int)*8)
#define HBITS (BITS/2)
#define HMASK ((1<<HBITS) - 1)

    p_uint rand_h = rand >> HBITS;
    p_uint rand_l = rand & HMASK;

    p_uint n_h = n >> HBITS;
    p_uint n_l = n & HMASK;
    
    p_uint result_h = rand_h * n_h;
    p_uint result_l = rand_h * n_l;
    
    result_h += result_l >> HBITS;
    result_l &= HMASK;
    result_l += rand_l * n_h;
    result_h += result_l >> HBITS;
    result_l &= HMASK;
    result_l += (rand_l * n_l) >> HBITS;
    result_h += result_l >> HBITS;

    return result_h;
}

void check(uint32_t rand, uint32_t n)
{
    uint32_t result1 = scale(rand, n);
    uint32_t result2 = (uint32_t)((((uint64_t) rand) * ((uint64_t) n)) >> 32);
    
    printf("%s %08X x %08X = %08X, %08X\n", result1 == result2 ? " " : "!", rand, n, result1, result2);
}

int main()
{
    check(1,1);
    check(0x7FFFFFFF,2);
    check(0x80000000,2);
    check(2,0x7FFFFFFF);
    check(2,0x80000000);
    check(0xFFFFFFFF,10);
    check(10,0xFFFFFFFF);
    check(0xFFFF0000,0xFFFF0000);
    check(0x0000FFFF,0xFFFF0000);
    check(0xFFFF0000,0x0000FFFF);
    check(0x0000FFFF,0x0000FFFF);
    check(0xFFFF0000,0x00010000);
    check(0x0000FFFF,0x00010000);
    check(0xFFFFFFFF,0xFFFFFFFF);
}
