-
Notifications
You must be signed in to change notification settings - Fork 3
/
sol1.py
200 lines (159 loc) · 5.02 KB
/
sol1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Sum of digits sequence
Problem 551
Let a(0), a(1),... be an integer sequence defined by:
a(0) = 1
for n >= 1, a(n) is the sum of the digits of all preceding terms
The sequence starts with 1, 1, 2, 4, 8, ...
You are given a(10^6) = 31054319.
Find a(10^15)
"""
ks = [k for k in range(2, 20 + 1)]
base = [10 ** k for k in range(ks[-1] + 1)]
memo = {}
def next_term(a_i, k, i, n):
"""
Calculates and updates a_i in-place to either the n-th term or the
smallest term for which c > 10^k when the terms are written in the form:
a(i) = b * 10^k + c
For any a(i), if digitsum(b) and c have the same value, the difference
between subsequent terms will be the same until c >= 10^k. This difference
is cached to greatly speed up the computation.
Arguments:
a_i -- array of digits starting from the one's place that represent
the i-th term in the sequence
k -- k when terms are written in the from a(i) = b*10^k + c.
Term are calulcated until c > 10^k or the n-th term is reached.
i -- position along the sequence
n -- term to calculate up to if k is large enough
Return: a tuple of difference between ending term and starting term, and
the number of terms calculated. ex. if starting term is a_0=1, and
ending term is a_10=62, then (61, 9) is returned.
"""
# ds_b - digitsum(b)
ds_b = sum(a_i[j] for j in range(k, len(a_i)))
c = sum(a_i[j] * base[j] for j in range(min(len(a_i), k)))
diff, dn = 0, 0
max_dn = n - i
sub_memo = memo.get(ds_b)
if sub_memo is not None:
jumps = sub_memo.get(c)
if jumps is not None and len(jumps) > 0:
# find and make the largest jump without going over
max_jump = -1
for _k in range(len(jumps) - 1, -1, -1):
if jumps[_k][2] <= k and jumps[_k][1] <= max_dn:
max_jump = _k
break
if max_jump >= 0:
diff, dn, _kk = jumps[max_jump]
# since the difference between jumps is cached, add c
new_c = diff + c
for j in range(min(k, len(a_i))):
new_c, a_i[j] = divmod(new_c, 10)
if new_c > 0:
add(a_i, k, new_c)
else:
sub_memo[c] = []
else:
sub_memo = {c: []}
memo[ds_b] = sub_memo
if dn >= max_dn or c + diff >= base[k]:
return diff, dn
if k > ks[0]:
while True:
# keep doing smaller jumps
_diff, terms_jumped = next_term(a_i, k - 1, i + dn, n)
diff += _diff
dn += terms_jumped
if dn >= max_dn or c + diff >= base[k]:
break
else:
# would be too small a jump, just compute sequential terms instead
_diff, terms_jumped = compute(a_i, k, i + dn, n)
diff += _diff
dn += terms_jumped
jumps = sub_memo[c]
# keep jumps sorted by # of terms skipped
j = 0
while j < len(jumps):
if jumps[j][1] > dn:
break
j += 1
# cache the jump for this value digitsum(b) and c
sub_memo[c].insert(j, (diff, dn, k))
return (diff, dn)
def compute(a_i, k, i, n):
"""
same as next_term(a_i, k, i, n) but computes terms without memoizing results.
"""
if i >= n:
return 0, i
if k > len(a_i):
a_i.extend([0 for _ in range(k - len(a_i))])
# note: a_i -> b * 10^k + c
# ds_b -> digitsum(b)
# ds_c -> digitsum(c)
start_i = i
ds_b, ds_c, diff = 0, 0, 0
for j in range(len(a_i)):
if j >= k:
ds_b += a_i[j]
else:
ds_c += a_i[j]
while i < n:
i += 1
addend = ds_c + ds_b
diff += addend
ds_c = 0
for j in range(k):
s = a_i[j] + addend
addend, a_i[j] = divmod(s, 10)
ds_c += a_i[j]
if addend > 0:
break
if addend > 0:
add(a_i, k, addend)
return diff, i - start_i
def add(digits, k, addend):
"""
adds addend to digit array given in digits
starting at index k
"""
for j in range(k, len(digits)):
s = digits[j] + addend
if s >= 10:
quotient, digits[j] = divmod(s, 10)
addend = addend // 10 + quotient
else:
digits[j] = s
addend = addend // 10
if addend == 0:
break
while addend > 0:
addend, digit = divmod(addend, 10)
digits.append(digit)
def solution(n: int = 10 ** 15) -> int:
"""
returns n-th term of sequence
>>> solution(10)
62
>>> solution(10**6)
31054319
>>> solution(10**15)
73597483551591773
"""
digits = [1]
i = 1
dn = 0
while True:
diff, terms_jumped = next_term(digits, 20, i + dn, n)
dn += terms_jumped
if dn == n - i:
break
a_n = 0
for j in range(len(digits)):
a_n += digits[j] * 10 ** j
return a_n
if __name__ == "__main__":
print(f"{solution() = }")