AIうぉ--!(ai-wo-katsuyo-shitai !)

AIを上手く使ってみせたい!!自分なりに。

(あまり反響ないけど、、、)これでRELUで掛け算ができる!

import numpy as np
import matplotlib.pyplot as plt

def relu(x):
    return np.maximum(0, x)

# piecewise-linear exact at integer breakpoints 0..18
def square_approx_pos(z):
    # z in [0,18] を整数ごとに線形補間(整数点では z^2 に一致)
    # 連続的に書くため、ヒンジの差分スロープで構成
    hinges = list(range(0, 19))  # 0,1,2,...,18
    # 区間 i..i+1 の傾きは ( (i+1)^2 - i^2 ) / 1 = 2i+1
    # これを ReLU の差分で積み上げる
    out = 0.0
    for i in range(0, 18):
        slope = 2*i + 1
        out += slope * (relu(z - i) - relu(z - (i+1)))
    return out

def square_approx(z):
    # z^2 = (|z|)^2 なので、正側の近似に |z| を入れる
    return square_approx_pos(np.abs(z))

def relu_multiplication(x, y):
    # 恒等式: xy = 1/4 * [ (x+y)^2 - (x-y)^2 ]
    return 0.25 * (square_approx(x + y) - square_approx(x - y))

def plot_and_evaluate():
    x_vals = np.arange(1, 10)
    y_vals = np.arange(1, 10)
    Z = np.zeros((9, 9))
    TrueZ = np.zeros((9, 9))
    Errors = np.zeros((9, 9))

    print("x y | true  approx  error")
    print("----------------------------")

    for i, x in enumerate(x_vals):
        for j, y in enumerate(y_vals):
            true_val = x * y
            approx_val = relu_multiplication(x, y)
            Z[i, j] = approx_val
            TrueZ[i, j] = true_val
            err = abs(true_val - approx_val)
            Errors[i, j] = err
            print(f"{x:2} {y:2} | {true_val:5.1f}  {approx_val:6.2f}  {err:6.2f}")

    mse = np.mean((TrueZ - Z)**2)
    print("\nMean Squared Error (MSE): {:.4f}".format(mse))

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(Z, cmap='viridis', origin='lower')
    plt.colorbar()
    plt.title("ReLU Approximation")
    plt.xlabel('y')
    plt.ylabel('x')
    plt.xticks(ticks=np.arange(9), labels=np.arange(1,10))
    plt.yticks(ticks=np.arange(9), labels=np.arange(1,10))

    plt.subplot(1, 2, 2)
    plt.imshow(Errors, cmap='hot', origin='lower')
    plt.colorbar()
    plt.title("Absolute Error")
    plt.xlabel('y')
    plt.ylabel('x')
    plt.xticks(ticks=np.arange(9), labels=np.arange(1,10))
    plt.yticks(ticks=np.arange(9), labels=np.arange(1,10))

    plt.tight_layout()
    plt.show()

plot_and_evaluate()

↓↑ 誤差、0っす。

x y | true  approx  error
----------------------------
 1  1 |   1.0    1.00    0.00
 1  2 |   2.0    2.00    0.00
 1  3 |   3.0    3.00    0.00
 1  4 |   4.0    4.00    0.00
 1  5 |   5.0    5.00    0.00
 1  6 |   6.0    6.00    0.00
 1  7 |   7.0    7.00    0.00
 1  8 |   8.0    8.00    0.00
 1  9 |   9.0    9.00    0.00
 2  1 |   2.0    2.00    0.00
 2  2 |   4.0    4.00    0.00
 2  3 |   6.0    6.00    0.00
 2  4 |   8.0    8.00    0.00
 2  5 |  10.0   10.00    0.00
 2  6 |  12.0   12.00    0.00
 2  7 |  14.0   14.00    0.00

しかーし、図示してもらうと、超絶トンチンカン