1
// Copyright (c) 2017-2021 The Bitcoin Core developers
2
// Distributed under the MIT software license, see the accompanying
3
// file COPYING or http://www.opensource.org/licenses/mit-license.php.
4

            
5
#include <crypto/muhash.h>
6

            
7
#include <crypto/chacha20.h>
8
#include <crypto/common.h>
9
#include <hash.h>
10

            
11
#include <cassert>
12
#include <cstdio>
13
#include <limits>
14

            
15
namespace {
16

            
17
using limb_t = Num3072::limb_t;
18
using double_limb_t = Num3072::double_limb_t;
19
constexpr int LIMB_SIZE = Num3072::LIMB_SIZE;
20
/** 2^3072 - 1103717, the largest 3072-bit safe prime number, is used as the modulus. */
21
constexpr limb_t MAX_PRIME_DIFF = 1103717;
22

            
23
/** Extract the lowest limb of [c0,c1,c2] into n, and left shift the number by 1 limb. */
24
inline void extract3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& n)
25
{
26
    n = c0;
27
    c0 = c1;
28
    c1 = c2;
29
    c2 = 0;
30
}
31

            
32
/** [c0,c1] = a * b */
33
inline void mul(limb_t& c0, limb_t& c1, const limb_t& a, const limb_t& b)
34
{
35
    double_limb_t t = (double_limb_t)a * b;
36
    c1 = t >> LIMB_SIZE;
37
    c0 = t;
38
}
39

            
40
/* [c0,c1,c2] += n * [d0,d1,d2]. c2 is 0 initially */
41
inline void mulnadd3(limb_t& c0, limb_t& c1, limb_t& c2, limb_t& d0, limb_t& d1, limb_t& d2, const limb_t& n)
42
{
43
    double_limb_t t = (double_limb_t)d0 * n + c0;
44
    c0 = t;
45
    t >>= LIMB_SIZE;
46
    t += (double_limb_t)d1 * n + c1;
47
    c1 = t;
48
    t >>= LIMB_SIZE;
49
    c2 = t + d2 * n;
50
}
51

            
52
/* [c0,c1] *= n */
53
inline void muln2(limb_t& c0, limb_t& c1, const limb_t& n)
54
{
55
    double_limb_t t = (double_limb_t)c0 * n;
56
    c0 = t;
57
    t >>= LIMB_SIZE;
58
    t += (double_limb_t)c1 * n;
59
    c1 = t;
60
}
61

            
62
/** [c0,c1,c2] += a * b */
63
inline void muladd3(limb_t& c0, limb_t& c1, limb_t& c2, const limb_t& a, const limb_t& b)
64
{
65
    double_limb_t t = (double_limb_t)a * b;
66
    limb_t th = t >> LIMB_SIZE;
67
    limb_t tl = t;
68

            
69
    c0 += tl;
70
    th += (c0 < tl) ? 1 : 0;
71
    c1 += th;
72
    c2 += (c1 < th) ? 1 : 0;
73
}
74

            
75
/** [c0,c1,c2] += 2 * a * b */
76
inline void muldbladd3(limb_t& c0, limb_t& c1, limb_t& c2, const limb_t& a, const limb_t& b)
77
{
78
    double_limb_t t = (double_limb_t)a * b;
79
    limb_t th = t >> LIMB_SIZE;
80
    limb_t tl = t;
81

            
82
    c0 += tl;
83
    limb_t tt = th + ((c0 < tl) ? 1 : 0);
84
    c1 += tt;
85
    c2 += (c1 < tt) ? 1 : 0;
86
    c0 += tl;
87
    th += (c0 < tl) ? 1 : 0;
88
    c1 += th;
89
    c2 += (c1 < th) ? 1 : 0;
90
}
91

            
92
/**
93
 * Add limb a to [c0,c1]: [c0,c1] += a. Then extract the lowest
94
 * limb of [c0,c1] into n, and left shift the number by 1 limb.
95
 * */
