Geant4 Cross Reference |
1 // 1 // 2 // -*- C++ -*- 2 // -*- C++ -*- 3 // 3 // 4 // ------------------------------------------- 4 // ----------------------------------------------------------------------- 5 // HEP Random 5 // HEP Random 6 // --- RanluxppEngine -- 6 // --- RanluxppEngine --- 7 // helper implementation f 7 // helper implementation file 8 // ------------------------------------------- 8 // ----------------------------------------------------------------------- 9 9 10 #ifndef RANLUXPP_MULMOD_H 10 #ifndef RANLUXPP_MULMOD_H 11 #define RANLUXPP_MULMOD_H 11 #define RANLUXPP_MULMOD_H 12 12 13 #include "helpers.h" 13 #include "helpers.h" 14 14 15 #include <cstdint> 15 #include <cstdint> 16 16 17 /// Multiply two 576 bit numbers, stored as 9 17 /// Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each 18 /// 18 /// 19 /// \param[in] in1 first factor as 9 numbers o 19 /// \param[in] in1 first factor as 9 numbers of 64 bits each 20 /// \param[in] in2 second factor as 9 numbers 20 /// \param[in] in2 second factor as 9 numbers of 64 bits each 21 /// \param[out] out result with 18 numbers of 21 /// \param[out] out result with 18 numbers of 64 bits each 22 static void multiply9x9(const uint64_t *in1, c 22 static void multiply9x9(const uint64_t *in1, const uint64_t *in2, 23 uint64_t *out) { 23 uint64_t *out) { 24 uint64_t next = 0; 24 uint64_t next = 0; 25 unsigned nextCarry = 0; 25 unsigned nextCarry = 0; 26 26 27 #if defined(__clang__) || defined(__INTEL_COMP 27 #if defined(__clang__) || defined(__INTEL_COMPILER) 28 #pragma unroll 28 #pragma unroll 29 #elif defined(__GNUC__) && __GNUC__ >= 8 29 #elif defined(__GNUC__) && __GNUC__ >= 8 30 // This pragma was introduced in GCC version 8 30 // This pragma was introduced in GCC version 8. 31 #pragma GCC unroll 18 31 #pragma GCC unroll 18 32 #endif 32 #endif 33 for (int i = 0; i < 18; i++) { 33 for (int i = 0; i < 18; i++) { 34 uint64_t current = next; 34 uint64_t current = next; 35 unsigned carry = nextCarry; 35 unsigned carry = nextCarry; 36 36 37 next = 0; 37 next = 0; 38 nextCarry = 0; 38 nextCarry = 0; 39 39 40 #if defined(__clang__) || defined(__INTEL_COMP 40 #if defined(__clang__) || defined(__INTEL_COMPILER) 41 #pragma unroll 41 #pragma unroll 42 #elif defined(__GNUC__) && __GNUC__ >= 8 42 #elif defined(__GNUC__) && __GNUC__ >= 8 43 // This pragma was introduced in GCC version 8 43 // This pragma was introduced in GCC version 8. 44 #pragma GCC unroll 9 44 #pragma GCC unroll 9 45 #endif 45 #endif 46 for (int j = 0; j < 9; j++) { 46 for (int j = 0; j < 9; j++) { 47 int k = i - j; 47 int k = i - j; 48 if (k < 0 || k >= 9) 48 if (k < 0 || k >= 9) 49 continue; 49 continue; 50 50 51 uint64_t fac1 = in1[j]; 51 uint64_t fac1 = in1[j]; 52 uint64_t fac2 = in2[k]; 52 uint64_t fac2 = in2[k]; 53 #if defined(__SIZEOF_INT128__) && !defined(CLH 53 #if defined(__SIZEOF_INT128__) && !defined(CLHEP_NO_INT128) 54 #ifdef __GNUC__ 54 #ifdef __GNUC__ 55 #pragma GCC diagnostic push 55 #pragma GCC diagnostic push 56 #pragma GCC diagnostic ignored "-Wpedantic" 56 #pragma GCC diagnostic ignored "-Wpedantic" 57 #endif 57 #endif 58 unsigned __int128 prod = fac1; 58 unsigned __int128 prod = fac1; 59 prod = prod * fac2; 59 prod = prod * fac2; 60 60 61 uint64_t upper = prod >> 64; 61 uint64_t upper = prod >> 64; 62 uint64_t lower = static_cast<uint64_t>(p 62 uint64_t lower = static_cast<uint64_t>(prod); 63 #ifdef __GNUC__ 63 #ifdef __GNUC__ 64 #pragma GCC diagnostic pop 64 #pragma GCC diagnostic pop 65 #endif 65 #endif 66 #else 66 #else 67 uint64_t upper1 = fac1 >> 32; 67 uint64_t upper1 = fac1 >> 32; 68 uint64_t lower1 = static_cast<uint32_t>( 68 uint64_t lower1 = static_cast<uint32_t>(fac1); 69 69 70 uint64_t upper2 = fac2 >> 32; 70 uint64_t upper2 = fac2 >> 32; 71 uint64_t lower2 = static_cast<uint32_t>( 71 uint64_t lower2 = static_cast<uint32_t>(fac2); 72 72 73 // Multiply 32-bit parts, each product h 73 // Multiply 32-bit parts, each product has a maximum value of 74 // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 74 // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 ** 32 + 1. 75 uint64_t upper = upper1 * upper2; 75 uint64_t upper = upper1 * upper2; 76 uint64_t middle1 = upper1 * lower2; 76 uint64_t middle1 = upper1 * lower2; 77 uint64_t middle2 = lower1 * upper2; 77 uint64_t middle2 = lower1 * upper2; 78 uint64_t lower = lower1 * lower2; 78 uint64_t lower = lower1 * lower2; 79 79 80 // When adding the two products, the max 80 // When adding the two products, the maximum value for middle is 81 // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which 81 // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which exceeds a uint64_t. 82 unsigned overflow; 82 unsigned overflow; 83 uint64_t middle = add_overflow(middle1, 83 uint64_t middle = add_overflow(middle1, middle2, overflow); 84 // Handling the overflow by a multiplica 84 // Handling the overflow by a multiplication with 0 or 1 is cheaper 85 // than branching with an if statement, 85 // than branching with an if statement, which the compiler does not 86 // optimize to this equivalent code. Not 86 // optimize to this equivalent code. Note that we could do entirely 87 // without this overflow handling when s 87 // without this overflow handling when summing up the intermediate 88 // products differently as described in 88 // products differently as described in the following SO answer: 89 // https://stackoverflow.com/a/515872 89 // https://stackoverflow.com/a/51587262 90 // However, this approach takes at least 90 // However, this approach takes at least the same amount of thinking 91 // why a) the code gives the same result 91 // why a) the code gives the same results without b) overflowing due 92 // to the mixture of 32 bit arithmetic. 92 // to the mixture of 32 bit arithmetic. Moreover, my tests show that 93 // the scheme implemented here is actual 93 // the scheme implemented here is actually slightly more performant. 94 uint64_t overflow_add = overflow * (uint 94 uint64_t overflow_add = overflow * (uint64_t(1) << 32); 95 // This addition can never overflow beca 95 // This addition can never overflow because the maximum value of upper 96 // is 2 ** 64 - 2 * 2 ** 32 + 1 (see abo 96 // is 2 ** 64 - 2 * 2 ** 32 + 1 (see above). When now adding another 97 // 2 ** 32, the result is 2 ** 64 - 2 ** 97 // 2 ** 32, the result is 2 ** 64 - 2 ** 32 + 1 and still smaller than 98 // the maximum 2 ** 64 - 1 that can be s 98 // the maximum 2 ** 64 - 1 that can be stored in a uint64_t. 99 upper += overflow_add; 99 upper += overflow_add; 100 100 101 uint64_t middle_upper = middle >> 32; 101 uint64_t middle_upper = middle >> 32; 102 uint64_t middle_lower = middle << 32; 102 uint64_t middle_lower = middle << 32; 103 103 104 lower = add_overflow(lower, middle_lower 104 lower = add_overflow(lower, middle_lower, overflow); 105 upper += overflow; 105 upper += overflow; 106 106 107 // This still can't overflow since the m 107 // This still can't overflow since the maximum of middle_upper is 108 // - 2 ** 32 - 4 if there was an overfl 108 // - 2 ** 32 - 4 if there was an overflow for middle above, bringing 109 // the maximum value of upper to 2 ** 109 // the maximum value of upper to 2 ** 64 - 2. 110 // - otherwise upper still has the init 110 // - otherwise upper still has the initial maximum value given above 111 // and the addition of a value smalle 111 // and the addition of a value smaller than 2 ** 32 brings it to 112 // a maximum value of 2 ** 64 - 2 ** 112 // a maximum value of 2 ** 64 - 2 ** 32 + 2. 113 // (Both cases include the increment to 113 // (Both cases include the increment to handle the overflow in lower.) 114 // 114 // 115 // All the reasoning makes perfect sense 115 // All the reasoning makes perfect sense given that the product of two 116 // 64 bit numbers is smaller than or equ 116 // 64 bit numbers is smaller than or equal to 117 // (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 117 // (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 * 2 ** 64 + 1 118 // with the upper bits matching the 2 ** 118 // with the upper bits matching the 2 ** 64 - 2 of the first case. 119 upper += middle_upper; 119 upper += middle_upper; 120 #endif 120 #endif 121 121 122 // Add to current, remember carry. 122 // Add to current, remember carry. 123 current = add_carry(current, lower, carr 123 current = add_carry(current, lower, carry); 124 124 125 // Add to next, remember nextCarry. 125 // Add to next, remember nextCarry. 126 next = add_carry(next, upper, nextCarry) 126 next = add_carry(next, upper, nextCarry); 127 } 127 } 128 128 129 next = add_carry(next, carry, nextCarry); 129 next = add_carry(next, carry, nextCarry); 130 130 131 out[i] = current; 131 out[i] = current; 132 } 132 } 133 } 133 } 134 134 135 /// Compute a value congruent to mul modulo m 135 /// Compute a value congruent to mul modulo m less than 2 ** 576 136 /// 136 /// 137 /// \param[in] mul product from multiply9x9 wi 137 /// \param[in] mul product from multiply9x9 with 18 numbers of 64 bits each 138 /// \param[out] out result with 9 numbers of 6 138 /// \param[out] out result with 9 numbers of 64 bits each 139 /// 139 /// 140 /// \f$ m = 2^{576} - 2^{240} + 1 \f$ 140 /// \f$ m = 2^{576} - 2^{240} + 1 \f$ 141 /// 141 /// 142 /// The result in out is guaranteed to be smal 142 /// The result in out is guaranteed to be smaller than the modulus. 143 static void mod_m(const uint64_t *mul, uint64_ 143 static void mod_m(const uint64_t *mul, uint64_t *out) { 144 uint64_t r[9]; 144 uint64_t r[9]; 145 // Assign r = t0 145 // Assign r = t0 146 for (int i = 0; i < 9; i++) { 146 for (int i = 0; i < 9; i++) { 147 r[i] = mul[i]; 147 r[i] = mul[i]; 148 } 148 } 149 149 150 int64_t c = compute_r(mul + 9, r); 150 int64_t c = compute_r(mul + 9, r); 151 151 152 // To update r = r - c * m, it suffices to k 152 // To update r = r - c * m, it suffices to know c * (-2 ** 240 + 1) 153 // because the 2 ** 576 will cancel out. Als 153 // because the 2 ** 576 will cancel out. Also note that c may be zero, but 154 // the operation is still performed to avoid 154 // the operation is still performed to avoid branching. 155 155 156 // c * (-2 ** 240 + 1) in 576 bits looks as 156 // c * (-2 ** 240 + 1) in 576 bits looks as follows, depending on c: 157 // - if c = 0, the number is zero. 157 // - if c = 0, the number is zero. 158 // - if c = 1: bits 576 to 240 are set, 158 // - if c = 1: bits 576 to 240 are set, 159 // bits 239 to 1 are zero, and 159 // bits 239 to 1 are zero, and 160 // the last one is set 160 // the last one is set 161 // - if c = -1, which corresponds to all bi 161 // - if c = -1, which corresponds to all bits set (signed int64_t): 162 // bits 576 to 240 are zero and 162 // bits 576 to 240 are zero and the rest is set. 163 // Note that all bits except the last are ex 163 // Note that all bits except the last are exactly complimentary (unless c = 0) 164 // and the last byte is conveniently represe 164 // and the last byte is conveniently represented by c already. 165 // Now construct the three bit patterns from 165 // Now construct the three bit patterns from c, their names correspond to the 166 // assembly implementation by Alexei Sibidan 166 // assembly implementation by Alexei Sibidanov. 167 167 168 // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 168 // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 -> all bits set (sign extension) 169 // (The assembly implementation shifts by 63 169 // (The assembly implementation shifts by 63, which gives the same result.) 170 int64_t t0 = c >> 1; 170 int64_t t0 = c >> 1; 171 171 172 // Left shifting negative values is undefine 172 // Left shifting negative values is undefined behavior until C++20, cast to 173 // unsigned. 173 // unsigned. 174 uint64_t c_unsigned = static_cast<uint64_t>( 174 uint64_t c_unsigned = static_cast<uint64_t>(c); 175 175 176 // c = 0 -> t2 = 0; c = 1 -> upper 16 bits s 176 // c = 0 -> t2 = 0; c = 1 -> upper 16 bits set; c = -1 -> lower 48 bits set 177 int64_t t2 = t0 - (c_unsigned << 48); 177 int64_t t2 = t0 - (c_unsigned << 48); 178 178 179 // c = 0 -> t1 = 0; c = 1 -> all bits set (s 179 // c = 0 -> t1 = 0; c = 1 -> all bits set (sign extension); c = -1 -> t1 = 0 180 // (The assembly implementation shifts by 63 180 // (The assembly implementation shifts by 63, which gives the same result.) 181 int64_t t1 = t2 >> 48; 181 int64_t t1 = t2 >> 48; 182 182 183 unsigned carry = 0; 183 unsigned carry = 0; 184 { 184 { 185 uint64_t r_0 = r[0]; 185 uint64_t r_0 = r[0]; 186 186 187 uint64_t out_0 = sub_carry(r_0, c, carry); 187 uint64_t out_0 = sub_carry(r_0, c, carry); 188 out[0] = out_0; 188 out[0] = out_0; 189 } 189 } 190 for (int i = 1; i < 3; i++) { 190 for (int i = 1; i < 3; i++) { 191 uint64_t r_i = r[i]; 191 uint64_t r_i = r[i]; 192 r_i = sub_overflow(r_i, carry, carry); 192 r_i = sub_overflow(r_i, carry, carry); 193 193 194 uint64_t out_i = sub_carry(r_i, t0, carry) 194 uint64_t out_i = sub_carry(r_i, t0, carry); 195 out[i] = out_i; 195 out[i] = out_i; 196 } 196 } 197 { 197 { 198 uint64_t r_3 = r[3]; 198 uint64_t r_3 = r[3]; 199 r_3 = sub_overflow(r_3, carry, carry); 199 r_3 = sub_overflow(r_3, carry, carry); 200 200 201 uint64_t out_3 = sub_carry(r_3, t2, carry) 201 uint64_t out_3 = sub_carry(r_3, t2, carry); 202 out[3] = out_3; 202 out[3] = out_3; 203 } 203 } 204 for (int i = 4; i < 9; i++) { 204 for (int i = 4; i < 9; i++) { 205 uint64_t r_i = r[i]; 205 uint64_t r_i = r[i]; 206 r_i = sub_overflow(r_i, carry, carry); 206 r_i = sub_overflow(r_i, carry, carry); 207 207 208 uint64_t out_i = sub_carry(r_i, t1, carry) 208 uint64_t out_i = sub_carry(r_i, t1, carry); 209 out[i] = out_i; 209 out[i] = out_i; 210 } 210 } 211 } 211 } 212 212 213 /// Combine multiply9x9 and mod_m with interna 213 /// Combine multiply9x9 and mod_m with internal temporary storage 214 /// 214 /// 215 /// \param[in] in1 first factor with 9 numbers 215 /// \param[in] in1 first factor with 9 numbers of 64 bits each 216 /// \param[inout] inout second factor and also 216 /// \param[inout] inout second factor and also the output of the same size 217 /// 217 /// 218 /// The result in inout is guaranteed to be sm 218 /// The result in inout is guaranteed to be smaller than the modulus. 219 static void mulmod(const uint64_t *in1, uint64 219 static void mulmod(const uint64_t *in1, uint64_t *inout) { 220 uint64_t mul[2 * 9] = {0}; 220 uint64_t mul[2 * 9] = {0}; 221 multiply9x9(in1, inout, mul); 221 multiply9x9(in1, inout, mul); 222 mod_m(mul, inout); 222 mod_m(mul, inout); 223 } 223 } 224 224 225 /// Compute base to the n modulo m 225 /// Compute base to the n modulo m 226 /// 226 /// 227 /// \param[in] base with 9 numbers of 64 bits 227 /// \param[in] base with 9 numbers of 64 bits each 228 /// \param[out] res output with 9 numbers of 6 228 /// \param[out] res output with 9 numbers of 64 bits each 229 /// \param[in] n exponent 229 /// \param[in] n exponent 230 /// 230 /// 231 /// The arguments base and res may point to th 231 /// The arguments base and res may point to the same location. 232 static void powermod(const uint64_t *base, uin 232 static void powermod(const uint64_t *base, uint64_t *res, uint64_t n) { 233 uint64_t fac[9] = {0}; 233 uint64_t fac[9] = {0}; 234 fac[0] = base[0]; 234 fac[0] = base[0]; 235 res[0] = 1; 235 res[0] = 1; 236 for (int i = 1; i < 9; i++) { 236 for (int i = 1; i < 9; i++) { 237 fac[i] = base[i]; 237 fac[i] = base[i]; 238 res[i] = 0; 238 res[i] = 0; 239 } 239 } 240 240 241 uint64_t mul[18] = {0}; 241 uint64_t mul[18] = {0}; 242 while (n) { 242 while (n) { 243 if (n & 1) { 243 if (n & 1) { 244 multiply9x9(res, fac, mul); 244 multiply9x9(res, fac, mul); 245 mod_m(mul, res); 245 mod_m(mul, res); 246 } 246 } 247 n >>= 1; 247 n >>= 1; 248 if (!n) 248 if (!n) 249 break; 249 break; 250 multiply9x9(fac, fac, mul); 250 multiply9x9(fac, fac, mul); 251 mod_m(mul, fac); 251 mod_m(mul, fac); 252 } 252 } 253 } 253 } 254 254 255 #endif 255 #endif 256 256