# -*- coding: utf-8 -*-


"""
Simulation des lois de la réflexion et de la réfraction (Snell-Descartes)
avec prise en compte des coefficients de Fresnel pour l'intensité des rayons.
"""

# --- Importation des bibliothèques nécessaires ---
import numpy as np  # NumPy est utilisé pour les calculs mathématiques (sin, cos, pi...)
import matplotlib.pyplot as plt  # Matplotlib est la bibliothèque principale pour le tracé
from matplotlib.widgets import Slider  # Le module 'widgets' permet d'ajouter des éléments interactifs comme des sliders


I0 = 6 # "intensité du rayon incident ; épaisseur du trait
# --- Configuration initiale de la figure et des axes ---

# Crée une figure (la fenêtre) et un axe (la zone de dessin)
# figsize=(8, 8) pour avoir une fenêtre carrée
fig, ax = plt.subplots(figsize=(8, 8))

# Ajuste la position de la zone de dessin pour faire de la place aux sliders en bas
plt.subplots_adjust(bottom=0.35)

# Définit les limites des axes x et y pour centrer la scène
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)

# Assure que les échelles sur les axes x et y sont les mêmes (un cercle apparaîtrait comme un cercle)
ax.set_aspect('equal', adjustable='box')

# Masque les graduations et les bords du graphique pour un rendu plus propre
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)


# --- Tracé des éléments statiques ---

# Trace le dioptre : une ligne horizontale noire qui sépare les deux milieux
# ax.axhline trace une ligne horizontale sur tout l'axe
ax.axhline(0, color='black', lw=2, label='Dioptre')

# Trace la normale au dioptre : une ligne verticale en pointillés pour référence des angles
# ax.axvline trace une ligne verticale
ax.axvline(0, color='gray', linestyle='--', lw=1, label='Normale')


# --- Initialisation des objets graphiques des rayons ---
# On crée les objets 'Line2D' une seule fois. Ensuite, on mettra à jour leurs données (coordonnées)
# pour l'animation, ce qui est plus performant que de recréer des objets à chaque fois.
# Le rayon incident vient du milieu 1 (en haut, y > 0) et frappe l'origine (0,0).
incident_ray, = ax.plot([], [], 'r-', lw=2, label='Rayon Incident')
# Le rayon réfléchi part de l'origine et remonte dans le milieu 1.
reflected_ray, = ax.plot([], [], 'b-', lw=2, label='Rayon Réfléchi')
# Le rayon réfracté (ou transmis) part de l'origine et descend dans le milieu 2 (en bas, y < 0).
refracted_ray, = ax.plot([], [], 'g-', lw=2, label='Rayon Réfracté')

