Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/externals/clhep/include/CLHEP/Random/ranluxpp/mulmod.h

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

Diff markup

Differences between /externals/clhep/include/CLHEP/Random/ranluxpp/mulmod.h (Version 11.3.0) and /externals/clhep/include/CLHEP/Random/ranluxpp/mulmod.h (Version 11.1)


  1 //                                                  1 //
  2 // -*- C++ -*-                                      2 // -*- C++ -*-
  3 //                                                  3 //
  4 // -------------------------------------------      4 // -----------------------------------------------------------------------
  5 //                             HEP Random           5 //                             HEP Random
  6 //                       --- RanluxppEngine --      6 //                       --- RanluxppEngine ---
  7 //                     helper implementation f      7 //                     helper implementation file
  8 // -------------------------------------------      8 // -----------------------------------------------------------------------
  9                                                     9 
 10 #ifndef RANLUXPP_MULMOD_H                          10 #ifndef RANLUXPP_MULMOD_H
 11 #define RANLUXPP_MULMOD_H                          11 #define RANLUXPP_MULMOD_H
 12                                                    12 
 13 #include "helpers.h"                               13 #include "helpers.h"
 14                                                    14 
 15 #include <cstdint>                                 15 #include <cstdint>
 16                                                    16 
 17 /// Multiply two 576 bit numbers, stored as 9      17 /// Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each
 18 ///                                                18 ///
 19 /// \param[in] in1 first factor as 9 numbers o     19 /// \param[in] in1 first factor as 9 numbers of 64 bits each
 20 /// \param[in] in2 second factor as 9 numbers      20 /// \param[in] in2 second factor as 9 numbers of 64 bits each
 21 /// \param[out] out result with 18 numbers of      21 /// \param[out] out result with 18 numbers of 64 bits each
 22 static void multiply9x9(const uint64_t *in1, c     22 static void multiply9x9(const uint64_t *in1, const uint64_t *in2,
 23                         uint64_t *out) {           23                         uint64_t *out) {
 24   uint64_t next = 0;                               24   uint64_t next = 0;
 25   unsigned nextCarry = 0;                          25   unsigned nextCarry = 0;
 26                                                    26 
 27 #if defined(__clang__) || defined(__INTEL_COMP     27 #if defined(__clang__) || defined(__INTEL_COMPILER)
 28 #pragma unroll                                     28 #pragma unroll
 29 #elif defined(__GNUC__) && __GNUC__ >= 8           29 #elif defined(__GNUC__) && __GNUC__ >= 8
 30 // This pragma was introduced in GCC version 8     30 // This pragma was introduced in GCC version 8.
 31 #pragma GCC unroll 18                              31 #pragma GCC unroll 18
 32 #endif                                             32 #endif
 33   for (int i = 0; i < 18; i++) {                   33   for (int i = 0; i < 18; i++) {
 34     uint64_t current = next;                       34     uint64_t current = next;
 35     unsigned carry = nextCarry;                    35     unsigned carry = nextCarry;
 36                                                    36 
 37     next = 0;                                      37     next = 0;
 38     nextCarry = 0;                                 38     nextCarry = 0;
 39                                                    39 
 40 #if defined(__clang__) || defined(__INTEL_COMP     40 #if defined(__clang__) || defined(__INTEL_COMPILER)
 41 #pragma unroll                                     41 #pragma unroll
 42 #elif defined(__GNUC__) && __GNUC__ >= 8           42 #elif defined(__GNUC__) && __GNUC__ >= 8
 43 // This pragma was introduced in GCC version 8     43 // This pragma was introduced in GCC version 8.
 44 #pragma GCC unroll 9                               44 #pragma GCC unroll 9
 45 #endif                                             45 #endif
 46     for (int j = 0; j < 9; j++) {                  46     for (int j = 0; j < 9; j++) {
 47       int k = i - j;                               47       int k = i - j;
 48       if (k < 0 || k >= 9)                         48       if (k < 0 || k >= 9)
 49         continue;                                  49         continue;
 50                                                    50 
 51       uint64_t fac1 = in1[j];                      51       uint64_t fac1 = in1[j];
 52       uint64_t fac2 = in2[k];                      52       uint64_t fac2 = in2[k];
 53 #if defined(__SIZEOF_INT128__) && !defined(CLH     53 #if defined(__SIZEOF_INT128__) && !defined(CLHEP_NO_INT128)
 54 #ifdef __GNUC__                                    54 #ifdef __GNUC__
 55 #pragma GCC diagnostic push                        55 #pragma GCC diagnostic push
 56 #pragma GCC diagnostic ignored "-Wpedantic"        56 #pragma GCC diagnostic ignored "-Wpedantic"
 57 #endif                                             57 #endif
 58       unsigned __int128 prod = fac1;               58       unsigned __int128 prod = fac1;
 59       prod = prod * fac2;                          59       prod = prod * fac2;
 60                                                    60 
 61       uint64_t upper = prod >> 64;                 61       uint64_t upper = prod >> 64;
 62       uint64_t lower = static_cast<uint64_t>(p     62       uint64_t lower = static_cast<uint64_t>(prod);
 63 #ifdef __GNUC__                                    63 #ifdef __GNUC__
 64 #pragma GCC diagnostic pop                         64 #pragma GCC diagnostic pop
 65 #endif                                             65 #endif
 66 #else                                              66 #else
 67       uint64_t upper1 = fac1 >> 32;                67       uint64_t upper1 = fac1 >> 32;
 68       uint64_t lower1 = static_cast<uint32_t>(     68       uint64_t lower1 = static_cast<uint32_t>(fac1);
 69                                                    69 
 70       uint64_t upper2 = fac2 >> 32;                70       uint64_t upper2 = fac2 >> 32;
 71       uint64_t lower2 = static_cast<uint32_t>(     71       uint64_t lower2 = static_cast<uint32_t>(fac2);
 72                                                    72 
 73       // Multiply 32-bit parts, each product h     73       // Multiply 32-bit parts, each product has a maximum value of
 74       // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2      74       // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 ** 32 + 1.
 75       uint64_t upper = upper1 * upper2;            75       uint64_t upper = upper1 * upper2;
 76       uint64_t middle1 = upper1 * lower2;          76       uint64_t middle1 = upper1 * lower2;
 77       uint64_t middle2 = lower1 * upper2;          77       uint64_t middle2 = lower1 * upper2;
 78       uint64_t lower = lower1 * lower2;            78       uint64_t lower = lower1 * lower2;
 79                                                    79 
 80       // When adding the two products, the max     80       // When adding the two products, the maximum value for middle is
 81       // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which      81       // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which exceeds a uint64_t.
 82       unsigned overflow;                           82       unsigned overflow;
 83       uint64_t middle = add_overflow(middle1,      83       uint64_t middle = add_overflow(middle1, middle2, overflow);
 84       // Handling the overflow by a multiplica     84       // Handling the overflow by a multiplication with 0 or 1 is cheaper
 85       // than branching with an if statement,      85       // than branching with an if statement, which the compiler does not
 86       // optimize to this equivalent code. Not     86       // optimize to this equivalent code. Note that we could do entirely
 87       // without this overflow handling when s     87       // without this overflow handling when summing up the intermediate
 88       // products differently as described in      88       // products differently as described in the following SO answer:
 89       //    https://stackoverflow.com/a/515872     89       //    https://stackoverflow.com/a/51587262
 90       // However, this approach takes at least     90       // However, this approach takes at least the same amount of thinking
 91       // why a) the code gives the same result     91       // why a) the code gives the same results without b) overflowing due
 92       // to the mixture of 32 bit arithmetic.      92       // to the mixture of 32 bit arithmetic. Moreover, my tests show that
 93       // the scheme implemented here is actual     93       // the scheme implemented here is actually slightly more performant.
 94       uint64_t overflow_add = overflow * (uint     94       uint64_t overflow_add = overflow * (uint64_t(1) << 32);
 95       // This addition can never overflow beca     95       // This addition can never overflow because the maximum value of upper
 96       // is 2 ** 64 - 2 * 2 ** 32 + 1 (see abo     96       // is 2 ** 64 - 2 * 2 ** 32 + 1 (see above). When now adding another
 97       // 2 ** 32, the result is 2 ** 64 - 2 **     97       // 2 ** 32, the result is 2 ** 64 - 2 ** 32 + 1 and still smaller than
 98       // the maximum 2 ** 64 - 1 that can be s     98       // the maximum 2 ** 64 - 1 that can be stored in a uint64_t.
 99       upper += overflow_add;                       99       upper += overflow_add;
100                                                   100 
101       uint64_t middle_upper = middle >> 32;       101       uint64_t middle_upper = middle >> 32;
102       uint64_t middle_lower = middle << 32;       102       uint64_t middle_lower = middle << 32;
103                                                   103 
104       lower = add_overflow(lower, middle_lower    104       lower = add_overflow(lower, middle_lower, overflow);
105       upper += overflow;                          105       upper += overflow;
106                                                   106 
107       // This still can't overflow since the m    107       // This still can't overflow since the maximum of middle_upper is
108       //  - 2 ** 32 - 4 if there was an overfl    108       //  - 2 ** 32 - 4 if there was an overflow for middle above, bringing
109       //    the maximum value of upper to 2 **    109       //    the maximum value of upper to 2 ** 64 - 2.
110       //  - otherwise upper still has the init    110       //  - otherwise upper still has the initial maximum value given above
111       //    and the addition of a value smalle    111       //    and the addition of a value smaller than 2 ** 32 brings it to
112       //    a maximum value of 2 ** 64 - 2 **     112       //    a maximum value of 2 ** 64 - 2 ** 32 + 2.
113       // (Both cases include the increment to     113       // (Both cases include the increment to handle the overflow in lower.)
114       //                                          114       //
115       // All the reasoning makes perfect sense    115       // All the reasoning makes perfect sense given that the product of two
116       // 64 bit numbers is smaller than or equ    116       // 64 bit numbers is smaller than or equal to
117       //     (2 ** 64 - 1) ** 2 = 2 ** 128 - 2    117       //     (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 * 2 ** 64 + 1
118       // with the upper bits matching the 2 **    118       // with the upper bits matching the 2 ** 64 - 2 of the first case.
119       upper += middle_upper;                      119       upper += middle_upper;
120 #endif                                            120 #endif
121                                                   121 
122       // Add to current, remember carry.          122       // Add to current, remember carry.
123       current = add_carry(current, lower, carr    123       current = add_carry(current, lower, carry);
124                                                   124 
125       // Add to next, remember nextCarry.         125       // Add to next, remember nextCarry.
126       next = add_carry(next, upper, nextCarry)    126       next = add_carry(next, upper, nextCarry);
127     }                                             127     }
128                                                   128 
129     next = add_carry(next, carry, nextCarry);     129     next = add_carry(next, carry, nextCarry);
130                                                   130 
131     out[i] = current;                             131     out[i] = current;
132   }                                               132   }
133 }                                                 133 }
134                                                   134 
135 /// Compute a value congruent to mul modulo m     135 /// Compute a value congruent to mul modulo m less than 2 ** 576
136 ///                                               136 ///
137 /// \param[in] mul product from multiply9x9 wi    137 /// \param[in] mul product from multiply9x9 with 18 numbers of 64 bits each
138 /// \param[out] out result with 9 numbers of 6    138 /// \param[out] out result with 9 numbers of 64 bits each
139 ///                                               139 ///
140 /// \f$ m = 2^{576} - 2^{240} + 1 \f$             140 /// \f$ m = 2^{576} - 2^{240} + 1 \f$
141 ///                                               141 ///
142 /// The result in out is guaranteed to be smal    142 /// The result in out is guaranteed to be smaller than the modulus.
143 static void mod_m(const uint64_t *mul, uint64_    143 static void mod_m(const uint64_t *mul, uint64_t *out) {
144   uint64_t r[9];                                  144   uint64_t r[9];
145   // Assign r = t0                                145   // Assign r = t0
146   for (int i = 0; i < 9; i++) {                   146   for (int i = 0; i < 9; i++) {
147     r[i] = mul[i];                                147     r[i] = mul[i];
148   }                                               148   }
149                                                   149 
150   int64_t c = compute_r(mul + 9, r);              150   int64_t c = compute_r(mul + 9, r);
151                                                   151 
152   // To update r = r - c * m, it suffices to k    152   // To update r = r - c * m, it suffices to know c * (-2 ** 240 + 1)
153   // because the 2 ** 576 will cancel out. Als    153   // because the 2 ** 576 will cancel out. Also note that c may be zero, but
154   // the operation is still performed to avoid    154   // the operation is still performed to avoid branching.
155                                                   155 
156   // c * (-2 ** 240 + 1) in 576 bits looks as     156   // c * (-2 ** 240 + 1) in 576 bits looks as follows, depending on c:
157   //  - if c = 0, the number is zero.             157   //  - if c = 0, the number is zero.
158   //  - if c = 1: bits 576 to 240 are set,        158   //  - if c = 1: bits 576 to 240 are set,
159   //              bits 239 to 1 are zero, and     159   //              bits 239 to 1 are zero, and
160   //              the last one is set             160   //              the last one is set
161   //  - if c = -1, which corresponds to all bi    161   //  - if c = -1, which corresponds to all bits set (signed int64_t):
162   //              bits 576 to 240 are zero and    162   //              bits 576 to 240 are zero and the rest is set.
163   // Note that all bits except the last are ex    163   // Note that all bits except the last are exactly complimentary (unless c = 0)
164   // and the last byte is conveniently represe    164   // and the last byte is conveniently represented by c already.
165   // Now construct the three bit patterns from    165   // Now construct the three bit patterns from c, their names correspond to the
166   // assembly implementation by Alexei Sibidan    166   // assembly implementation by Alexei Sibidanov.
167                                                   167 
168   // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1     168   // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 -> all bits set (sign extension)
169   // (The assembly implementation shifts by 63    169   // (The assembly implementation shifts by 63, which gives the same result.)
170   int64_t t0 = c >> 1;                            170   int64_t t0 = c >> 1;
171                                                   171 
172   // Left shifting negative values is undefine    172   // Left shifting negative values is undefined behavior until C++20, cast to
173   // unsigned.                                    173   // unsigned.
174   uint64_t c_unsigned = static_cast<uint64_t>(    174   uint64_t c_unsigned = static_cast<uint64_t>(c);
175                                                   175 
176   // c = 0 -> t2 = 0; c = 1 -> upper 16 bits s    176   // c = 0 -> t2 = 0; c = 1 -> upper 16 bits set; c = -1 -> lower 48 bits set
177   int64_t t2 = t0 - (c_unsigned << 48);           177   int64_t t2 = t0 - (c_unsigned << 48);
178                                                   178 
179   // c = 0 -> t1 = 0; c = 1 -> all bits set (s    179   // c = 0 -> t1 = 0; c = 1 -> all bits set (sign extension); c = -1 -> t1 = 0
180   // (The assembly implementation shifts by 63    180   // (The assembly implementation shifts by 63, which gives the same result.)
181   int64_t t1 = t2 >> 48;                          181   int64_t t1 = t2 >> 48;
182                                                   182 
183   unsigned carry = 0;                             183   unsigned carry = 0;
184   {                                               184   {
185     uint64_t r_0 = r[0];                          185     uint64_t r_0 = r[0];
186                                                   186 
187     uint64_t out_0 = sub_carry(r_0, c, carry);    187     uint64_t out_0 = sub_carry(r_0, c, carry);
188     out[0] = out_0;                               188     out[0] = out_0;
189   }                                               189   }
190   for (int i = 1; i < 3; i++) {                   190   for (int i = 1; i < 3; i++) {
191     uint64_t r_i = r[i];                          191     uint64_t r_i = r[i];
192     r_i = sub_overflow(r_i, carry, carry);        192     r_i = sub_overflow(r_i, carry, carry);
193                                                   193 
194     uint64_t out_i = sub_carry(r_i, t0, carry)    194     uint64_t out_i = sub_carry(r_i, t0, carry);
195     out[i] = out_i;                               195     out[i] = out_i;
196   }                                               196   }
197   {                                               197   {
198     uint64_t r_3 = r[3];                          198     uint64_t r_3 = r[3];
199     r_3 = sub_overflow(r_3, carry, carry);        199     r_3 = sub_overflow(r_3, carry, carry);
200                                                   200 
201     uint64_t out_3 = sub_carry(r_3, t2, carry)    201     uint64_t out_3 = sub_carry(r_3, t2, carry);
202     out[3] = out_3;                               202     out[3] = out_3;
203   }                                               203   }
204   for (int i = 4; i < 9; i++) {                   204   for (int i = 4; i < 9; i++) {
205     uint64_t r_i = r[i];                          205     uint64_t r_i = r[i];
206     r_i = sub_overflow(r_i, carry, carry);        206     r_i = sub_overflow(r_i, carry, carry);
207                                                   207 
208     uint64_t out_i = sub_carry(r_i, t1, carry)    208     uint64_t out_i = sub_carry(r_i, t1, carry);
209     out[i] = out_i;                               209     out[i] = out_i;
210   }                                               210   }
211 }                                                 211 }
212                                                   212 
213 /// Combine multiply9x9 and mod_m with interna    213 /// Combine multiply9x9 and mod_m with internal temporary storage
214 ///                                               214 ///
215 /// \param[in] in1 first factor with 9 numbers    215 /// \param[in] in1 first factor with 9 numbers of 64 bits each
216 /// \param[inout] inout second factor and also    216 /// \param[inout] inout second factor and also the output of the same size
217 ///                                               217 ///
218 /// The result in inout is guaranteed to be sm    218 /// The result in inout is guaranteed to be smaller than the modulus.
219 static void mulmod(const uint64_t *in1, uint64    219 static void mulmod(const uint64_t *in1, uint64_t *inout) {
220   uint64_t mul[2 * 9] = {0};                      220   uint64_t mul[2 * 9] = {0};
221   multiply9x9(in1, inout, mul);                   221   multiply9x9(in1, inout, mul);
222   mod_m(mul, inout);                              222   mod_m(mul, inout);
223 }                                                 223 }
224                                                   224 
225 /// Compute base to the n modulo m                225 /// Compute base to the n modulo m
226 ///                                               226 ///
227 /// \param[in] base with 9 numbers of 64 bits     227 /// \param[in] base with 9 numbers of 64 bits each
228 /// \param[out] res output with 9 numbers of 6    228 /// \param[out] res output with 9 numbers of 64 bits each
229 /// \param[in] n exponent                         229 /// \param[in] n exponent
230 ///                                               230 ///
231 /// The arguments base and res may point to th    231 /// The arguments base and res may point to the same location.
232 static void powermod(const uint64_t *base, uin    232 static void powermod(const uint64_t *base, uint64_t *res, uint64_t n) {
233   uint64_t fac[9] = {0};                          233   uint64_t fac[9] = {0};
234   fac[0] = base[0];                               234   fac[0] = base[0];
235   res[0] = 1;                                     235   res[0] = 1;
236   for (int i = 1; i < 9; i++) {                   236   for (int i = 1; i < 9; i++) {
237     fac[i] = base[i];                             237     fac[i] = base[i];
238     res[i] = 0;                                   238     res[i] = 0;
239   }                                               239   }
240                                                   240 
241   uint64_t mul[18] = {0};                         241   uint64_t mul[18] = {0};
242   while (n) {                                     242   while (n) {
243     if (n & 1) {                                  243     if (n & 1) {
244       multiply9x9(res, fac, mul);                 244       multiply9x9(res, fac, mul);
245       mod_m(mul, res);                            245       mod_m(mul, res);
246     }                                             246     }
247     n >>= 1;                                      247     n >>= 1;
248     if (!n)                                       248     if (!n)
249       break;                                      249       break;
250     multiply9x9(fac, fac, mul);                   250     multiply9x9(fac, fac, mul);
251     mod_m(mul, fac);                              251     mod_m(mul, fac);
252   }                                               252   }
253 }                                                 253 }
254                                                   254 
255 #endif                                            255 #endif
256                                                   256