#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
成绩走势图（Tkinter + Matplotlib）
- 左侧：学生列表
- 右侧：图表 + 成绩表格
- 顶部工具：载入/保存 CSV、添加记录、导出图像
CSV 格式示例（UTF-8，不含 BOM）:
Student,Term,Subject,Score
张三,2021-09,数学,85
张三,2021-12,数学,88
李四,2021-09,语文,78
"""

import statistics
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import csv
import os
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, simpledialog
from collections import defaultdict, OrderedDict
import matplotlib
matplotlib.use("TkAgg")

# ---------- 示例数据 ----------
SAMPLE_DATA = [
    {"Student": "张三", "Term": "2021-09", "Subject": "数学", "Score": "85"},
    {"Student": "张三", "Term": "2021-12", "Subject": "数学", "Score": "88"},
    {"Student": "张三", "Term": "2022-06", "Subject": "数学", "Score": "90"},
    {"Student": "张三", "Term": "2021-09", "Subject": "语文", "Score": "78"},
    {"Student": "张三", "Term": "2021-12", "Subject": "语文", "Score": "82"},
    {"Student": "李四", "Term": "2021-09", "Subject": "数学", "Score": "72"},
    {"Student": "李四", "Term": "2021-12", "Subject": "数学", "Score": "75"},
    {"Student": "李四", "Term": "2022-06", "Subject": "数学", "Score": "80"},
]

# ---------- 应用 ----------


class GradeTrendApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("成绩走势图")
        self.geometry("1000x650")
        self.minsize(800, 500)

        # 数据存储为列表的 dict
        self.data = SAMPLE_DATA.copy()

        self.build_ui()
        self.refresh_student_list()

    def build_ui(self):
        # 顶部工具栏
        toolbar = ttk.Frame(self, padding=6)
        toolbar.pack(side=tk.TOP, fill=tk.X)

        ttk.Button(toolbar, text="载入 CSV", command=self.load_csv).pack(
            side=tk.LEFT, padx=4)
        ttk.Button(toolbar, text="保存 CSV", command=self.save_csv).pack(
            side=tk.LEFT, padx=4)
        ttk.Button(toolbar, text="添加记录", command=self.add_record_dialog).pack(
            side=tk.LEFT, padx=4)
        ttk.Button(toolbar, text="删除选中记录", command=self.delete_selected_records).pack(
            side=tk.LEFT, padx=4)
        ttk.Button(toolbar, text="导出图像 (PNG)",
                   command=self.export_png).pack(side=tk.LEFT, padx=4)

        # 中间主区域：左学生列表，右图表与表格
        main = ttk.Frame(self)
        main.pack(fill=tk.BOTH, expand=True, padx=6, pady=6)

        # 左侧学生列表
        left = ttk.Frame(main, width=220)
        left.pack(side=tk.LEFT, fill=tk.Y)
        ttk.Label(left, text="学生列表", font=(
            "Segoe UI", 10, "bold")).pack(anchor=tk.W)
        self.search_var = tk.StringVar()
        search_entry = ttk.Entry(left, textvariable=self.search_var)
        search_entry.pack(fill=tk.X, pady=(4, 6))
        search_entry.bind(
            "<KeyRelease>", lambda e: self.refresh_student_list())

        self.student_listbox = tk.Listbox(left, exportselection=False)
        self.student_listbox.pack(fill=tk.BOTH, expand=True)
        self.student_listbox.bind(
            "<<ListboxSelect>>", lambda e: self.on_student_selected())

        # 右侧
        right = ttk.Frame(main)
        right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        # 上半：图表区域 + subject 选择 + 平滑开关
        plot_top = ttk.Frame(right)
        plot_top.pack(fill=tk.BOTH, expand=True)

        control_row = ttk.Frame(plot_top)
        control_row.pack(fill=tk.X, pady=(0, 4))
        ttk.Label(control_row, text="科目:").pack(side=tk.LEFT)
        self.subject_var = tk.StringVar(value="全部科目")
        self.subject_combo = ttk.Combobox(
            control_row, textvariable=self.subject_var, state="readonly", width=30)
        self.subject_combo.pack(side=tk.LEFT, padx=(4, 8))
        self.subject_combo.bind("<<ComboboxSelected>>",
                                lambda e: self.update_plot())

        self.smooth_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(control_row, text="显示移动平均(窗口=3)", variable=self.smooth_var,
                        command=self.update_plot).pack(side=tk.LEFT, padx=6)

        # Matplotlib 图
        self.fig = Figure(figsize=(6, 4), dpi=100)
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("成绩趋势")
        self.ax.set_xlabel("Term")
        self.ax.set_ylabel("Score")
        self.canvas = FigureCanvasTkAgg(self.fig, master=plot_top)
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 下半：成绩表（Treeview）
        ttk.Label(right, text="成绩明细", font=("Segoe UI", 10, "bold")
                  ).pack(anchor=tk.W, pady=(6, 0))
        self.tree = ttk.Treeview(right, columns=(
            "Term", "Subject", "Score"), show="headings", selectmode="extended")
        self.tree.heading("Term", text="Term")
        self.tree.heading("Subject", text="Subject")
        self.tree.heading("Score", text="Score")
        self.tree.pack(fill=tk.BOTH, expand=True)
        self.tree.bind("<Double-1>", self.on_tree_double_click)

        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status = ttk.Label(self, textvariable=self.status_var,
                           relief=tk.SUNKEN, anchor=tk.W)
        status.pack(side=tk.BOTTOM, fill=tk.X)

    # ---------- 数据/文件操作 ----------
    def load_csv(self):
        fn = filedialog.askopenfilename(title="打开 CSV 文件", filetypes=[
                                        ("CSV 文件", "*.csv"), ("所有文件", "*.*")])
        if not fn:
            return
        try:
            with open(fn, newline='', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                rows = []
                for r in reader:
                    # 期望字段 Student, Term, Subject, Score
                    if not all(k in r for k in ("Student", "Term", "Subject", "Score")):
                        raise ValueError(
                            "CSV 必须包含列：Student, Term, Subject, Score")
                    rows.append({"Student": r["Student"].strip(), "Term": r["Term"].strip(),
                                 "Subject": r["Subject"].strip(), "Score": r["Score"].strip()})
            self.data = rows
            self.refresh_student_list()
            self.status_var.set(f"已载入 {len(rows)} 条记录：{os.path.basename(fn)}")
        except Exception as e:
            messagebox.showerror("加载失败", str(e))

    def save_csv(self):
        fn = filedialog.asksaveasfilename(
            title="保存 CSV", defaultextension=".csv", filetypes=[("CSV 文件", "*.csv")])
        if not fn:
            return
        try:
            with open(fn, "w", newline='', encoding='utf-8') as f:
                fieldnames = ["Student", "Term", "Subject", "Score"]
                writer = csv.DictWriter(f, fieldnames=fieldnames)
                writer.writeheader()
                for r in self.data:
                    writer.writerow(r)
            self.status_var.set(f"已保存到 {os.path.basename(fn)}")
        except Exception as e:
            messagebox.showerror("保存失败", str(e))

    # ---------- 学生与表格刷新 ----------
    def refresh_student_list(self):
        self.student_listbox.delete(0, tk.END)
        q = self.search_var.get().strip().lower()
        students = sorted({r["Student"] for r in self.data})
        filtered = [s for s in students if q in s.lower()]
        for s in filtered:
            self.student_listbox.insert(tk.END, s)
        # clear selection / plot if nothing selected
        if filtered:
            self.student_listbox.selection_set(0)
            self.on_student_selected()
        else:
            self.clear_plot_and_table()

    def on_student_selected(self):
        sel = self.student_listbox.curselection()
        if not sel:
            self.clear_plot_and_table()
            return
        student = self.student_listbox.get(sel[0])
        # update subject combo with available subjects for this student
        subjects = sorted({r["Subject"]
                          for r in self.data if r["Student"] == student})
        combo_values = ["全部科目"] + subjects
        self.subject_combo['values'] = combo_values
        if self.subject_var.get() not in combo_values:
            self.subject_var.set("全部科目")
        self.refresh_table(student)
        self.update_plot()

    def refresh_table(self, student):
        # 清空表格
        for it in self.tree.get_children():
            self.tree.delete(it)
        # 插入该学生记录，按 Term 排序（字符串排序）
        recs = [r for r in self.data if r["Student"] == student]
        recs.sort(key=lambda x: x["Term"])
        for r in recs:
            self.tree.insert("", tk.END, values=(
                r["Term"], r["Subject"], r["Score"]))

    def clear_plot_and_table(self):
        self.ax.clear()
        self.ax.set_title("成绩趋势")
        self.canvas.draw()
        for it in self.tree.get_children():
            self.tree.delete(it)
        self.subject_combo['values'] = []
        self.title("成绩走势图")

    # ---------- 绘图 ----------
    def update_plot(self):
        sel = self.student_listbox.curselection()
        if not sel:
            return
        student = self.student_listbox.get(sel[0])
        subject_choice = self.subject_var.get()
        # collect terms sorted
        recs = [r for r in self.data if r["Student"] == student]
        if not recs:
            self.clear_plot_and_table()
            return
        terms = sorted({r["Term"] for r in recs})
        term_index = {t: i for i, t in enumerate(terms)}

        self.ax.clear()
        self.ax.set_title(f"{student} 的成绩趋势")
        self.ax.set_xlabel("Term")
        self.ax.set_ylabel("Score")
        self.ax.set_xticks(range(len(terms)))
        self.ax.set_xticklabels(terms, rotation=30, ha='right')

        # prepare data per subject
        # subject -> list of scores aligned to terms
        subj_map = defaultdict(lambda: [None]*len(terms))
        for r in recs:
            t = r["Term"]
            subj = r["Subject"]
            try:
                score = float(r["Score"])
            except:
                continue
            subj_map[subj][term_index[t]] = score

        colors = ["#1f77b4", "#ff7f0e", "#2ca02c",
                  "#d62728", "#9467bd", "#8c564b"]
        plotted = False
        if subject_choice == "全部科目":
            for i, (subj, arr) in enumerate(sorted(subj_map.items())):
                xs = [idx for idx, val in enumerate(arr) if val is not None]
                ys = [val for val in arr if val is not None]
                if not xs:
                    continue
                plotted = True
                self.ax.plot(xs, ys, marker='o', label=subj,
                             color=colors[i % len(colors)])
                if self.smooth_var.get() and len(ys) >= 3:
                    ma = self.moving_average(ys, window=3)
                    # ma aligns center; to avoid complexity, plot over xs with same length
                    if len(ma) == len(ys):
                        self.ax.plot(xs, ma, linestyle='--',
                                     color=colors[i % len(colors)])
            if plotted:
                self.ax.legend()
        else:
            arr = subj_map.get(subject_choice)
            if arr:
                xs = [idx for idx, val in enumerate(arr) if val is not None]
                ys = [val for val in arr if val is not None]
                if xs:
                    plotted = True
                    self.ax.plot(xs, ys, marker='o', color=colors[0])
                    if self.smooth_var.get() and len(ys) >= 3:
                        ma = self.moving_average(ys, window=3)
                        if len(ma) == len(ys):
                            self.ax.plot(xs, ma, linestyle='--',
                                         color=colors[0])
        if not plotted:
            self.ax.text(0.5, 0.5, "没有可绘制的数据", ha='center',
                         va='center', transform=self.ax.transAxes)
        self.fig.tight_layout()
        self.canvas.draw()

    def moving_average(self, data, window=3):
        # simple centered moving average where possible, fallback to simple moving average
        if window <= 1:
            return data
        res = []
        n = len(data)
        for i in range(n):
            # center at i, take symmetric window as much as possible
            left = max(0, i - window//2)
            right = min(n, left + window)
            window_vals = data[left:right]
            res.append(statistics.mean(window_vals))
        return res

    # ---------- 记录添加/删除 ----------
    def add_record_dialog(self):
        dlg = AddRecordDialog(self)
        self.wait_window(dlg)
        if dlg.result:
            rec = dlg.result
            # basic validation
            if not (rec.get("Student") and rec.get("Term") and rec.get("Subject") and rec.get("Score")):
                messagebox.showwarning("无效", "请填写全部字段")
                return
            try:
                float(rec["Score"])
            except:
                messagebox.showwarning("无效", "Score 必须为数字")
                return
            self.data.append(rec)
            self.refresh_student_list()
            self.status_var.set("已添加记录")

    def delete_selected_records(self):
        sel_items = self.tree.selection()
        if not sel_items:
            messagebox.showinfo("提示", "请先在下方成绩明细中选择要删除的记录（可多选）")
            return
        sel = self.student_listbox.curselection()
        if not sel:
            return
        student = self.student_listbox.get(sel[0])
        to_delete = []
        for it in sel_items:
            vals = self.tree.item(it, "values")
            term, subj, score = vals
            # 删除与该三元组匹配的第一条数据
            for idx, r in enumerate(self.data):
                if r["Student"] == student and r["Term"] == term and r["Subject"] == subj and r["Score"] == str(score):
                    to_delete.append(idx)
                    break
        # 删除时按倒序索引删除
        for idx in sorted(to_delete, reverse=True):
            self.data.pop(idx)
        self.refresh_student_list()
        self.status_var.set(f"删除了 {len(to_delete)} 条记录")

    def on_tree_double_click(self, event):
        # 双击编辑 Score 单元（简单处理）
        sel = self.tree.selection()
        if not sel:
            return
        item = sel[0]
        vals = self.tree.item(item, "values")
        term, subj, score = vals
        new_score = simpledialog.askstring(
            "编辑分数", f"修改 {term} / {subj} 的分数：", initialvalue=score, parent=self)
        if new_score is None:
            return
        try:
            float(new_score)
        except:
            messagebox.showwarning("无效", "请输入数字")
            return
        # 更新数据中第一个匹配项
        student = self.student_listbox.get(
            self.student_listbox.curselection()[0])
        for r in self.data:
            if r["Student"] == student and r["Term"] == term and r["Subject"] == subj and r["Score"] == str(score):
                r["Score"] = str(new_score)
                break
        self.refresh_student_list()
        self.status_var.set("已更新分数")

    # ---------- 导出图像 ----------
    def export_png(self):
        fn = filedialog.asksaveasfilename(
            title="导出为 PNG", defaultextension=".png", filetypes=[("PNG 图像", "*.png")])
        if not fn:
            return
        try:
            self.fig.savefig(fn, dpi=150)
            self.status_var.set(f"已导出图像：{os.path.basename(fn)}")
        except Exception as e:
            messagebox.showerror("导出失败", str(e))


# 简易添加记录对话框
class AddRecordDialog(tk.Toplevel):
    def __init__(self, parent):
        super().__init__(parent)
        self.title("添加记录")
        self.transient(parent)
        self.grab_set()
        self.result = None

        frm = ttk.Frame(self, padding=10)
        frm.pack(fill=tk.BOTH, expand=True)

        ttk.Label(frm, text="Student:").grid(
            row=0, column=0, sticky=tk.W, pady=4)
        self.student_e = ttk.Entry(frm)
        self.student_e.grid(row=0, column=1, sticky=tk.EW, pady=4)

        ttk.Label(frm, text="Term:").grid(row=1, column=0, sticky=tk.W, pady=4)
        self.term_e = ttk.Entry(frm)
        self.term_e.grid(row=1, column=1, sticky=tk.EW, pady=4)
        ttk.Label(frm, text="示例: 2021-09 或 第1学期").grid(row=1,
                                                       column=2, sticky=tk.W, padx=6)

        ttk.Label(frm, text="Subject:").grid(
            row=2, column=0, sticky=tk.W, pady=4)
        self.subj_e = ttk.Entry(frm)
        self.subj_e.grid(row=2, column=1, sticky=tk.EW, pady=4)

        ttk.Label(frm, text="Score:").grid(
            row=3, column=0, sticky=tk.W, pady=4)
        self.score_e = ttk.Entry(frm)
        self.score_e.grid(row=3, column=1, sticky=tk.EW, pady=4)

        btns = ttk.Frame(frm)
        btns.grid(row=4, column=0, columnspan=3, pady=(8, 0))
        ttk.Button(btns, text="取消", command=self.on_cancel).pack(
            side=tk.RIGHT, padx=6)
        ttk.Button(btns, text="添加", command=self.on_add).pack(side=tk.RIGHT)

        frm.columnconfigure(1, weight=1)
        self.protocol("WM_DELETE_WINDOW", self.on_cancel)

    def on_add(self):
        s = self.student_e.get().strip()
        t = self.term_e.get().strip()
        sub = self.subj_e.get().strip()
        sc = self.score_e.get().strip()
        if not (s and t and sub and sc):
            messagebox.showwarning("提示", "请填写全部字段", parent=self)
            return
        self.result = {"Student": s, "Term": t, "Subject": sub, "Score": sc}
        self.destroy()

    def on_cancel(self):
        self.result = None
        self.destroy()


if __name__ == "__main__":
    app = GradeTrendApp()
    app.mainloop()
