#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 30 12:24:14 2019

@author: kerkmann
"""

import numpy as np
import matplotlib.pyplot as plt

def five_point_star(f,a,b,m,bvals):
    '''Calculates the numerical solution to a given BVP
    using the five point star operator of 2nd order'''
    # Input parameters:
    # f       function for right hand side values
    # a       left boundary; notice: computational domain is a square
    # b       right boundary; notice: computational domain is a square
    # m       number of inner grid points
    # bvals   array of size (m+2) x (m+2) with boundary values
    #         Notice: only the first and last rows
    #                 and columns of bvals is are used
        
    # Output parameters:
    # U   array with size (m+2) x (m+2) containing the numerical solution
    
    x = np.linspace(a,b,m+2)
    h = (b-a)/(m+1)
    
    # Matrix setup
    A = np.zeros((m**2,m**2))
    I = np.eye(m)
    e = np.ones(m)
    T = np.diag(e[:-1],-1) - 4*np.diag(e) + np.diag(e[:-1],1)
    S = np.diag(e[:-1],-1) + np.diag(e[:-1],1)
    A = (np.kron(I,T) + np.kron(S,I))/h**2
    
    # Right hand side
    X,Y = np.meshgrid(x,x)                      # mesh grid
    Xint,Yint = X[1:-1,1:-1], Y[1:-1,1:-1]      # interior points
    F = f(Xint,Yint)
    
    # Adjust for boundary terms
    F[0,] = F[0,] - bvals[0,1:-1]/h**2
    F[-1,] = F[-1,] - bvals[-1,1:-1]/h**2
    F[:,0] = F[:,0] - bvals[1:-1,0]/h**2
    F[:,-1] = F[:,-1] - bvals[1:-1,-1]/h**2
    
    F = F.reshape(m**2)
    
    # Solve system and reshape
    UInt = np.linalg.solve(A,F)
    
    UInt = UInt.reshape((m,m))
    
    # Augment U with boundary
    U = bvals.copy()
    U[1:-1,1:-1] = UInt
    
    return U


def nine_point_star(f,a,b,m,bvals):
    '''Calculates the numerical solution to a given BVP
    using the nine point star operator of 4th order
    (including deffered correction)'''
    # Input parameters:
    # f       function for right hand side values
    # a       left boundary; notice: computational domain is a square
    # b       right boundary; notice: computational domain is a square
    # m       number of inner grid points
    # bvals   array of size (m+2) x (m+2) with boundary values
    #         Notice: only the first and last rows
    #                 and columns of bvals is are used
        
    # Output parameters:
    # U   array with size (m+2) x (m+2) containing the numerical solution    
    
    x = np.linspace(a,b,m+2)
    h = (b-a)/(m+1)
    
    # Matrix setup (2 Punkte)
    A = np.zeros((m**2,m**2))
    I = np.eye(m)
    e = np.ones(m)
    T = 4*np.diag(e[:-1],-1) - 20*np.diag(e) + 4*np.diag(e[:-1],1)
    S = np.diag(e[:-1],-1) + np.diag(e[:-1],1)
    R = np.diag(e[:-1],-1) + 4*np.diag(e) + np.diag(e[:-1],1)
    A = (np.kron(I,T) + np.kron(S,R))/6/h**2
    
    # Right hand side
    X,Y = np.meshgrid(x,x)                      # mesh grid
    Xint,Yint = X[1:-1,1:-1], Y[1:-1,1:-1]      # interior points
    F = f(Xint,Yint)
    
    # Adjust for boundary terms (2 Punkte)
    F[0,] = F[0,] - (bvals[0,0:-2] + 4*bvals[0,1:-1] + bvals[0,2:])/6/h**2
    F[-1,] = F[-1,] - (bvals[-1,0:-2] + 4*bvals[-1,1:-1] + bvals[-1,2:])/6/h**2
    F[:,0] = F[:,0] - (bvals[0:-2,0] + 4*bvals[1:-1,0] + bvals[2:,0])/6/h**2
    F[:,-1] = F[:,-1] - (bvals[0:-2,-1] + 4*bvals[1:-1,-1] + bvals[2:,-1])/6/h**2
    
    F[0,0] = F[0,0] + bvals[0,0]/6/h**2
    F[-1,0] = F[-1,0] + bvals[-1,0]/6/h**2
    F[0,-1] = F[0,-1] + bvals[0,-1]/6/h**2
    F[-1,-1] = F[-1,-1] + bvals[-1,-1]/6/h**2
    
    # Deferred Correction (1 Punkt)
    F = F + (f(Xint+h,Yint) + f(Xint-h,Yint) + f(Xint,Yint+h) + f(Xint,Yint-h) - 4*f(Xint,Yint))/12
    
    F = F.reshape(m**2)
    
    # Solve system and reshape
    UInt = np.linalg.solve(A,F)
    
    UInt = UInt.reshape((m,m))
    
    # Augment U with boundary
    U = bvals.copy()
    U[1:-1,1:-1] = UInt
    
    return U


### MAIN PROGRAM ###

a = 0
b = 1

f = lambda x,y: 1.25*np.exp(x+y/2)
    
ms = [10, 20, 40, 80]
err5 = np.zeros(len(ms))
err9 = np.zeros(len(ms))

# solve for different meshes
for k in range(len(ms)):
    m = ms[k]
    x = np.linspace(a,b,m+2)
    X,Y = np.meshgrid(x,x)
    
    utrue = np.exp(X+Y/2)
    bvals = utrue            # only boundary values are used
    
    # solve with five point star
    U = five_point_star(f,a,b,m,bvals)
    
    # error in max norm
    err5[k] = np.max(np.abs(U-utrue))
    
    # solve with nine point star
    U = nine_point_star(f,a,b,m,bvals)

    # error in max norm
    err9[k] = np.max(np.abs(U-utrue))

# experimental error of convergence (1 Punkt)
EOC5 = np.log(err5[:-1]/err5[1:])/np.log(2)
EOC9 = np.log(err9[:-1]/err9[1:])/np.log(2)

# console outputs
print('Five point star:\n Error: {0}\n EOC: {1}'.format(err5,EOC5))
print('Nine point star:\n Error: {0}\n EOC: {1}'.format(err9,EOC9))

# plot
#plt.contour(X,Y,U)
    