# Ajout d'un texte pour afficher les informations (angles, indices) - Décalé et agrandi
info_text = ax.text(0.02, 0.98, '', transform=ax.transAxes, verticalalignment='top', fontsize=12, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Ajout d'un texte spécifique pour la réflexion totale interne
tir_text = ax.text(0.5, -0.5, 'Réflexion Totale', horizontalalignment='center', color='red', fontsize=16, visible=False)

# Affiche les légendes définies lors de la création des lignes
ax.legend(loc='lower left')

# --- Variables pour stocker les artistes des angles (arcs et textes) ---
# On les stocke pour pouvoir les supprimer et les redessiner à chaque mise à jour.
angle_artists = []

# --- Fonction principale de mise à jour de l'animation ---
# Cette fonction est appelée à chaque fois qu'un slider est déplacé.
def update(val):
    """
    Recalcule et redessine la scène en fonction des nouvelles valeurs des sliders.
    """
    global angle_artists
    # Supprime les anciens dessins d'angles avant d'en créer de nouveaux
    for artist in angle_artists:
        artist.remove()
    angle_artists = []

    # Récupère les valeurs actuelles des sliders
    n1 = slider_n1.val
    n2 = slider_n2.val
    theta_i_deg = slider_theta_i.val  # Angle d'incidence en degrés
    
    # Conversion de l'angle d'incidence en radians pour les calculs avec NumPy
    theta_i_rad = np.deg2rad(theta_i_deg)

    # --- Calcul des angles selon les lois de Snell-Descartes ---

    # 1. Loi de la réflexion : l'angle de réflexion est égal à l'angle d'incidence.
    # Le rayon repart de l'autre côté de la normale, donc on prend l'opposé pour le sens de rotation.
    theta_refl_rad = -theta_i_rad
    
    # 2. Loi de la réfraction : n1 * sin(theta_i) = n2 * sin(theta_t)
    # On isole sin(theta_t) pour trouver l'angle de réfraction theta_t.
    # On utilise np.clip pour éviter les erreurs de calcul si la valeur est infimement supérieure à 1.
    arg_asin = np.clip((n1 / n2) * np.sin(theta_i_rad), -1.0, 1.0)
    
    # --- Gestion de la réflexion totale interne ---
    total_internal_reflection = (n1 > n2) and (np.abs((n1 / n2) * np.sin(theta_i_rad)) > 1.0)

    if total_internal_reflection:
        theta_refr_rad = None
        R = 1.0
        T = 0.0
        tir_text.set_visible(True)
    else:
        theta_refr_rad = np.arcsin(arg_asin)
        tir_text.set_visible(False)
        
        # --- Calcul des coefficients de Fresnel ---
        cos_i = np.cos(theta_i_rad)
        cos_t = np.cos(theta_refr_rad)
        Rs = ((n1 * cos_i - n2 * cos_t) / (n1 * cos_i + n2 * cos_t))**2
        Rp = ((n2 * cos_i - n1 * cos_t) / (n2 * cos_i + n1 * cos_t))**2
        R = (Rs + Rp) / 2.0
        T = 1.0 - R

    # --- Mise à jour des données des rayons pour le tracé ---
    incident_ray.set_data([-np.sin(theta_i_rad), 0], [np.cos(theta_i_rad), 0])
    incident_ray.set_linewidth(I0)
    
    reflected_ray.set_data([0, -np.sin(theta_refl_rad)], [0, np.cos(theta_refl_rad)])
    reflected_ray.set_linewidth(I0 * R)

    if theta_refr_rad is not None:
        refracted_ray.set_visible(True)
        refracted_ray.set_data([0, np.sin(theta_refr_rad)], [0, -np.cos(theta_refr_rad)])
        refracted_ray.set_linewidth(I0 * T)
    else:
        refracted_ray.set_visible(False)

    # --- NOUVEAU: Ajout des flèches sur les rayons ---
    # On utilise annotate pour dessiner une flèche au milieu de chaque rayon.
    # --- AMÉLIORATION: Ajout de flèches plus esthétiques sur les rayons ---
    
    # Flèche sur le rayon incident (pointe vers l'origine)
    arrow_start_i = (-0.55 * np.sin(theta_i_rad), 0.55 * np.cos(theta_i_rad))
    arrow_end_i = (-0.45 * np.sin(theta_i_rad), 0.45 * np.cos(theta_i_rad))
    # arrow_i = ax.annotate("", xy=arrow_end_i, xytext=arrow_start_i, 
    #                       arrowprops=dict(arrowstyle="->", color="red", lw=1, head_width=I0, head_length=I0*1.5))
    arrow_i = ax.annotate("", xy=arrow_end_i, xytext=arrow_start_i, 
                          arrowprops=dict(arrowstyle="->", color="red", lw=I0/2, mutation_scale=5+6*I0))
    angle_artists.append(arrow_i)

    # Flèche sur le rayon réfléchi (part de l'origine)
    ray_lw_r = I0 * R
    # if ray_lw_r > 1.0:  # Seuil pour une bonne visibilité
    x_refl_end = -np.sin(theta_refl_rad)
    y_refl_end = np.cos(theta_refl_rad)
    arrow_start_r = (0.45 * x_refl_end, 0.45 * y_refl_end)
    arrow_end_r = (0.55 * x_refl_end, 0.55 * y_refl_end)
    arrow_r = ax.annotate("", xy=arrow_end_r, xytext=arrow_start_r,
                      arrowprops=dict(arrowstyle="->", color="blue", lw=ray_lw_r/2, mutation_scale=5+6*ray_lw_r))
    # arrow_r = ax.annotate("", xy=arrow_end_r, xytext=arrow_start_r, 
    #                       arrowprops=dict(arrowstyle="->", color="blue", lw=1, head_width=ray_lw_r, head_length=ray_lw_r*1.5))
    angle_artists.append(arrow_r)

    # Flèche sur le rayon réfracté (part de l'origine)
    # if theta_refr_rad is not None:
    ray_lw_t = I0 * T
    if ray_lw_t > 1.0: # Seuil pour une bonne visibilité
        x_refr_end = np.sin(theta_refr_rad)
        y_refr_end = -np.cos(theta_refr_rad)
        arrow_start_t = (0.45 * x_refr_end, 0.45 * y_refr_end)
        arrow_end_t = (0.55 * x_refr_end, 0.55 * y_refr_end)
        # arrow_t = ax.annotate("", xy=arrow_end_t, xytext=arrow_start_t, 
        #                       arrowprops=dict(arrowstyle="->", color="green", lw=1, head_width=ray_lw_t, head_length=ray_lw_t*1.5))
        arrow_t = ax.annotate("", xy=arrow_end_t, xytext=arrow_start_t, 
                              arrowprops=dict(arrowstyle="->", color="green", lw=ray_lw_t/2, mutation_scale=5+6*ray_lw_t))
        angle_artists.append(arrow_t)
        
    # --- Dessin des arcs et des labels pour les angles ---
    arc_radius = 0.3
    text_radius = 0.25
    
    # Angle d'incidence
    if theta_i_deg != 0:
        arc_i = ax.annotate("",
            xy=(arc_radius * -np.sin(theta_i_rad), arc_radius * np.cos(theta_i_rad)),
            xytext=(0, arc_radius),
            arrowprops=dict(arrowstyle="->", color="red", connectionstyle=f"arc3,rad={0.2*np.sign(theta_i_rad)}"))
        text_i = ax.text(text_radius * -np.sin(theta_i_rad/2), text_radius * np.cos(theta_i_rad/2), r"$\theta_1$", color="red", fontsize=14, ha='center', va='center')
        angle_artists.extend([arc_i, text_i])

    # Angle de réflexion
    if theta_i_deg != 0:
        arc_r = ax.annotate("",
            xy=(arc_radius * np.sin(theta_i_rad), arc_radius * np.cos(theta_i_rad)),
            xytext=(0, arc_radius),
            arrowprops=dict(arrowstyle="->", color="blue", connectionstyle=f"arc3,rad={-0.2*np.sign(theta_i_rad)}"))
        text_r = ax.text(text_radius * np.sin(theta_i_rad/2), text_radius * np.cos(theta_i_rad/2), r"$\theta'_1$", color="blue", fontsize=14, ha='center', va='center')
        angle_artists.extend([arc_r, text_r])
        
    # Angle de réfraction
    if theta_refr_rad is not None and np.abs(np.rad2deg(theta_refr_rad)) > 0.1:
        arc_t = ax.annotate("",
            xytext=(arc_radius * np.sin(theta_refr_rad), -arc_radius * np.cos(theta_refr_rad)),
            xy=(0, -arc_radius),
            arrowprops=dict(arrowstyle="<-", color="green", connectionstyle=f"arc3,rad={-0.2*np.sign(theta_refr_rad)}"))
        text_t = ax.text(text_radius * np.sin(theta_refr_rad/2), -text_radius * np.cos(theta_refr_rad/2), r"$\theta_2$", color="green", fontsize=14, ha='center', va='center')
        angle_artists.extend([arc_t, text_t])

    # --- Mise à jour du texte d'information ---
    info_str = (
        f'Milieu 1: n₁ = {n1:.2f}\n'
        f'Milieu 2: n₂ = {n2:.2f}\n'
        f'Angle incident: θ₁ = {theta_i_deg:.1f}°\n'
    )
    if theta_refr_rad is not None:
        info_str += f'Angle réfracté: θ₂ = {np.rad2deg(theta_refr_rad):.1f}°\n'
        info_str += f'Intensité (Réf./Réfra.): {R*100:.1f}% / {T*100:.1f}%'
    else:
        info_str += 'Réflexion Totale'
        
    info_text.set_text(info_str)

    # Redessine la figure avec les nouvelles données
    fig.canvas.draw_idle()


# --- Création des sliders ---
ax_n1 = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_n2 = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_theta_i = plt.axes([0.25, 0.1, 0.65, 0.03])

slider_n1 = Slider(ax=ax_n1, label='Indice n₁ (haut)', valmin=1.0, valmax=3.0, valinit=1.0, valstep=0.01)
slider_n2 = Slider(ax=ax_n2, label='Indice n₂ (bas)', valmin=1.0, valmax=3.0, valinit=1.5, valstep=0.01)
slider_theta_i = Slider(ax=ax_theta_i, label='Angle Incident θ₁ (°)', valmin=-89.9, valmax=89.9, valinit=30, valstep=0.1)

# --- Connexion des sliders à la fonction de mise à jour ---
slider_n1.on_changed(update)
slider_n2.on_changed(update)
slider_theta_i.on_changed(update)

# --- Appel initial pour dessiner la première image ---
update(None)

# --- Affichage de la fenêtre ---
plt.show()

