#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 24 14:32:34 2019

@author: kerkmann
"""

import numpy as np
import matplotlib.pyplot as plt

# Newton method (from exercise 13)
def Newton(f,df,x0,tol=1e-8,max_it=100):
    '''Newton algorithm to find root of function f'''
    # Input parameters:
    # f        function to find root of
    # df       derivative of f
    # x0       initial guess / starting value
    # tol      tolerance to finish iteration
    # max_it   maximum number of iterations
            
    # Output parameters:
    # x        solution
    # it       number of iterations
    # n        size of next step
    
    
    it = 0
    d = np.linalg.solve(np.atleast_2d(df(x0)),np.atleast_1d(f(x0)))
    x = x0
    
    while np.linalg.norm(d,np.inf) > tol and it < max_it:
        x = x - d
        d = np.linalg.solve(np.atleast_2d(df(x)),np.atleast_1d(f(x)))
        it += 1
    
    # Warning if maximimum number of iterations has been reached
    if it == max_it:
        print('Warning: Maximum number of iterations reached. Newton\'s method might not have converged.')
        
    return x, it, np.linalg.norm(d)

# explicit Euler method (from exercise 27)
def Euler_explicit(x,t,f,k):
    return x + k*f(x,t)

# implicit Euler method
def Euler_implicit(x,t,f,df,k):
    g = lambda u: u - k*f(u,t+k) - x
    dg = lambda u: 1 - k*df(u,t+k)
    # explicit Euler as initial guess
    x0 = Euler_explicit(x,t,f,k)
    return Newton(g,dg,x0)[0]

# trapezoidal rule
def trap(x,t,f,df,k):
    g = lambda u: u - k/2*(f(u,t+k) + f(x,t)) - x
    dg = lambda u: 1 - k/2*df(u,t+k)
    # explicit Euler as initial guess
    x0 = Euler_explicit(x,t,f,k)
    return Newton(g,dg,x0)[0]

# TR-BDF 2
def TR_BDF2(x,t,f,df,k):
    g = lambda u: u - k/4*(f(u,t+k/2) + f(x,t)) - x
    dg = lambda u: 1 - k/4*df(u,t+k/2)
    # explicit Euler as initial guess
    x0 = Euler_explicit(x,t,f,k)
    ustar = Newton(g,dg,x0)[0]
    
    g = lambda u: u - k/3*f(u,t+k) - 4/3*ustar + x
    dg = lambda u: 1 - k/3*df(u,t+k)
    # explicit Euler as initial guess
    x0 = Euler_explicit(x,t,f,k)
    return Newton(g,dg,x0)[0]
    
### MAIN PROGRAM ###
    
l = -10**6
f = lambda u,t: l*(u-np.cos(t)) - np.sin(t)
df = lambda u,t: l
ex = lambda t: 1/2*np.exp(l*t) + np.cos(t)
t_0 = 0
T = 3

# (a) (1 Punkt)

k = 1.9*10**(-6) # maximum step size according to stability region is -2/l

# initial data
U = [1.5]
t = [t_0]
n = 0

while t[n]<T-1e-14:
    # match final time
    if (t[-1]+k) > T:
        k = T - t[-1]

    # approximation with explicit Euler
    U.append(Euler_explicit(U[n],t[-1],f,k))
    n = n+1
    t.append(t[-1] + k)

err = abs(U[-1] - ex(T))

U = np.array(U)
t = np.array(t)
plt.plot(t,U,label='explicit Euler')
plt.plot(t,ex(t),label='exact solution')

# (b) (2 Punkte)

k = 5*10**(-2)

t = [t_0]
n = 0

# initial data
U = [1.5]
U2 = [1.5]
t = [t_0]
n = 0

while t[n]<T-1e-14:
    # match final time
    if (t[-1]+k) > T:
        k = T - t[-1]

    # approximation with implicit Euler
    U.append(Euler_implicit(U[n],t[-1],f,df,k))
    U2.append(trap(U2[n],t[-1],f,df,k))
    n = n+1
    t.append(t[-1] + k)

err = abs(U[-1] - ex(T))

U = np.array(U)
U2 = np.array(U2)
t = np.array(t)
plt.plot(t,U,label='implicit Euler')
plt.plot(t,U2,label='trapezoidal rule')

# Die Trapezregel liefert oszillierende Ergebnisse. (1 Punkt)
# Begruendung:
# Anwendung auf u'=lambda * u liefert
# R(z) = (1+z/2)/(1-z/2), s.d. |R(z)| -> 1 fuer |z| -> inf
# D.h. fuer grosse k ist R(z) in etwa -1, d.h. der Fehler, der gemacht
# wurde, aendert in jedem Schritt das Vorzeichen, wird aber nur wenig
# kleiner.

# (c) (2 Punkte)
k = 5*10**(-2)

# initial data
U = [1.5]
t = [t_0]
n = 0

while t[n]<T-1e-14:
    # match final time
    if (t[-1]+k) > T:
        k = T - t[-1]

    # approximation with implicit Euler
    U.append(TR_BDF2(U[n],t[-1],f,df,k))
    n = n+1
    t.append(t[-1] + k)

err = abs(U[-1] - ex(T))

U = np.array(U)
U2 = np.array(U2)
t = np.array(t)
plt.plot(t,U,label='TR-BDF 2')
plt.legend()

# Vorteil: 2. Ordnung statt 1. Ordnung wie beim impl. Euler Verfahren (1 Punkt)