# omr_camera_preprocess.py
"""
Camera-photo preprocessing for OMR answer sheets.

Drop this file next to omr_grader.py.  It is imported automatically when
omr_grader.py is run – no CLI changes needed.

The module adds five corrections that turn a hand-held camera photo of an
answer sheet into something the existing contour-based OMR engine can read:

  1. Illumination normalisation  – CLAHE on the L* channel evens out flash
                                   hotspots and edge shadows.
  2. Page segmentation           – the largest white-ish quad (the answer sheet
                                   itself) is found in the image.
  3. Perspective de-warp         – a four-corner homography straightens the
                                   trapezoidal distortion from an angled shot.
  4. Fine deskew                 – Hough-line angle correction removes residual
                                   rotation after de-warping.
  5. Unsharp-mask sharpening     – recovers fine detail (pencil marks, cell
                                   borders) lost to camera JPEG compression.

None of these steps are applied to flat-scan images (detected by uniform bright
background), so the existing scanning behaviour is fully preserved.

Public API
----------
maybe_deskew_camera_photo(img, max_skew_deg=30.0, debug=False) -> np.ndarray

    img           : BGR numpy array produced by pdf_to_images() or cv2.imread()
    max_skew_deg  : rotations beyond this magnitude are ignored (safety guard)
    debug         : write /tmp/_omr_camera_corrected.png for visual inspection

    Returns a corrected BGR array ready to pass into process_image().
"""

from __future__ import annotations

import math
import statistics
from typing import Optional

import cv2
import numpy as np


# ===========================================================================
#  Public entry point
# ===========================================================================

def maybe_deskew_camera_photo(img: np.ndarray,
                               max_skew_deg: float = 30.0,
                               debug: bool = False) -> np.ndarray:
    """
    Full camera-photo correction pipeline.

    Safe to call on every image: the heuristic in *_looks_like_flat_scan*
    returns the original array unchanged for clean flatbed scans, so there is
    no performance penalty for mixed batches.
    """
    if img is None or img.size == 0:
        return img

    if _looks_like_flat_scan(img):
        return img                          # untouched for clean scans

    img = _normalise_illumination(img)      # step 1
    warped = _extract_and_dewarp_page(img)  # steps 2 + 3
    if warped is not None:
        img = warped
    img = _deskew(img, max_skew_deg)        # step 4
    img = _unsharp_mask(img)               # step 5

    if debug:
        cv2.imwrite("/tmp/_omr_camera_corrected.png", img)

    return img


# ===========================================================================
#  Step 0 – flat-scan heuristic
# ===========================================================================

