# interactive_3d_lines.py
# 绘制 3D 折线（多条折线），用于展示随考试次数/时间变动的成绩
# 支持两种数据格式：
# 1) 宽表（每一行是一次考试，列为 exam, Chinese, Math, English）
# 2) 长表（每行包含 exam, student/subject, score），会自动 pivot

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import webbrowser
import os
import tempfile

# ---------- 配置 ----------
csv_path = None  # 如果有 CSV，填路径，例如 "scores_wide.csv" 或 "scores_long.csv"
# 选择绘制维度: "subject" 表示每门科目一条线； "student" 表示每个学生一条线
mode = "subject"  # 可选 "subject" 或 "student"

# ---------- 示例数据 ----------
if csv_path is None:
    # 示例 A: 多科随考试变化（宽表）
    exams = ["Exam 1", "Exam 2", "Exam 3", "Exam 4"]
    df_wide = pd.DataFrame({
        "Exam": exams,
        "Chinese": [78, 82, 85, 88],
        "Math":    [90, 87, 92, 95],
        "English": [74, 79, 80, 83],
    })
    # 示例 B: 多学生随考试变化（长表）
    df_long = pd.DataFrame({
        "Exam": np.repeat(exams, 3),
        "Student": ["Alice","Bob","Charlie"] * len(exams),
        "Chinese": [78,75,80, 82,79,85, 85,83,82, 88,86,90],
        "Math":    [90,88,85, 87,86,84, 92,90,88, 95,94,92],
        "English": [74,70,78, 79,76,80, 80,78,77, 83,81,85],
    })
    # For the "student" mode, convert df_long to long format per subject
    # We'll build a long table with columns: Exam, Student, Subject, Score
    rows = []
    for _, r in df_long.iterrows():
        for subj in ["Chinese","Math","English"]:
            rows.append({"Exam": r["Exam"], "Student": r["Student"], "Subject": subj, "Score": r[subj]})
    df_long_formatted = pd.DataFrame(rows)
    # choose data based on mode
    if mode == "subject":
        df = df_wide.copy()
    else:
        df = df_long_formatted.copy()
else:
    df = pd.read_csv(csv_path)

# ---------- 数据处理 ----------
# Normalize exam ordering and numeric x positions
# We will map each exam to an x coordinate 0..N-1
if "Exam" not in df.columns:
    raise ValueError("数据必须包含 'Exam' 列（考试名称/时间）。")

exam_order = list(pd.Categorical(df["Exam"], categories=pd.unique(df["Exam"]), ordered=True).categories)
exam_to_x = {e: i for i, e in enumerate(exam_order)}
x_ticks = list(range(len(exam_order)))

# Build lines depending on mode
traces = []
colors = [
    "#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A", "#19D3F3",
    "#FF6692", "#B6E880", "#FF97FF", "#FECB52"
]

if mode == "subject":
    # Expect wide table with columns: Exam, Chinese, Math, English (or generic subjects)
    # Subjects are all columns except 'Exam'
    subjects = [c for c in df.columns if c != "Exam"]
    # Ensure exam order: pivot by Exam
    df_sorted = df.set_index("Exam").reindex(exam_order).reset_index()
    for idx, subj in enumerate(subjects):
        y = [idx] * len(exam_order)  # put each subject on a different y coordinate (so lines are separated in Y)
        z = df_sorted[subj].astype(float).tolist()
        x = x_ticks
        traces.append(go.Scatter3d(
            x=x, y=y, z=z, mode='lines+markers',
            name=subj,
            line=dict(color=colors[idx % len(colors)], width=4),
            marker=dict(size=4)
        ))
    y_tickvals = list(range(len(subjects)))
    y_ticktext = subjects
    title = "3D 折线 — 每门科目随考试变化"
else:
    # mode == "student"
    # Expect long table with columns: Exam, Student, Subject, Score
    # We'll aggregate per student by stacking subjects if desired, or draw one line per student per subject.
    if not set(["Student","Subject","Score"]).issubset(df.columns):
        raise ValueError("长表格式需要包含列: 'Exam','Student','Subject','Score'。")
    # We'll draw one line per (Student,Subject), and offset y by subject so lines for different subjects are separated.
    subjects = sorted(df["Subject"].unique())
    students = sorted(df["Student"].unique())
    subject_to_base_y = {subj: i*len(students) for i, subj in enumerate(subjects)}  # separate subjects
    color_map = {s: colors[i % len(colors)] for i, s in enumerate(students)}
    for si, student in enumerate(students):
        for sj, subj in enumerate(subjects):
            # extract scores for this student & subject across exams
            sel = df[(df["Student"] == student) & (df["Subject"] == subj)]
            # reindex by exam_order to keep alignment
            sel = sel.set_index("Exam").reindex(exam_order).reset_index()
            z = sel["Score"].astype(float).tolist()
            x = x_ticks
            # y position: base for subject + student offset
            y_base = subject_to_base_y[subj]
            y = [y_base + si for _ in x]
            traces.append(go.Scatter3d(
                x=x, y=y, z=z, mode='lines+markers',
                name=f"{student} - {subj}",
                line=dict(color=color_map[student], width=3),
                marker=dict(size=3)
            ))
    # build y ticks: one tick per student block per subject start
    y_tickvals = []
    y_ticktext = []
    for subj in subjects:
        for si, student in enumerate(students):
            y_tickvals.append(subject_to_base_y[subj] + si)
            y_ticktext.append(f"{subj}:{student}")
    title = "3D 折线 — 每个学生各科成绩随考试变化"

# ---------- 绘图布局 ----------
layout = go.Layout(
    title=title,
    scene=dict(
        xaxis=dict(title="Exam", tickvals=x_ticks, ticktext=exam_order),
        yaxis=dict(title="Subject / Student", tickvals=y_tickvals if 'y_tickvals' in locals() else y_tickvals, ticktext=y_ticktext if 'y_ticktext' in locals() else y_ticktext),
        zaxis=dict(title="Score", range=[0, 100])
    ),
    margin=dict(l=0, r=0, t=40, b=0),
    legend=dict(itemsizing='constant')
)

fig = go.Figure(data=traces, layout=layout)

# 保存并打开 html（或在 notebook 中使用 fig.show()）
tmpfile = os.path.join(tempfile.gettempdir(), "scores_3d_lines.html")
pio.write_html(fig, file=tmpfile, auto_open=False)
print(f"Saved interactive plot to: {tmpfile}")
webbrowser.open("file://" + tmpfile)

# 在 Jupyter/Colab 中直接显示：注释上面保存/打开两行，使用 fig.show()
# fig.show()