1
0
mirror of https://github.com/osmarks/random-stuff synced 2025-01-28 01:44:54 +00:00
random-stuff/code-guessing/matrix-ts.py

98 lines
2.9 KiB
Python
Raw Normal View History

2021-03-21 19:36:03 +00:00
#!/usr/bin/env python3
import random
import collections
import subprocess
import ctypes
import tempfile
import os
def random_matrix(n):
return [
[random.randint(-0xFFF, 0xFFF) for _ in range(n)]
for _ in range(n)
]
# https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm#Iterative_algorithm
def simple_multiply(m1, m2):
n = len(m1)
out = [ [ None for _ in range(n) ] for _ in range(n) ]
for i in range(n):
for j in range(n):
total = 0
for k in range(n):
total += m1[i][k] * m2[k][j]
out[i][j] = total
return out
def print_matrix(m):
longest_of_col = collections.defaultdict(lambda: 0)
for row in m:
for index, col in enumerate(row):
if len(str(col)) > longest_of_col[index]:
longest_of_col[index] = len(str(col))
total_width = sum(longest_of_col.values()) + len(m) + 1
out = ["" + (" " * total_width) + ""]
for row in m:
things = [""]
for index, col in enumerate(row):
things.append(str(col).rjust(longest_of_col[index]))
things.append("")
out.append(" ".join(things))
out.append("" + (" " * total_width) + "")
return "\n".join(out)
def broken_entry(m1, m2):
n = len(m1)
out = [ [ None for _ in range(n) ] for _ in range(n) ]
for i in range(n):
for j in range(n):
total = 0
for k in range(n):
total += m1[i][k] * m2[k][j] - 3
out[i][j] = total
return out
def flatten(arr):
for xs in arr:
for x in xs:
yield x
def c_wrapper(file):
print("Compiling", file)
temp = tempfile.mktemp(prefix="lib-compile-")
print(temp)
if subprocess.run(["gcc", file, "-o", temp, "-shared"]).returncode != 0:
raise ValueError("compilation failed")
library = ctypes.CDLL(temp)
entry = library.entry
entry.restype = ctypes.POINTER(ctypes.c_int)
def wrapper(m1, m2):
n = len(m1)
Matrix = (ctypes.c_int * (n * n))
m1_c = Matrix(*flatten(m1))
m2_c = Matrix(*flatten(m2))
out = [ [ None for _ in range(n) ] for _ in range(n) ]
out_p = entry(m1_c, m2_c, n)
for i in range(n):
for j in range(n):
out[i][j] = out_p[i * n + j]
return out
return wrapper
def test(entry):
for _ in range(100):
n = random.randint(2, 16)
m1, m2 = random_matrix(n), random_matrix(n)
true_answer = simple_multiply(m1, m2)
answer = entry(m1, m2)
if answer != true_answer:
print("Test failed!", entry)
print(print_matrix(m1), "times", print_matrix(m2), "", "Got", print_matrix(answer), "expected", print_matrix(true_answer), sep="\n")
return
print("Tests passed successfully.", entry)
#test(c_wrapper("./c_matrix_test.c"))
#test(broken_entry)
import multiply_matrices
test(multiply_matrices.entry)