def _looks_like_flat_scan(img: np.ndarray) -> bool:
    """
    Return True when the image almost certainly came from a flatbed scanner:
      • corner regions are very bright (white paper, no shadow)
      • background brightness is uniform (low std-dev)
      • aspect ratio is close to A4 (1.414) or US Letter (1.294)

    Camera photos of paper on a desk typically fail at least one condition
    (dark desk visible, shadow gradient, non-standard crop).
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    H, W = gray.shape

    bh = max(1, H // 10)
    bw = max(1, W // 10)
    corners = [
        gray[:bh, :bw],
        gray[:bh, -bw:],
        gray[-bh:, :bw],
        gray[-bh:, -bw:],
    ]
    bg_mean = float(np.mean([c.mean() for c in corners]))
    bg_std  = float(np.mean([c.std()  for c in corners]))

    ar = H / W
    near_a4    = 1.25 <= ar <= 1.60   # portrait A4 / Letter (±10 %)
    bright_bg  = bg_mean > 205
    uniform_bg = bg_std  < 20

    return bright_bg and uniform_bg and near_a4


# ===========================================================================
#  Step 1 – illumination normalisation
# ===========================================================================

def _normalise_illumination(img: np.ndarray,
                             clip_limit: float = 2.0,
                             tile_size: int = 8) -> np.ndarray:
    """
    CLAHE on the L* channel of CIE L*a*b*.

    Mild clip_limit (2.0) lifts shadow areas without over-amplifying noise.
    tile_size=8 gives fine-enough locality for typical A4 phone shots.
    """
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=clip_limit,
                             tileGridSize=(tile_size, tile_size))
    l_eq = clahe.apply(l)
    return cv2.cvtColor(cv2.merge([l_eq, a, b]), cv2.COLOR_LAB2BGR)


# ===========================================================================
#  Steps 2 + 3 – page extraction and perspective de-warp
# ===========================================================================

def _extract_and_dewarp_page(img: np.ndarray) -> Optional[np.ndarray]:
    """
    Find the answer-sheet rectangle in the camera photo and warp it to a
    flat, front-facing view.

    Strategy
    --------
    1. Convert to gray, blur to suppress the internal red cell grid.
    2. Adaptive threshold → invert (dark background, white page foreground).
    3. Morphological close to seal any open contour edges.
    4. Find the largest external contour that covers ≥ 15 % of the image.
    5. Approximate it to a quadrilateral (4 corners).
    6. Compute a perspective transform from those 4 corners to a rectangle
       whose dimensions match the detected quad's width × height.

    Returns None if no convincing quad is found.
    """
    H, W = img.shape[:2]
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Blur strongly so that the fine red grid lines don't fragment the edge
    blur = cv2.GaussianBlur(gray, (9, 9), 0)

    # Adaptive threshold handles uneven lighting better than a global Otsu
    binary = cv2.adaptiveThreshold(
        blur, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY,
        blockSize=51,
        C=10,
    )
    binary = cv2.bitwise_not(binary)   # white paper → bright foreground

    # Close gaps so the page boundary is one solid closed contour
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (11, 11))
    closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=3)

    cnts, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL,
                                cv2.CHAIN_APPROX_SIMPLE)
    if not cnts:
        return None

    min_area = W * H * 0.15
    page_cnt = next(
        (c for c in sorted(cnts, key=cv2.contourArea, reverse=True)
         if cv2.contourArea(c) > min_area),
        None,
    )
    if page_cnt is None:
        return None

    peri   = cv2.arcLength(page_cnt, True)
    approx = cv2.approxPolyDP(page_cnt, 0.02 * peri, True)

    # Accept a 4-point poly; fall back to convex-hull approximation otherwise
    if len(approx) == 4:
        quad = approx.reshape(4, 2).astype(np.float32)
    else:
        hull       = cv2.convexHull(page_cnt)
        hull_approx = cv2.approxPolyDP(
            hull, 0.02 * cv2.arcLength(hull, True), True)
        if len(hull_approx) == 4:
            quad = hull_approx.reshape(4, 2).astype(np.float32)
        else:
            # Last resort: bounding rectangle
            rx, ry, rw, rh = cv2.boundingRect(page_cnt)
            quad = np.array(
                [[rx, ry], [rx + rw, ry],
                 [rx + rw, ry + rh], [rx, ry + rh]],
                dtype=np.float32,
            )

    # Sanity-check: quad must span most of the image
    qw = max(np.linalg.norm(quad[0] - quad[1]),
             np.linalg.norm(quad[2] - quad[3]))
    qh = max(np.linalg.norm(quad[1] - quad[2]),
             np.linalg.norm(quad[3] - quad[0]))
    if qw < W * 0.35 or qh < H * 0.35:
        return None

    quad = _order_quad(quad)

    # Output size: real width × height of the detected quad, capped at 3 000 px
    dst_w = int(max(np.linalg.norm(quad[1] - quad[0]),
                    np.linalg.norm(quad[2] - quad[3])))
    dst_h = int(max(np.linalg.norm(quad[2] - quad[1]),
                    np.linalg.norm(quad[3] - quad[0])))
    scale = min(1.0, 3000.0 / max(dst_w, dst_h, 1))
    dst_w = max(1, int(dst_w * scale))
    dst_h = max(1, int(dst_h * scale))

    src_pts = quad
    dst_pts = np.array(
        [[0, 0], [dst_w - 1, 0],
         [dst_w - 1, dst_h - 1], [0, dst_h - 1]],
        dtype=np.float32,
    )

    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    warped = cv2.warpPerspective(
        img, M, (dst_w, dst_h),
        flags=cv2.INTER_LANCZOS4,
        borderMode=cv2.BORDER_REPLICATE,
    )
    return warped


def _order_quad(pts: np.ndarray) -> np.ndarray:
    """
    Reorder four 2-D points to (top-left, top-right, bottom-right, bottom-left).
    Works for quads tilted up to ~45 °.
    """
    rect = np.zeros((4, 2), dtype=np.float32)
    s    = pts.sum(axis=1)
    diff = np.diff(pts, axis=1).ravel()
    rect[0] = pts[np.argmin(s)]      # top-left     (min x+y)
    rect[2] = pts[np.argmax(s)]      # bottom-right (max x+y)
    rect[1] = pts[np.argmin(diff)]   # top-right    (min y−x)
    rect[3] = pts[np.argmax(diff)]   # bottom-left  (max y−x)
    return rect


# ===========================================================================
#  Step 4 – fine deskew via Hough lines
# ===========================================================================

def _deskew(img: np.ndarray, max_skew_deg: float = 15.0) -> np.ndarray:
    """
    Measure residual rotation from the dominant near-horizontal Hough lines and
    counter-rotate the image.  Canvas is expanded to avoid clipping content.

    Angles beyond ±max_skew_deg are silently ignored; this prevents wild
    rotations on images where line detection went wrong.
    """
    gray  = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    H, W  = gray.shape
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)

    lines = cv2.HoughLinesP(
        edges,
        rho=1,
        theta=math.pi / 180,
        threshold=max(50, W // 15),
        minLineLength=W // 6,
        maxLineGap=W // 30,
    )
    if lines is None or len(lines) < 4:
        return img

    angles: list = []
    for x1, y1, x2, y2 in lines[:, 0]:
        dx, dy = x2 - x1, y2 - y1
        if dx == 0:
            continue
        ang = math.degrees(math.atan2(dy, dx))
        if abs(ang) <= max_skew_deg:
            angles.append(ang)
        elif abs(abs(ang) - 90) <= max_skew_deg:
            # Near-vertical line → its skew is (ang ± 90)
            angles.append(ang - 90 if ang > 0 else ang + 90)

    if not angles:
        return img

    angle = statistics.median(angles)
    if abs(angle) < 0.3:
        return img          # negligible: skip rotation

    cx, cy = W / 2.0, H / 2.0
    M      = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)

    # Expand output canvas so corners are not clipped
    cos_a  = abs(M[0, 0])
    sin_a  = abs(M[0, 1])
    new_W  = int(H * sin_a + W * cos_a)
    new_H  = int(H * cos_a + W * sin_a)
    M[0, 2] += (new_W - W) / 2.0
    M[1, 2] += (new_H - H) / 2.0

    return cv2.warpAffine(
        img, M, (new_W, new_H),
        flags=cv2.INTER_LANCZOS4,
        borderMode=cv2.BORDER_REPLICATE,
    )


# ===========================================================================
#  Step 5 – unsharp-mask sharpening
# ===========================================================================

def _unsharp_mask(img: np.ndarray,
                  sigma: float = 1.0,
                  strength: float = 1.5) -> np.ndarray:
    """
    Standard unsharp mask: enhances fine edges (pencil marks, cell borders)
    lost to camera JPEG / optical blur without amplifying large-scale noise.

    strength=1.5 is conservative; increase to 2.0–2.5 for very blurry shots.
    """
    blurred   = cv2.GaussianBlur(img, (0, 0), sigma)
    sharpened = cv2.addWeighted(img, 1.0 + strength, blurred, -strength, 0)
    return np.clip(sharpened, 0, 255).astype(np.uint8)
