Geant4 Cross Reference

Cross-Referencing   Geant4
Geant4/externals/clhep/include/CLHEP/Random/ranluxpp/mulmod.h

Version: [ ReleaseNotes ] [ 1.0 ] [ 1.1 ] [ 2.0 ] [ 3.0 ] [ 3.1 ] [ 3.2 ] [ 4.0 ] [ 4.0.p1 ] [ 4.0.p2 ] [ 4.1 ] [ 4.1.p1 ] [ 5.0 ] [ 5.0.p1 ] [ 5.1 ] [ 5.1.p1 ] [ 5.2 ] [ 5.2.p1 ] [ 5.2.p2 ] [ 6.0 ] [ 6.0.p1 ] [ 6.1 ] [ 6.2 ] [ 6.2.p1 ] [ 6.2.p2 ] [ 7.0 ] [ 7.0.p1 ] [ 7.1 ] [ 7.1.p1 ] [ 8.0 ] [ 8.0.p1 ] [ 8.1 ] [ 8.1.p1 ] [ 8.1.p2 ] [ 8.2 ] [ 8.2.p1 ] [ 8.3 ] [ 8.3.p1 ] [ 8.3.p2 ] [ 9.0 ] [ 9.0.p1 ] [ 9.0.p2 ] [ 9.1 ] [ 9.1.p1 ] [ 9.1.p2 ] [ 9.1.p3 ] [ 9.2 ] [ 9.2.p1 ] [ 9.2.p2 ] [ 9.2.p3 ] [ 9.2.p4 ] [ 9.3 ] [ 9.3.p1 ] [ 9.3.p2 ] [ 9.4 ] [ 9.4.p1 ] [ 9.4.p2 ] [ 9.4.p3 ] [ 9.4.p4 ] [ 9.5 ] [ 9.5.p1 ] [ 9.5.p2 ] [ 9.6 ] [ 9.6.p1 ] [ 9.6.p2 ] [ 9.6.p3 ] [ 9.6.p4 ] [ 10.0 ] [ 10.0.p1 ] [ 10.0.p2 ] [ 10.0.p3 ] [ 10.0.p4 ] [ 10.1 ] [ 10.1.p1 ] [ 10.1.p2 ] [ 10.1.p3 ] [ 10.2 ] [ 10.2.p1 ] [ 10.2.p2 ] [ 10.2.p3 ] [ 10.3 ] [ 10.3.p1 ] [ 10.3.p2 ] [ 10.3.p3 ] [ 10.4 ] [ 10.4.p1 ] [ 10.4.p2 ] [ 10.4.p3 ] [ 10.5 ] [ 10.5.p1 ] [ 10.6 ] [ 10.6.p1 ] [ 10.6.p2 ] [ 10.6.p3 ] [ 10.7 ] [ 10.7.p1 ] [ 10.7.p2 ] [ 10.7.p3 ] [ 10.7.p4 ] [ 11.0 ] [ 11.0.p1 ] [ 11.0.p2 ] [ 11.0.p3, ] [ 11.0.p4 ] [ 11.1 ] [ 11.1.1 ] [ 11.1.2 ] [ 11.1.3 ] [ 11.2 ] [ 11.2.1 ] [ 11.2.2 ] [ 11.3.0 ]

  1 //
  2 // -*- C++ -*-
  3 //
  4 // -----------------------------------------------------------------------
  5 //                             HEP Random
  6 //                       --- RanluxppEngine ---
  7 //                     helper implementation file
  8 // -----------------------------------------------------------------------
  9 
 10 #ifndef RANLUXPP_MULMOD_H
 11 #define RANLUXPP_MULMOD_H
 12 
 13 #include "helpers.h"
 14 
 15 #include <cstdint>
 16 
 17 /// Multiply two 576 bit numbers, stored as 9 numbers of 64 bits each
 18 ///
 19 /// \param[in] in1 first factor as 9 numbers of 64 bits each
 20 /// \param[in] in2 second factor as 9 numbers of 64 bits each
 21 /// \param[out] out result with 18 numbers of 64 bits each
 22 static void multiply9x9(const uint64_t *in1, const uint64_t *in2,
 23                         uint64_t *out) {
 24   uint64_t next = 0;
 25   unsigned nextCarry = 0;
 26 
 27 #if defined(__clang__) || defined(__INTEL_COMPILER)
 28 #pragma unroll
 29 #elif defined(__GNUC__) && __GNUC__ >= 8
 30 // This pragma was introduced in GCC version 8.
 31 #pragma GCC unroll 18
 32 #endif
 33   for (int i = 0; i < 18; i++) {
 34     uint64_t current = next;
 35     unsigned carry = nextCarry;
 36 
 37     next = 0;
 38     nextCarry = 0;
 39 
 40 #if defined(__clang__) || defined(__INTEL_COMPILER)
 41 #pragma unroll
 42 #elif defined(__GNUC__) && __GNUC__ >= 8
 43 // This pragma was introduced in GCC version 8.
 44 #pragma GCC unroll 9
 45 #endif
 46     for (int j = 0; j < 9; j++) {
 47       int k = i - j;
 48       if (k < 0 || k >= 9)
 49         continue;
 50 
 51       uint64_t fac1 = in1[j];
 52       uint64_t fac2 = in2[k];
 53 #if defined(__SIZEOF_INT128__) && !defined(CLHEP_NO_INT128)
 54 #ifdef __GNUC__
 55 #pragma GCC diagnostic push
 56 #pragma GCC diagnostic ignored "-Wpedantic"
 57 #endif
 58       unsigned __int128 prod = fac1;
 59       prod = prod * fac2;
 60 
 61       uint64_t upper = prod >> 64;
 62       uint64_t lower = static_cast<uint64_t>(prod);
 63 #ifdef __GNUC__
 64 #pragma GCC diagnostic pop
 65 #endif
 66 #else
 67       uint64_t upper1 = fac1 >> 32;
 68       uint64_t lower1 = static_cast<uint32_t>(fac1);
 69 
 70       uint64_t upper2 = fac2 >> 32;
 71       uint64_t lower2 = static_cast<uint32_t>(fac2);
 72 
 73       // Multiply 32-bit parts, each product has a maximum value of
 74       // (2 ** 32 - 1) ** 2 = 2 ** 64 - 2 * 2 ** 32 + 1.
 75       uint64_t upper = upper1 * upper2;
 76       uint64_t middle1 = upper1 * lower2;
 77       uint64_t middle2 = lower1 * upper2;
 78       uint64_t lower = lower1 * lower2;
 79 
 80       // When adding the two products, the maximum value for middle is
 81       // 2 * 2 ** 64 - 4 * 2 ** 32 + 2, which exceeds a uint64_t.
 82       unsigned overflow;
 83       uint64_t middle = add_overflow(middle1, middle2, overflow);
 84       // Handling the overflow by a multiplication with 0 or 1 is cheaper
 85       // than branching with an if statement, which the compiler does not
 86       // optimize to this equivalent code. Note that we could do entirely
 87       // without this overflow handling when summing up the intermediate
 88       // products differently as described in the following SO answer:
 89       //    https://stackoverflow.com/a/51587262
 90       // However, this approach takes at least the same amount of thinking
 91       // why a) the code gives the same results without b) overflowing due
 92       // to the mixture of 32 bit arithmetic. Moreover, my tests show that
 93       // the scheme implemented here is actually slightly more performant.
 94       uint64_t overflow_add = overflow * (uint64_t(1) << 32);
 95       // This addition can never overflow because the maximum value of upper
 96       // is 2 ** 64 - 2 * 2 ** 32 + 1 (see above). When now adding another
 97       // 2 ** 32, the result is 2 ** 64 - 2 ** 32 + 1 and still smaller than
 98       // the maximum 2 ** 64 - 1 that can be stored in a uint64_t.
 99       upper += overflow_add;
100 
101       uint64_t middle_upper = middle >> 32;
102       uint64_t middle_lower = middle << 32;
103 
104       lower = add_overflow(lower, middle_lower, overflow);
105       upper += overflow;
106 
107       // This still can't overflow since the maximum of middle_upper is
108       //  - 2 ** 32 - 4 if there was an overflow for middle above, bringing
109       //    the maximum value of upper to 2 ** 64 - 2.
110       //  - otherwise upper still has the initial maximum value given above
111       //    and the addition of a value smaller than 2 ** 32 brings it to
112       //    a maximum value of 2 ** 64 - 2 ** 32 + 2.
113       // (Both cases include the increment to handle the overflow in lower.)
114       //
115       // All the reasoning makes perfect sense given that the product of two
116       // 64 bit numbers is smaller than or equal to
117       //     (2 ** 64 - 1) ** 2 = 2 ** 128 - 2 * 2 ** 64 + 1
118       // with the upper bits matching the 2 ** 64 - 2 of the first case.
119       upper += middle_upper;
120 #endif
121 
122       // Add to current, remember carry.
123       current = add_carry(current, lower, carry);
124 
125       // Add to next, remember nextCarry.
126       next = add_carry(next, upper, nextCarry);
127     }
128 
129     next = add_carry(next, carry, nextCarry);
130 
131     out[i] = current;
132   }
133 }
134 
135 /// Compute a value congruent to mul modulo m less than 2 ** 576
136 ///
137 /// \param[in] mul product from multiply9x9 with 18 numbers of 64 bits each
138 /// \param[out] out result with 9 numbers of 64 bits each
139 ///
140 /// \f$ m = 2^{576} - 2^{240} + 1 \f$
141 ///
142 /// The result in out is guaranteed to be smaller than the modulus.
143 static void mod_m(const uint64_t *mul, uint64_t *out) {
144   uint64_t r[9];
145   // Assign r = t0
146   for (int i = 0; i < 9; i++) {
147     r[i] = mul[i];
148   }
149 
150   int64_t c = compute_r(mul + 9, r);
151 
152   // To update r = r - c * m, it suffices to know c * (-2 ** 240 + 1)
153   // because the 2 ** 576 will cancel out. Also note that c may be zero, but
154   // the operation is still performed to avoid branching.
155 
156   // c * (-2 ** 240 + 1) in 576 bits looks as follows, depending on c:
157   //  - if c = 0, the number is zero.
158   //  - if c = 1: bits 576 to 240 are set,
159   //              bits 239 to 1 are zero, and
160   //              the last one is set
161   //  - if c = -1, which corresponds to all bits set (signed int64_t):
162   //              bits 576 to 240 are zero and the rest is set.
163   // Note that all bits except the last are exactly complimentary (unless c = 0)
164   // and the last byte is conveniently represented by c already.
165   // Now construct the three bit patterns from c, their names correspond to the
166   // assembly implementation by Alexei Sibidanov.
167 
168   // c = 0 -> t0 = 0; c = 1 -> t0 = 0; c = -1 -> all bits set (sign extension)
169   // (The assembly implementation shifts by 63, which gives the same result.)
170   int64_t t0 = c >> 1;
171 
172   // Left shifting negative values is undefined behavior until C++20, cast to
173   // unsigned.
174   uint64_t c_unsigned = static_cast<uint64_t>(c);
175 
176   // c = 0 -> t2 = 0; c = 1 -> upper 16 bits set; c = -1 -> lower 48 bits set
177   int64_t t2 = t0 - (c_unsigned << 48);
178 
179   // c = 0 -> t1 = 0; c = 1 -> all bits set (sign extension); c = -1 -> t1 = 0
180   // (The assembly implementation shifts by 63, which gives the same result.)
181   int64_t t1 = t2 >> 48;
182 
183   unsigned carry = 0;
184   {
185     uint64_t r_0 = r[0];
186 
187     uint64_t out_0 = sub_carry(r_0, c, carry);
188     out[0] = out_0;
189   }
190   for (int i = 1; i < 3; i++) {
191     uint64_t r_i = r[i];
192     r_i = sub_overflow(r_i, carry, carry);
193 
194     uint64_t out_i = sub_carry(r_i, t0, carry);
195     out[i] = out_i;
196   }
197   {
198     uint64_t r_3 = r[3];
199     r_3 = sub_overflow(r_3, carry, carry);
200 
201     uint64_t out_3 = sub_carry(r_3, t2, carry);
202     out[3] = out_3;
203   }
204   for (int i = 4; i < 9; i++) {
205     uint64_t r_i = r[i];
206     r_i = sub_overflow(r_i, carry, carry);
207 
208     uint64_t out_i = sub_carry(r_i, t1, carry);
209     out[i] = out_i;
210   }
211 }
212 
213 /// Combine multiply9x9 and mod_m with internal temporary storage
214 ///
215 /// \param[in] in1 first factor with 9 numbers of 64 bits each
216 /// \param[inout] inout second factor and also the output of the same size
217 ///
218 /// The result in inout is guaranteed to be smaller than the modulus.
219 static void mulmod(const uint64_t *in1, uint64_t *inout) {
220   uint64_t mul[2 * 9] = {0};
221   multiply9x9(in1, inout, mul);
222   mod_m(mul, inout);
223 }
224 
225 /// Compute base to the n modulo m
226 ///
227 /// \param[in] base with 9 numbers of 64 bits each
228 /// \param[out] res output with 9 numbers of 64 bits each
229 /// \param[in] n exponent
230 ///
231 /// The arguments base and res may point to the same location.
232 static void powermod(const uint64_t *base, uint64_t *res, uint64_t n) {
233   uint64_t fac[9] = {0};
234   fac[0] = base[0];
235   res[0] = 1;
236   for (int i = 1; i < 9; i++) {
237     fac[i] = base[i];
238     res[i] = 0;
239   }
240 
241   uint64_t mul[18] = {0};
242   while (n) {
243     if (n & 1) {
244       multiply9x9(res, fac, mul);
245       mod_m(mul, res);
246     }
247     n >>= 1;
248     if (!n)
249       break;
250     multiply9x9(fac, fac, mul);
251     mod_m(mul, fac);
252   }
253 }
254 
255 #endif
256