#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Musterlösung für Aufgabe 28

import numpy as np
import matplotlib.pyplot as plt 



def sphere():
    phi1D      = np.linspace(np.pi/4, 3*np.pi/4, 10)[1:-1]
    theta1D    = np.linspace(-3*np.pi/4, 3*np.pi/4, 22)[:-1]
    phi, theta = np.meshgrid(phi1D, theta1D)

    x0 = (np.sin(phi)*np.cos(theta)).flatten()
    x1 = (np.sin(phi)*np.sin(theta)).flatten()
    x2 = np.cos(phi).flatten()

    xYin  = np.vstack([ x0, x1, x2])
    xYang = np.vstack([-x0, x2, x1])
    x     = np.hstack((xYin, xYang))

    return x




def plotBildSphere(A, u = None):
    x = sphere()
    # x ist ein Array der Dimension (3,336), wobei 336 die Anzahl der 
    # geplotteten Pfeile ist. Die Abbildung erhalten wir einfach per 
    # Matrix-Matrix-Multiplikation:
    xabb = A@x

    fig = plt.figure(figsize=(10, 5))

    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122, projection='3d')

    # neue Werte für die Achsengrenzen bestimmen
    m1 = np.min(np.hstack([x, xabb]))
    m2 = np.max(np.hstack([x, xabb]))

    for (a, xx) in [(ax1, x), (ax2, xabb)]:
        a.set_xlim3d(m1, m2)
        a.set_xlabel('$x_0$')
        a.set_ylim3d(m1, m2)
        a.set_ylabel('$x_1$')
        a.set_zlim3d(m1, m2)
        a.set_zlabel('$x_2$')

        for k in range(xx.shape[1]):
            col = plt.cm.rainbow(k/xx.shape[1])
            a.quiver3D(0, 0, 0, xx[0, k], xx[1, k], xx[2, k], colors=col)
            
    # nur für das Plotten vom Richtungsvektor u
    if u is not None:
        ax2.quiver3D(0, 0, 0, u[0], u[1], u[2], colors='k')
        
        
        

