from numpy.random import rand
import numpy as np
import sympy as sp
import matplotlib.pyplot as plt
from bruch import Bruch


def LRZerlegungVektorisiert(A, dtype):
    n = A.shape[0]
    L = np.eye(n, dtype=object)
    for i in range(n):
        L[i, i] = dtype(L[i, i])
    R = A.copy()
    for kk in range(n - 1):
        L[kk+1:, kk] = R[kk+1:, kk] / R[kk, kk]
        R[kk+1:, :] = R[kk+1:, :] - L[kk+1:, [kk]] * R[[kk], :]
    return L, R


def vorwaerts(L, b):
    n = L.shape[0]
    assert (n == L.shape[1] and n == b.shape[0] and b.shape[1] == 1),\
        "Dimensionen passen nicht."
    assert np.prod(L.diagonal()), "L ist nicht invertierbar."
    x = np.zeros_like(b)
    x[0] = b[0]/L[0, 0]
    for i in range(1, n):
        x[i] = (b[i] - L[i, :i] @ x[:i])/L[i, i]
    return x


def rueckwaerts(R, b):
    n = R.shape[0]
    assert (n == R.shape[1] and n == b.shape[0] and b.shape[1] == 1),\
        "Dimensionen passen nicht."
    assert np.prod(R.diagonal()), "R ist nicht invertierbar."
    x = np.zeros_like(b)
    x[n-1] = b[n-1]/R[n-1, n-1]
    for i in range(n-2, -1, -1):
        x[i] = (b[i] - R[i, i+1:] @ x[i+1:])/R[i, i]
    return x


# %% LR Zerlegung und Vorwaerts/Rueckwaertssubstitution
A, b = np.array([[1e-20, 1.], [1., 1.]]), np.array([[1.], [0.]])
L, R = LRZerlegungVektorisiert(A, np.float64)
y = vorwaerts(L, b)
x = rueckwaerts(R, y)
print(f'Nichtpermutiertes A \n{A} \nrechte Seite b \n{b}')
print(f'L = \n{L} \n \nR = \n{R}')
print(f'Fehler in LR-A: {np.max(abs(L@R-A))}\n')
print(f'y = \n{y} \n x = \n{x}\n')

# %% LR Zerlegung und Vorwaerts/Rueckwaertssubstitution mit Permutation
# wir tauschen die 1. mit der 2. Zeile
PA, Pb = np.array([[1., 1.], [1e-20, 1.]]), np.array([[0.], [1.]])
PL, PR = LRZerlegungVektorisiert(PA, np.float64)
y = vorwaerts(PL, Pb)
x = rueckwaerts(PR, y)
print(f'Permutiertes A \n{PA} \nrechte Seite b \n{Pb}')
print(f'L = \n{L} \n \nR = \n{R}')
print(f'Fehler in LR-A: {np.max(abs(PL@PR-PA))}')
print(f'y = \n{y} \n x = \n{x}\n')

# %%
# exakte Loesung
sym_A = np.array([[sp.S(10)**-20, sp.S(1)], [sp.S(1), sp.S(1)]])
b = np.array([[sp.S(1)], [sp.S(0)]])
L, R = LRZerlegungVektorisiert(sym_A, sp.S)
y = vorwaerts(L, b)
x = rueckwaerts(R, y)
print('\n exakte Loesung mit sympy')
print(f'L = \n{L} \n \n R = \n{R}')
print(f'y = \n{y} \n x = \n{x}\n')

# %% Beispiel aus der Vorlesung
A_ = np.array([[2, 1, 1, 0],
               [4, 3, 3, 1],
               [8, 7, 9, 5],
               [6, 7, 9, 8]])  # leider sind das hier numpy integer, die wir erst mit int casten müssen, damit unsere Bruchklasse funktioniert.

A = np.array([Bruch(int(a)) for a in A_.flatten()]).reshape(A_.shape)
print(f'A = \n{A} \n')

L, R = LRZerlegungVektorisiert(A, Bruch)


print(f'L = \n{L} \n \nR = \n{R}\n')

# p = [2, 3, 1, 0] Permutationsvektor

P = np.array([[0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])

print('PA  = \n', P@A)

# print(f'PA = \n {A[p]}')

L, R = LRZerlegungVektorisiert(P@A, Bruch)

print(f'L = \n{L} \n \nR = \n{R} \n\n')

# %%
# ' ## Graphik

x = np.linspace(-19.57, 0, 200)
f = 1/7*(-x)**(3/2)*((3/2)**(np.sqrt(-x))-np.floor((3/2)**(np.sqrt(-x))))
plt.figure(1)
plt.plot(x, f)


# %%
# ' in Grün

plt.figure(2)
plt.plot(x, f, x, -f, linewidth=3, color='green')


# %%
# ' Aufrichten

fig = plt.figure(3)

ax = fig.add_subplot(111)
ax.plot(f, x, -f, x, linewidth=4, color='lightgreen')
ax.fill(f, x, -f, x, linewidth=2, color='green')

# %%
# ' Ein Rechteck

ax.fill([-.5, .5, .5, -.5, -.5],
        [-19.5, -19.5, -22, -22, -19.5], color='brown', linewidth='4')


# %%
# ' Stern

def n_stern(n):
    x = [(0.5 + j % 2)*np.sin(np.pi*2*j/(2*n)) for j in range(2*n+1)]
    y = [(0.5 + j % 2)*np.cos(np.pi*2*j/(2*n)) for j in range(2*n+1)]
    return x, y


ax.plot(*n_stern(8), color='gold')

# %%
# ' Hintergrund

fig.set_facecolor('darkblue')
ax.axis('off')


# %%
# ' 3D Tannenbaum mit Schnee

fig = plt.figure(100)
ax = fig.add_subplot(111, projection='3d')

theta = np.linspace(-np.pi, np.pi, 50)
X = f.reshape((-1, 1)) * np.sin(theta).reshape(1, -1)
Y = f.reshape((-1, 1)) * np.cos(theta).reshape(1, -1)
Z = x.reshape((-1, 1)) * np.ones_like(theta).reshape(1, -1)
sc = ax.plot_surface(X, Y, Z, color='green')

nFlocken = 10
xs = 20*rand(nFlocken, 1)-10
ys = 20*rand(nFlocken, 1)-10
zs = -20*rand(nFlocken, 1)+2
sc = ax.scatter3D(xs, ys, zs, s=50*np.ones_like(xs), color='white', marker='3')
sc = ax.scatter3D(xs, ys, zs, s=50*np.ones_like(xs), color='white', marker='4')
fig.set_facecolor('darkblue')
ax.set_facecolor('darkblue') 
ax.axis('off')
