import numpy as np
import cv2 as cv
from vilib import Vilib
from time import sleep
from picarx import Picarx
import matplotlib.pyplot as plt
#import sklearn as sk
# from sklearn.linear_model import LinearRegression  
# from scipy.optimize import minimize

# based on https://docs.opencv.org/3.4/d4/dee/tutorial_optical_flow.html

# use the Picarx to record a refrence video to use for the optical flow
px = Picarx()
Vilib.camera_start(vflip=False, hflip=False)
sleep(0.8)

Vilib.rec_video_set["name"] = "video"
Vilib.rec_video_set["path"] = "/home/cs371a/picar-x/master-calibration"

Vilib.rec_video_run()
Vilib.rec_video_start()
px.forward(10)
sleep(1)
px.stop()
Vilib.rec_video_stop()
px.backward(10)
sleep(1)
px.stop()
sleep(0.2)

Vilib.camera_close()

#cap holds the
arg = "/home/cs371a/picar-x/master-calibration/video.avi"
cap = cv.VideoCapture(arg)
# params for ShiTomasi corner detection
feature_params = dict( maxCorners = 100,
                       qualityLevel = 0.3,
                       minDistance = 7,
                       blockSize = 7 )
# Parameters for lucas kanade optical flow
lk_params = dict( winSize  = (15, 15),
                  maxLevel = 2,
                  criteria = (cv.TERM_CRITERIA_EPS | cv.TERM_CRITERIA_COUNT, 10, 0.03))
# Create some random colors
color = np.random.randint(0, 255, (100, 3))
# Take first frame and find corners in it
ret, old_frame = cap.read()
old_gray = cv.cvtColor(old_frame, cv.COLOR_BGR2GRAY)
p0 = cv.goodFeaturesToTrack(old_gray, mask = None, **feature_params)
# Create a mask image for drawing purposes
mask = np.zeros_like(old_frame)

# ADDED: Create an array of arrays, each inner array contain one p0 pixel
lines = []
frame_center_x = old_frame.shape[1] // 2
frame_center_y = old_frame.shape[0] // 2

while(1):
    ret, frame = cap.read()
    if not ret:
        print('No frames grabbed!')
        break
    frame_gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
    # calculate optical flow
    p1, st, err = cv.calcOpticalFlowPyrLK(old_gray, frame_gray, p0, None, **lk_params)
    # Select good points
    if p1 is not None:
        good_new = p1[st==1]
        good_old = p0[st==1]
        
#ADDED25S: Plotting of the old and new points to visually see if the points are same or not, this makes sure if the wheels are calibrated to true 0
 #   frame_copy = frame_gray.copy()
 #   for old_point, new_point in zip(good_old, good_new):
 #       old_x, old_y = old_point.ravel()
 #       new_x, new_y = new_point.ravel()
 #       plt.plot(old_x, old_y, marker = 'o', linestyle = '-', color='r', label ='Old Points(Circles)')
 #       plt.plot(new_x, new_y, marker = 's', linestyle = '-', color='g', label ='New Points(Squares)')
 #   plt.show()
       
       
    #ADDED; Add the new pixels to the array with their corresponding pixels from previous frames. If the array is empty (during the first loop) add the previous pixels as new arrays
    # TODO:--Done25S The major issue that stopped us is that I'm not sure if the pixels are actually in parallel arrays. They should be, but we're not sure and ran out of time before we figured it out. 
    # TODO:--Done25S Check to see what the deal is, and adjust array assignment if necessary. There is a good chance that it IS parallel and there is no issue but it'd be nice to be safe.
    if(len(lines) == 0):
        lines = [[x] for x in good_old]
    
    for i in range(len(good_new)):
        lines[i].append(good_new[i])     
    
    # TODO: Take the linear regression lines and find a point who's "normal distance" (look it up, its a calculus thing) is the average lowest for all values.
    #TODO : improve this code so that it runs better utilizing both camera and sensors. 
        regression_lines = []
        for line in lines:
            x = np.array([point[0] for point in line])
            y = np.array([point[1] for point in line])
            if len(x) > 1:
                m, b = np.polyfit(x, y, 1)
                regression_lines.append((m, b))
                
                # Recalibration
                distance = abs(m * frame_center_x - frame_center_y + b) / np.sqrt(m**2 + 1)
                
                if distance > 0:
                    if m > 0:
                        px.set_dir_servo_angle(2)
                        sleep(0.5)
                        px.forward(10)
                        sleep(0.5)
                        px.stop()
                    else:
                        px.set_dir_servo_angle(-2)
                        sleep(0.5)
                        px.forward(10)
                        sleep(0.5)
                        px.stop()
                    px.stop()
                    sleep(0.5)
                else:
                    px.stop()
                    
    # TODO: From that point, adjust so the point is in the center of the camera screen

    #This is code for creating an image refrence for the optical flow lines, feel free to uncomment if you'd like to see it, though it's just wasted computation for the calculation itself.
    #for i, (new, old) in enumerate(zip(good_new, good_old)):
    #    a, b = new.ravel()
    #    c, d = old.ravel()
    #    mask = cv.line(mask, (int(a), int(b)), (int(c), int(d)), color[i].tolist(), 2)
    #    frame = cv.circle(frame, (int(a), int(b)), 5, color[i].tolist(), -1)
    #img = cv.add(frame, mask)
    #cv.imshow('frame', img)
    k = cv.waitKey(30) & 0xff
    if k == 27:
        break
    # Now update the previous frame and previous points
    old_gray = frame_gray.copy()
    p0 = good_new.reshape(-1, 1, 2)

cv.destroyAllWindows()

# ADDED; Below just gives a plotting of the first line in the optical flow chart. It is unecessary for the final project and just for testing.
#line0 = [[x] for x in lines[4]]
#plt.figure(figsize = (len(line0),len(line0)))
#print(line0[0][0][0])
#x_values = [point[0][0] for point in line0]
#y_values = [point[0][1] for point in line0]
    
#plt.plot(x_values, y_values, marker = 'o', linestyle = '-', label =f'Set {i+1}')
#plt.show()





