import numpy as np import matplotlib.pyplot as plt def relu(x): return np.maximum(0, x) # piecewise-linear exact at integer breakpoints 0..18 def square_approx_pos(z): # z in [0,18] を整数ごとに線形補間(整数点では z^2 に一致) # 連続的に書くため、ヒンジの差分スロープで構成 hinges = list(range(0, 19)) # 0,1,2,...,18 # 区間 i..i+1 の傾きは ( (i+1)^2 - i^2 ) / 1 = 2i+1 # これを ReLU の差分で積み上げる out = 0.0 for i in range(0, 18): slope = 2*i + 1 out += slope * (relu(z - i) - relu(z - (i+1))) return out def square_approx(z): # z^2 = (|z|)^2 なので、正側の近似に |z| を入れる return square_approx_pos(np.abs(z)) def relu_multiplication(x, y): # 恒等式: xy = 1/4 * [ (x+y)^2 - (x-y)^2 ] return 0.25 * (square_approx(x + y) - square_approx(x - y)) def plot_and_evaluate(): x_vals = np.arange(1, 10) y_vals = np.arange(1, 10) Z = np.zeros((9, 9)) TrueZ = np.zeros((9, 9)) Errors = np.zeros((9, 9)) print("x y | true approx error") print("----------------------------") for i, x in enumerate(x_vals): for j, y in enumerate(y_vals): true_val = x * y approx_val = relu_multiplication(x, y) Z[i, j] = approx_val TrueZ[i, j] = true_val err = abs(true_val - approx_val) Errors[i, j] = err print(f"{x:2} {y:2} | {true_val:5.1f} {approx_val:6.2f} {err:6.2f}") mse = np.mean((TrueZ - Z)**2) print("\nMean Squared Error (MSE): {:.4f}".format(mse)) plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.imshow(Z, cmap='viridis', origin='lower') plt.colorbar() plt.title("ReLU Approximation") plt.xlabel('y') plt.ylabel('x') plt.xticks(ticks=np.arange(9), labels=np.arange(1,10)) plt.yticks(ticks=np.arange(9), labels=np.arange(1,10)) plt.subplot(1, 2, 2) plt.imshow(Errors, cmap='hot', origin='lower') plt.colorbar() plt.title("Absolute Error") plt.xlabel('y') plt.ylabel('x') plt.xticks(ticks=np.arange(9), labels=np.arange(1,10)) plt.yticks(ticks=np.arange(9), labels=np.arange(1,10)) plt.tight_layout() plt.show() plot_and_evaluate()
↓↑ 誤差、0っす。
x y | true approx error ---------------------------- 1 1 | 1.0 1.00 0.00 1 2 | 2.0 2.00 0.00 1 3 | 3.0 3.00 0.00 1 4 | 4.0 4.00 0.00 1 5 | 5.0 5.00 0.00 1 6 | 6.0 6.00 0.00 1 7 | 7.0 7.00 0.00 1 8 | 8.0 8.00 0.00 1 9 | 9.0 9.00 0.00 2 1 | 2.0 2.00 0.00 2 2 | 4.0 4.00 0.00 2 3 | 6.0 6.00 0.00 2 4 | 8.0 8.00 0.00 2 5 | 10.0 10.00 0.00 2 6 | 12.0 12.00 0.00 2 7 | 14.0 14.00 0.00
しかーし、図示してもらうと、超絶トンチンカン