96
inline void addnextract2(limb_t& c0, limb_t& c1, const limb_t& a, limb_t& n)
97
{
98
    limb_t c2 = 0;
99

            
100
    // add
101
    c0 += a;
102
    if (c0 < a) {
103
        c1 += 1;
104

            
105
        // Handle case when c1 has overflown
106
        if (c1 == 0)
107
            c2 = 1;
108
    }
109

            
110
    // extract
111
    n = c0;
112
    c0 = c1;
113
    c1 = c2;
114
}
115

            
116
/** in_out = in_out^(2^sq) * mul */
117
inline void square_n_mul(Num3072& in_out, const int sq, const Num3072& mul)
118
{
119
    for (int j = 0; j < sq; ++j) in_out.Square();
120
    in_out.Multiply(mul);
121
}
122

            
123
} // namespace
124

            
125
/** Indicates whether d is larger than the modulus. */
126
bool Num3072::IsOverflow() const
127
{
128
    if (this->limbs[0] <= std::numeric_limits<limb_t>::max() - MAX_PRIME_DIFF) return false;
129
    for (int i = 1; i < LIMBS; ++i) {
130
        if (this->limbs[i] != std::numeric_limits<limb_t>::max()) return false;
131
    }
132
    return true;
133
}
134

            
135
void Num3072::FullReduce()
136
{
137
    limb_t c0 = MAX_PRIME_DIFF;
138
    limb_t c1 = 0;
139
    for (int i = 0; i < LIMBS; ++i) {
140
        addnextract2(c0, c1, this->limbs[i], this->limbs[i]);
141
    }
142
}
143

            
144
Num3072 Num3072::GetInverse() const
145
{
146
    // For fast exponentiation a sliding window exponentiation with repunit
147
    // precomputation is utilized. See "Fast Point Decompression for Standard
148
    // Elliptic Curves" (Brumley, Järvinen, 2008).
149

            
150
    Num3072 p[12]; // p[i] = a^(2^(2^i)-1)
151
    Num3072 out;
152

            
153
    p[0] = *this;
154

            
155
    for (int i = 0; i < 11; ++i) {
156
        p[i + 1] = p[i];
157
        for (int j = 0; j < (1 << i); ++j) p[i + 1].Square();
158
        p[i + 1].Multiply(p[i]);
159
    }
160

            
161
    out = p[11];
162

            
163
    square_n_mul(out, 512, p[9]);
164
    square_n_mul(out, 256, p[8]);
165
    square_n_mul(out, 128, p[7]);
166
    square_n_mul(out, 64, p[6]);
167
    square_n_mul(out, 32, p[5]);
168
    square_n_mul(out, 8, p[3]);
169
    square_n_mul(out, 2, p[1]);
170
    square_n_mul(out, 1, p[0]);
171
    square_n_mul(out, 5, p[2]);
172
    square_n_mul(out, 3, p[0]);
173
    square_n_mul(out, 2, p[0]);
174
    square_n_mul(out, 4, p[0]);
175
    square_n_mul(out, 4, p[1]);
176
    square_n_mul(out, 3, p[0]);
177

            
178
    return out;
179
}
180

            
181
void Num3072::Multiply(const Num3072& a)
182
{
183
    limb_t c0 = 0, c1 = 0, c2 = 0;
184
    Num3072 tmp;
185

            
186
    /* Compute limbs 0..N-2 of this*a into tmp, including one reduction. */
187
    for (int j = 0; j < LIMBS - 1; ++j) {
188
        limb_t d0 = 0, d1 = 0, d2 = 0;
189
        mul(d0, d1, this->limbs[1 + j], a.limbs[LIMBS + j - (1 + j)]);
190
        for (int i = 2 + j; i < LIMBS; ++i) muladd3(d0, d1, d2, this->limbs[i], a.limbs[LIMBS + j - i]);
191
        mulnadd3(c0, c1, c2, d0, d1, d2, MAX_PRIME_DIFF);
192
        for (int i = 0; i < j + 1; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[j - i]);
193
        extract3(c0, c1, c2, tmp.limbs[j]);
194
    }
195

            
196
    /* Compute limb N-1 of a*b into tmp. */
197
    assert(c2 == 0);
198
    for (int i = 0; i < LIMBS; ++i) muladd3(c0, c1, c2, this->limbs[i], a.limbs[LIMBS - 1 - i]);
199
    extract3(c0, c1, c2, tmp.limbs[LIMBS - 1]);
200

            
201
    /* Perform a second reduction. */
202
    muln2(c0, c1, MAX_PRIME_DIFF);
203
    for (int j = 0; j < LIMBS; ++j) {
204
        addnextract2(c0, c1, tmp.limbs[j], this->limbs[j]);
205
    }
206

            
207
    assert(c1 == 0);
208
    assert(c0 == 0 || c0 == 1);
209

            
210
    /* Perform up to two more reductions if the internal state has already
211
     * overflown the MAX of Num3072 or if it is larger than the modulus or
212
     * if both are the case.
213
     * */
214
    if (this->IsOverflow()) this->FullReduce();
215
    if (c0) this->FullReduce();
216
}
217

            
218
void Num3072::Square()
219
{
220
    limb_t c0 = 0, c1 = 0, c2 = 0;
221
    Num3072 tmp;
222

            
223
    /* Compute limbs 0..N-2 of this*this into tmp, including one reduction. */
224
    for (int j = 0; j < LIMBS - 1; ++j) {
225
        limb_t d0 = 0, d1 = 0, d2 = 0;
226
        for (int i = 0; i < (LIMBS - 1 - j) / 2; ++i) muldbladd3(d0, d1, d2, this->limbs[i + j + 1], this->limbs[LIMBS - 1 - i]);
227
        if ((j + 1) & 1) muladd3(d0, d1, d2, this->limbs[(LIMBS - 1 - j) / 2 + j + 1], this->limbs[LIMBS - 1 - (LIMBS - 1 - j) / 2]);
228
        mulnadd3(c0, c1, c2, d0, d1, d2, MAX_PRIME_DIFF);
229
        for (int i = 0; i < (j + 1) / 2; ++i) muldbladd3(c0, c1, c2, this->limbs[i], this->limbs[j - i]);
230
        if ((j + 1) & 1) muladd3(c0, c1, c2, this->limbs[(j + 1) / 2], this->limbs[j - (j + 1) / 2]);
231
        extract3(c0, c1, c2, tmp.limbs[j]);
232
    }
233

            
234
    assert(c2 == 0);
235
    for (int i = 0; i < LIMBS / 2; ++i) muldbladd3(c0, c1, c2, this->limbs[i], this->limbs[LIMBS - 1 - i]);
236
    extract3(c0, c1, c2, tmp.limbs[LIMBS - 1]);
237

            
238
    /* Perform a second reduction. */
239
    muln2(c0, c1, MAX_PRIME_DIFF);
240
    for (int j = 0; j < LIMBS; ++j) {
241
        addnextract2(c0, c1, tmp.limbs[j], this->limbs[j]);
242
    }
243

            
244
    assert(c1 == 0);
245
    assert(c0 == 0 || c0 == 1);
246

            
247
    /* Perform up to two more reductions if the internal state has already
248
     * overflown the MAX of Num3072 or if it is larger than the modulus or
249
     * if both are the case.
250
     * */
251
    if (this->IsOverflow()) this->FullReduce();
252
    if (c0) this->FullReduce();
253
}
254

            
255
void Num3072::SetToOne()
256
{
257
    this->limbs[0] = 1;
258
    for (int i = 1; i < LIMBS; ++i) this->limbs[i] = 0;
259
}
260

            
261
void Num3072::Divide(const Num3072& a)
262
{
263
    if (this->IsOverflow()) this->FullReduce();
264

            
265
    Num3072 inv{};
266
    if (a.IsOverflow()) {
267
        Num3072 b = a;
268
        b.FullReduce();
269
        inv = b.GetInverse();
270
    } else {
271
        inv = a.GetInverse();
272
    }
273

            
274
    this->Multiply(inv);
275
    if (this->IsOverflow()) this->FullReduce();
276
}
277

            
278
Num3072::Num3072(const unsigned char (&data)[BYTE_SIZE]) {
279
    for (int i = 0; i < LIMBS; ++i) {
280
        if (sizeof(limb_t) == 4) {
281
            this->limbs[i] = ReadLE32(data + 4 * i);
282
        } else if (sizeof(limb_t) == 8) {
283
            this->limbs[i] = ReadLE64(data + 8 * i);
284
        }
285
    }
286
}
287

            
288
void Num3072::ToBytes(unsigned char (&out)[BYTE_SIZE]) {
289
    for (int i = 0; i < LIMBS; ++i) {
290
        if (sizeof(limb_t) == 4) {
291
            WriteLE32(out + i * 4, this->limbs[i]);
292
        } else if (sizeof(limb_t) == 8) {
293
            WriteLE64(out + i * 8, this->limbs[i]);
294
        }
295
    }
296
}
297

            
298
Num3072 MuHash3072::ToNum3072(Span<const unsigned char> in) {
299
    unsigned char tmp[Num3072::BYTE_SIZE];
300

            
301
    uint256 hashed_in = (CHashWriter(SER_DISK, 0) << in).GetSHA256();
302
    ChaCha20(hashed_in.data(), hashed_in.size()).Keystream(tmp, Num3072::BYTE_SIZE);
303
    Num3072 out{tmp};
304

            
305
    return out;
306
}
307

            
308
MuHash3072::MuHash3072(Span<const unsigned char> in) noexcept
309
{
310
    m_numerator = ToNum3072(in);
311
}
312

            
313
void MuHash3072::Finalize(uint256& out) noexcept
314
{
315
    m_numerator.Divide(m_denominator);
316
    m_denominator.SetToOne();  // Needed to keep the MuHash object valid
317

            
318
    unsigned char data[Num3072::BYTE_SIZE];
319
    m_numerator.ToBytes(data);
320

            
321
    out = (CHashWriter(SER_DISK, 0) << data).GetSHA256();
322
}
323

            
324
MuHash3072& MuHash3072::operator*=(const MuHash3072& mul) noexcept
325
{
326
    m_numerator.Multiply(mul.m_numerator);
327
    m_denominator.Multiply(mul.m_denominator);
328
    return *this;
329
}
330

            
331
MuHash3072& MuHash3072::operator/=(const MuHash3072& div) noexcept
332
{
333
    m_numerator.Multiply(div.m_denominator);
334
    m_denominator.Multiply(div.m_numerator);
335
    return *this;
336
}
337

            
338
MuHash3072& MuHash3072::Insert(Span<const unsigned char> in) noexcept {
339
    m_numerator.Multiply(ToNum3072(in));
340
    return *this;
341
}
342

            
343
MuHash3072& MuHash3072::Remove(Span<const unsigned char> in) noexcept {
344
    m_denominator.Multiply(ToNum3072(in));
345
    return *this;
346
}