import cv2
import numpy as np
import glob
import pickle
import time
from picamera2 import Picamera2

class PiCarX_CameraCalibrator:
	#This is to ensure that our program is detecting a 7x7 chessboard pattern with 6x6 corners
    def __init__(self, pattern_size=(6, 6), calibration_file='camera_calibration.pkl'):
        self.pattern_size = pattern_size
        self.calibration_file = calibration_file
        self.camera_matrix = None
        self.dist_coeffs = None
        self.calibrated = False
        self.tilt_angles = []

        self.objp = np.zeros((pattern_size[0] * pattern_size[1], 3), np.float32)
        self.objp[:, :2] = np.mgrid[0:pattern_size[0], 0:pattern_size[1]].T.reshape(-1, 2)

	#Take 20 images in order to detect the corners
    def capture_calibration_images(self, num_images=20, save_path='calibration_images/'):
        picam2 = Picamera2()
        config = picam2.create_still_configuration(main={"size": (640, 480)})
        picam2.configure(config)
        picam2.start()
        time.sleep(1)

        print(f"Capturing {num_images} calibration images...")

        for i in range(num_images):
            input(f"Press Enter to capture image {i+1}/{num_images}...")
            image = picam2.capture_array()
            cv2.imwrite(f"{save_path}calib_{i:02d}.jpg", image)
            print(f"Saved {save_path}calib_{i:02d}.jpg")

        picam2.close()
        cv2.destroyAllWindows()

	#Show the pattern found on the chessboard that can be used to calibrate the camera
    def calibrate_camera(self, image_paths='calibration_images/*.jpg'):
        obj_points = []
        img_points = []

        images = glob.glob(image_paths)
        if not images:
            raise FileNotFoundError(f"No images found at {image_paths}")

        for fname in images:
            img = cv2.imread(fname)
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

            ret, corners = cv2.findChessboardCorners(gray, self.pattern_size, None)

            if ret:
                obj_points.append(self.objp)
                criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)
                corners_refined = cv2.cornerSubPix(gray, corners, (11, 11), (-1, -1), criteria)
                img_points.append(corners_refined)

                cv2.drawChessboardCorners(img, self.pattern_size, corners_refined, ret)
                cv2.imshow('Chessboard Corners', img)
                cv2.waitKey(500)

        cv2.destroyAllWindows()

        if not obj_points:
            raise ValueError("No chessboard patterns found in images!")

        ret, self.camera_matrix, self.dist_coeffs, rvecs, tvecs = cv2.calibrateCamera(
            obj_points, img_points, gray.shape[::-1], None, None)

        self.tilt_angles = [self._rotation_vector_to_euler_angles(rvec) for rvec in rvecs]
        
        #Roll -> Tilt, Pitch -> pan, Yaw -> Horizon
        print("Camera Tilt Angles (Roll, Pitch, Yaw in degrees):")
        for idx, (roll, pitch, yaw) in enumerate(self.tilt_angles):
            print(f"Image {idx+1}: Roll={roll:.2f}, Pitch={pitch:.2f}, Yaw={yaw:.2f}")

        self.avg_roll = np.mean([angle[0] for angle in self.tilt_angles])
        print(f"\nAverage roll angle: {self.avg_roll:.2f}° (positive = rotated clockwise)")
        
        avg_pitch = np.mean([angle[1] for angle in self.tilt_angles])
        print(f"\nAverage pitch angle: {avg_pitch:.2f}°")
        print("(Negative value means camera is pointing slightly downward)")
        print("You may want to manually adjust the camera tilt if needed")

        mean_error = self._calculate_reprojection_error(obj_points, img_points, rvecs, tvecs)
        print(f"\nCalibration complete! Reprojection error: {mean_error:.5f} pixels")

        self.calibrated = True
        self._save_calibration()

    def _save_calibration(self):
        calibration_data = {
            'camera_matrix': self.camera_matrix,
            'dist_coeffs': self.dist_coeffs,
            'tilt_angles': self.tilt_angles
        }
        with open(self.calibration_file, 'wb') as f:
            pickle.dump(calibration_data, f)
        print(f"\nCalibration data saved to {self.calibration_file}")

    def load_calibration(self):
        with open(self.calibration_file, 'rb') as f:
            calibration_data = pickle.load(f)
        self.camera_matrix = calibration_data['camera_matrix']
        self.dist_coeffs = calibration_data['dist_coeffs']
        self.tilt_angles = calibration_data.get('tilt_angles', [])
        self.calibrated = True
        print("Calibration data loaded successfully")

    def undistort_image(self, image):
        if not self.calibrated:
            raise RuntimeError("Camera not calibrated! Call calibrate_camera() first")

        h, w = image.shape[:2]
        new_camera_mtx, roi = cv2.getOptimalNewCameraMatrix(
            self.camera_matrix, self.dist_coeffs, (w, h), 0, (w, h))

        dst = cv2.undistort(image, self.camera_matrix, self.dist_coeffs, None, new_camera_mtx)
        x, y, w, h = roi
        dst = dst[y:y + h, x:x + w]
        return dst
                    
    def find_best_roll(self, image, angle_range=(-15, 15), step=0.5):
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        edges = cv2.Canny(gray, 50, 150, apertureSize=3)

        lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100)

        if lines is None:
            print("No lines detected!")
            return 0.0

        angles = []
        for rho, theta in lines[:, 0]:
            angle = (theta * 180 / np.pi)  # Convert from radians to degrees
            if angle < -90:
                angle += 180
            if angle > 90:
                angle -= 180
            angles.append(angle)	

        if len(angles) == 0:
            print("No valid near-horizontal lines found!")
            return 0.0

        avg_angle = np.mean(angles)
        return avg_angle

              
if __name__ == "__main__":
    calibrator = PiCarX_CameraCalibrator()

    # Step 1: Uncomment to capture new calibration images
    #calibrator.capture_calibration_images(num_images=15)

    # Step 2: Calibrate camera
    # calibrator.calibrate_camera()

    # Step 3: Test undistortion
    picam2 = Picamera2()
    config = picam2.create_still_configuration(main={"size": (640, 480)})
    picam2.configure(config)
    picam2.start()
    time.sleep(10)

    # Capture image
    image = picam2.capture_array()
    cv2.imshow("Original Image", image)
    
    # After capturing
    print("Checking captured image...")
    if image is None:
        raise ValueError("Image is None — failed to capture anything.")
    print("Image shape:", image.shape)
    print("Image dtype:", image.dtype)
    print("Image type:", type(image))

    if len(image.shape) != 3 or image.shape[2] != 3:
       raise ValueError(f"Bad image shape: {image.shape}")

    if not isinstance(image, np.ndarray):
        raise ValueError(f"Image is not a numpy array: {type(image)}")

    if image.dtype != np.uint8: 
       print(f"Fixing image dtype {image.dtype} to uint8")
       image = image.astype(np.uint8)

    print("Image is valid, proceeding to process.")


    picam2.close()

    # Find best roll correction
    best_roll_angle = calibrator.find_best_roll(image)

    print(f"Best roll correction angle: {best_roll_angle:.2f} degrees")

    # Correct the image 
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle= -best_roll_angle, scale=1.0) 
    leveled_image = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
    
    #By running this code multiple times through multiple distances we found that the horizon tilt in our camera is -2.5
    print("the tilt in our camera is -2.5")

    # Show
    cv2.imshow("Leveled Image", leveled_image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
