from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

def func1(x, a, b):
    return a*x + b

def func2(x, a, b, c):
    return a*x*x + b*x +c

    
def trender(X, Y, Obj_X, order, view, X_ax):
    if order==1:
        popt, pcov = curve_fit(func1, X, Y)
        if view == 'yes':
            fig = plt.figure()
            plt.plot(X, Y, 'go')
            plt.plot(X, func1(X, *popt), 'ro')
            plt.ylabel('O-C', fontsize=15)
            plt.xlabel(X_ax, fontsize=15)
            plt.grid()
            plt.draw()
            plt.pause(0.1)
            plt.waitforbuttonpress(0) # this will wait for indefinite time
            plt.close(fig)
        return func1(X, *popt), func1(Obj_X, *popt)

    if order==2:
        popt, pcov = curve_fit(func2, X, Y)
        if view == 'yes':
            fig = plt.figure()
            plt.plot(X, Y, 'go')
            plt.plot(X, func2(X, *popt), 'ro')
            plt.ylabel('O-C', fontsize=15)
            plt.xlabel(X_ax, fontsize=15)
            plt.grid()
            plt.draw()
            plt.pause(0.1)
            plt.waitforbuttonpress(0) # this will wait for indefinite time
            plt.close(fig)
        return func2(X, *popt), func2(Obj_X, *popt)        
