Currently I'm struglling with problem EASYMATH2.
I figured out that it a possible approach is
- iterate through all non-empty subsets of {a, a+d,...,a+4d} (powerset)
- calculate the least common multiple (lcm) of the elements of each subset
- add or subtract the number of multiples of this lcm within the range n..m
And I think that my python 3 code is doing exactly this without too much overhead, but still keep getting TLE. Obviously, some of the subsets return in the same lcm, so they are calculated unnecessarily, but I thought I could compensate this by memoization. And the respective (redundant) operations are quite compact.
Where is my code too slow? Or do I have to optimize the algo to avoid redundant subsets? But how?
[bbone=python3,502]""" SPOJ (classical).11391. EASY MATH. Problem code: EASYMATH"""
import sys
import itertools as it
def gcd(a, b):
"""Calculate the Greatest Common Divisor of a and b.
Unless b==0, the result will have the same sign as b (so that when
b is divided by it, the result comes out positive).
"""
while b:
a, b = b, a%b
return a
def memoize(fn):
stored_results = {}
def memoized(*args):
if args in stored_results:
# try to get the cached result
return stored_results[args]
else:
# nothing was cached for those args. let's fix that.
result = stored_results[args] = fn(*args)
return result
return memoized
def count_n_in_range( start, end, div):
if start > end:
return 0
res = (end - start) // div
if (start % div == 0) or (start % div > end % div):
return res + 1
else:
return res
count_n_in_range = memoize(count_n_in_range)
def lcml(numbers):
if len(numbers) == 1:
return sum(numbers)
elif len(numbers) == 2:
return numbers [0] * numbers[1] // gcd(*numbers)
else:
return lcml((lcml(numbers[:-1]), numbers[-1]))
lcml = memoize(lcml)
def powerset(iterable):
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return it.chain.from_iterable(it.combinations(s, r)
for r in range(1, len(s)+1))
def calc_numbers(data):
n, m, a, d = map(int, data)
res = m - n + 1
if a > m:
return res
divis = filter(lambda x:x<=m, range(a, a+5*d, d))
for ps in powerset(divis):
remove = lcml(ps)
if remove < m:
count = count_n_in_range(n, m, remove)
if len(ps) % 2:
res -= count
else:
res += count
return str(res)
def main():
cases = sys.stdin.read().split('\n')
output = sys.stdout.write
cases = [ c for c in cases if c]
results = []
for line in cases[1:]:
data = line.split()
res = calc_numbers(data)
results.append(res)
sys.stdout.write('\n'.join(results))
sys.stdout.write('\n')
if name == 'main':
main()
[/bbone]
Any Hints?
Thanks Daft
created
last reply
- 5
replies
- 1.1k
views
- 2
users
- 1
link