#!/usr/bin/ruby
# Tests for the Number `sum_of_squares(n)` method.
func sum_of_two_squares_solutions(n) is cached {
n == 0 && return [[0, 0]]
var prod1 = 1
var prod2 = 1
var prime_powers = []
for p,e in (n.factor_exp) {
if (p % 4 == 3) { # p = 3 (mod 4)
e.is_even || return [] # power must be even
prod2 *= p**(e >> 1)
}
elsif (p == 2) { # p = 2
if (e.is_even) { # power is even
prod2 *= p**(e >> 1)
}
else { # power is odd
prod1 *= p
prod2 *= p**((e - 1) >> 1)
prime_powers.append([p, 1])
}
}
else { # p = 1 (mod 4)
prod1 *= p**e
prime_powers.append([p, e])
}
}
prod1 == 1 && return [[0, prod2]]
prod1 == 2 && return [[prod2, prod2]]
# All the solutions to the congruence: x^2 = -1 (mod prod1)
var square_roots = gather {
gather {
for p,e in (prime_powers) {
var pp = p**e
var r = sqrtmod(-1, pp)
take([[r, pp], [pp - r, pp]])
}
}.cartesian { |*a|
take(Math.chinese(a...))
}
}
var solutions = []
for r in (square_roots) {
var s = r
var q = prod1
while (s*s > prod1) {
(s, q) = (q % s, s)
}
solutions.append([prod2 * s, prod2 * (q % s)])
}
for p,e in (prime_powers) {
for (var i = e%2; i < e; i += 2) {
var sq = p**((e - i) >> 1)
var pp = p**(e - i)
solutions <<
__FUNC__(prod1 / pp).map { |pair|
pair.map {|r| sq * prod2 * r }
}...
}
}
solutions.map {|pair| pair.sort } \
.uniq_by {|pair| pair[0] } \
.sort_by {|pair| pair[0] }
}
func square_partition_count(n) {
var f = (n.divisors.count{_%4 == 1} - n.divisors.count{_%4 == 3})
(f + is_square(n) + is_square(n/2))/2
}
70.times { |k|
var n = irand(2**k)
var solutions = sum_of_two_squares_solutions(n)
var solutions2 = n.sum_of_squares
if (solutions) {
say %Q(#{n} = #{solutions.map {|a| "#{a[0]}^2 + #{a[1]}^2" }.join(' = ') })
}
assert_eq(solutions, solutions2)
assert_eq(solutions2.len, square_partition_count(n))
}
assert_eq(
2025.sum_of_squares,
[[0, 45], [27, 36]],
)
assert_eq(
164025.sum_of_squares,
[[0, 405], [243, 324]]
)
assert_eq(
99025.sum_of_squares,
[[41, 312], [48, 311], [95, 300], [104, 297], [183, 256], [220, 225]]
)
assert_eq(
-10 .. 160 -> grep { .sum_of_squares.len > 0 },
%n[0, 1, 2, 4, 5, 8, 9, 10, 13, 16, 17, 18, 20, 25, 26, 29, 32, 34, 36, 37, 40, 41, 45, 49, 50, 52, 53, 58, 61, 64, 65, 68, 72, 73, 74, 80, 81, 82, 85, 89, 90, 97, 98, 100, 101, 104, 106, 109, 113, 116, 117, 121, 122, 125, 128, 130, 136, 137, 144, 145, 146, 148, 149, 153, 157, 160]
)
assert_eq(
1..200 -> map { .sum_of_squares.len },
1..200 -> map {|n| square_partition_count(n) }
)
assert_eq(
sum_of_squares(11392163240756069707031250),
[[39309472125, 3374998963875], [216763660575, 3368260197225], [477329304375, 3341305130625], [729359177085, 3295481517405], [735019741071, 3294223614297], [907262616645, 3251005657515], [982736803125, 3228992353125], [1151205969375, 3172835964375], [1224793301193, 3145162095999], [1393801568775, 3074000720175], [1622919634875, 2959441687125], [1847545189875, 2824666354125], [1993551800625, 2723584854375], [2056446956025, 2676413487825], [2194367046795, 2564549961435], [2198769707673, 2560776252111], [2386646521875, 2386646521875]]
)
for n in (1..Inf) {
var a = sum_of_squares(n**2) || next
var s = a.first { .sum + n == 1000 } || next
assert_eq(s.prod * n, 31875000)
break
}
say "** Test passed!"