import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import numpy as np
import threading
import time
import sys
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
from matplotlib import cm


class ComplexMountainGenerator:
    """复杂3D分形山脉生成器"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("3D 分形山脉生成器 - 高级图形系统")
        self.root.geometry("1400x900")
        self.root.configure(bg='#2c3e50')
        
        # 数据变量
        self.heightmap = None
        self.current_surf = None
        self.render_thread = None
        self.is_rendering = False
        self.progress_var = tk.StringVar(value="就绪")
        self.progress_val = tk.DoubleVar(value=0)
        
        # 默认参数
        self.params = {
            'size': 257,          # 网格大小 (2^n+1)
            'roughness': 0.65,    # 粗糙度
            'scale': 100.0,       # 缩放比例
            'seed': 42,           # 随机种子
            'color_map': 'terrain',
            'elevation_factor': 1.5,
            'water_level': 0.2,
            'octaves': 5,
            'persistence': 0.5,
            'lacunarity': 2.0
        }
        
        # 相机控制参数
        self.camera_params = {
            'elev': 25,
            'azim': -60,
            'dist': 10
        }
        
        # 存储变量引用的字典
        self.param_vars = {}
        
        self.setup_ui()
        self.setup_shortcuts()
        
    def setup_ui(self):
        """构建用户界面"""
        # 主框架
        main_frame = tk.Frame(self.root, bg='#2c3e50')
        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # 控制面板 - 左侧
        control_panel = tk.Frame(main_frame, bg='#34495e', width=300)
        control_panel.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
        control_panel.pack_propagate(False)
        
        # 标题
        title_label = tk.Label(control_panel, text="⚡ 山脉生成控制台", 
                               font=('Arial', 16, 'bold'), 
                               bg='#34495e', fg='#ecf0f1')
        title_label.pack(pady=15)
        
        # 创建各个参数框架
        self.create_param_section(control_panel, "基础参数", [
            ('网格大小', 'size', ['257', '513', '1025']),
            ('粗糙度', 'roughness', (0.3, 0.95, 0.01)),
            ('缩放比例', 'scale', (50, 200, 5)),
            ('随机种子', 'seed', (1, 9999, 1))
        ])
        
        self.create_param_section(control_panel, "分形参数", [
            ('八度数', 'octaves', (1, 8, 1)),
            ('持久性', 'persistence', (0.3, 0.8, 0.01)),
            ('倍频程', 'lacunarity', (1.5, 3.0, 0.1))
        ])
        
        self.create_param_section(control_panel, "可视化", [
            ('色彩映射', 'color_map', ['terrain', 'viridis', 'plasma', 'coolwarm', 'jet', 'gist_earth']),
            ('海拔因子', 'elevation_factor', (0.5, 3.0, 0.05)),
            ('水位线', 'water_level', (0, 0.5, 0.01))
        ])
        
        self.create_camera_section(control_panel, "相机控制")
        
        # 控制按钮区域
        btn_frame = tk.Frame(control_panel, bg='#34495e')
        btn_frame.pack(fill=tk.X, padx=15, pady=10)
        
        self.generate_btn = tk.Button(btn_frame, text="🚀 生成山脉", 
                                      command=self.start_generation_thread,
                                      bg='#27ae60', fg='white', 
                                      font=('Arial', 12, 'bold'),
                                      cursor='hand2')
        self.generate_btn.pack(fill=tk.X, pady=5)
        
        self.reset_btn = tk.Button(btn_frame, text="🔄 重置视角", 
                                   command=self.reset_camera,
                                   bg='#3498db', fg='white',
                                   font=('Arial', 10), cursor='hand2')
        self.reset_btn.pack(fill=tk.X, pady=2)
        
        self.export_btn = tk.Button(btn_frame, text="💾 导出高度图", 
                                    command=self.export_heightmap,
                                    bg='#9b59b6', fg='white',
                                    font=('Arial', 10), cursor='hand2')
        self.export_btn.pack(fill=tk.X, pady=2)
        
        # 进度条
        progress_frame = tk.Frame(control_panel, bg='#34495e')
        progress_frame.pack(fill=tk.X, padx=15, pady=15)
        
        self.progress_label = tk.Label(progress_frame, textvariable=self.progress_var,
                                       bg='#34495e', fg='#ecf0f1', font=('Arial', 9))
        self.progress_label.pack()
        
        self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_val,
                                            length=250, mode='determinate')
        self.progress_bar.pack(pady=5)
        
        # 信息显示
        info_text = tk.Text(control_panel, height=8, bg='#2c3e50', fg='#ecf0f1',
                            font=('Consolas', 8), wrap=tk.WORD)
        info_text.pack(fill=tk.BOTH, padx=15, pady=10)
        info_text.insert('1.0', 
            "📊 实时信息:\n"
            "• 使用 Diamond-Square 算法\n"
            "• 支持多线程渲染\n"
            "• 动态颜色映射\n"
            "• 交互式3D视图\n\n"
            "💡 提示: 点击生成后稍等\n"
            "   可使用鼠标拖拽旋转视图\n"
            "   快捷键: G-生成, R-重置视角")
        info_text.config(state='disabled')
        
        # 右侧3D视图区域
        self.fig = Figure(figsize=(10, 8), dpi=100, facecolor='#1a252c')
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.ax.set_facecolor('#1a252c')
        self.ax.xaxis.pane.fill = False
        self.ax.yaxis.pane.fill = False
        self.ax.zaxis.pane.fill = False
        
        self.canvas = FigureCanvasTkAgg(self.fig, master=main_frame)
        self.canvas.get_tk_widget().pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        toolbar = NavigationToolbar2Tk(self.canvas, main_frame)
        toolbar.update()
        
        # 显示初始图
        self.display_default_plot()
        
    def create_param_section(self, parent, title, params):
        """创建参数区域"""
        frame = tk.LabelFrame(parent, text=title, bg='#34495e', fg='#ecf0f1',
                              font=('Arial', 10, 'bold'))
        frame.pack(fill=tk.X, padx=15, pady=5)
        
        for label, param_key, config in params:
            param_frame = tk.Frame(frame, bg='#34495e')
            param_frame.pack(fill=tk.X, padx=10, pady=3)
            
            tk.Label(param_frame, text=label, bg='#34495e', fg='#bdc3c7',
                    width=12, anchor='w').pack(side=tk.LEFT)
            
            if isinstance(config, tuple):
                # 滑动条
                var = tk.DoubleVar(value=self.params[param_key])
                scale = tk.Scale(param_frame, from_=config[0], to=config[1], 
                                resolution=config[2], orient=tk.HORIZONTAL,
                                variable=var, bg='#34495e', fg='white',
                                length=150, showvalue=True)
                scale.pack(side=tk.LEFT, padx=5)
                
                # 同步更新
                def make_callback(key):
                    return lambda *args: self.update_param(key, var.get())
                scale.config(command=make_callback(param_key))
                self.param_vars[param_key] = var
                
            elif isinstance(config, list):
                # 下拉菜单
                var = tk.StringVar(value=str(self.params[param_key]))
                combo = ttk.Combobox(param_frame, textvariable=var, 
                                    values=config, state='readonly', width=12)
                combo.pack(side=tk.LEFT, padx=5)
                combo.bind('<<ComboboxSelected>>', 
                          lambda e, k=param_key: self.update_param(k, var.get()))
                self.param_vars[param_key] = var
                
    def create_camera_section(self, parent, title):
        """创建相机控制区域"""
        frame = tk.LabelFrame(parent, text=title, bg='#34495e', fg='#ecf0f1',
                              font=('Arial', 10, 'bold'))
        frame.pack(fill=tk.X, padx=15, pady=5)
        
        camera_params = [
            ('俯仰角', 'elev', (-90, 90, 5)),
            ('方位角', 'azim', (-180, 180, 5)),
            ('距离', 'dist', (5, 20, 0.5))
        ]
        
        for label, param_key, config in camera_params:
            param_frame = tk.Frame(frame, bg='#34495e')
            param_frame.pack(fill=tk.X, padx=10, pady=3)
            
            tk.Label(param_frame, text=label, bg='#34495e', fg='#bdc3c7',
                    width=12, anchor='w').pack(side=tk.LEFT)
            
            var = tk.DoubleVar(value=self.camera_params[param_key])
            scale = tk.Scale(param_frame, from_=config[0], to=config[1], 
                            resolution=config[2], orient=tk.HORIZONTAL,
                            variable=var, bg='#34495e', fg='white',
                            length=150, showvalue=True)
            scale.pack(side=tk.LEFT, padx=5)
            
            def make_callback(key):
                return lambda *args: self.update_camera_param(key, var.get())
            scale.config(command=make_callback(param_key))
            self.param_vars[param_key] = var
            
    def update_param(self, key, value):
        """更新参数"""
        if key in ['size', 'seed', 'octaves']:
            value = int(float(value))
        elif key in ['roughness', 'scale', 'persistence', 'lacunarity', 
                    'elevation_factor', 'water_level']:
            value = float(value)
        
        self.params[key] = value
        
    def update_camera_param(self, key, value):
        """更新相机参数"""
        self.camera_params[key] = float(value)
        if self.current_surf is not None:
            self.update_camera()
        
    def setup_shortcuts(self):
        """设置快捷键"""
        def on_key(event):
            if event.char == 'g':
                self.start_generation_thread()
            elif event.char == 'r':
                self.reset_camera()
            elif event.char == 's':
                self.export_heightmap()
                
        self.root.bind('<Key>', on_key)
        
    def update_camera(self):
        """更新相机视角"""
        self.ax.view_init(elev=self.camera_params['elev'], azim=self.camera_params['azim'])
        self.ax.dist = self.camera_params['dist']
        self.canvas.draw_idle()
        
    def reset_camera(self):
        """重置相机"""
        self.camera_params['elev'] = 25
        self.camera_params['azim'] = -60
        self.camera_params['dist'] = 10
        
        # 更新滑动条
        if 'elev' in self.param_vars:
            self.param_vars['elev'].set(25)
            self.param_vars['azim'].set(-60)
            self.param_vars['dist'].set(10)
        
        self.update_camera()
        
    def display_default_plot(self):
        """显示默认图"""
        self.ax.clear()
        # 创建一个简单的初始显示
        x = np.linspace(-5, 5, 50)
        y = np.linspace(-5, 5, 50)
        X, Y = np.meshgrid(x, y)
        Z = np.sin(np.sqrt(X**2 + Y**2)) * 2
        surf = self.ax.plot_surface(X, Y, Z, cmap=cm.terrain, linewidth=0, 
                                    antialiased=True, alpha=0.8)
        self.fig.colorbar(surf, ax=self.ax, shrink=0.5, aspect=20)
        self.ax.set_title("等待生成山脉...", color='white', fontsize=12)
        self.ax.set_xlabel("X", color='white')
        self.ax.set_ylabel("Y", color='white')
        self.ax.set_zlabel("高度", color='white')
        self.ax.tick_params(colors='white')
        self.canvas.draw()
        self.current_surf = surf
        
    def start_generation_thread(self):
        """启动生成线程"""
        if self.is_rendering:
            messagebox.showwarning("提示", "正在生成中，请稍等...")
            return
            
        self.is_rendering = True
        self.generate_btn.config(state='disabled', text='⏳ 生成中...')
        self.progress_var.set("初始化分形算法...")
        self.progress_val.set(0)
        
        self.render_thread = threading.Thread(target=self.generate_mountain)
        self.render_thread.daemon = True
        self.render_thread.start()
        
    def generate_mountain(self):
        """生成山脉高度图 (Diamond-Square算法)"""
        try:
            np.random.seed(int(self.params['seed']))
            size = int(self.params['size'])
            roughness = self.params['roughness']
            octaves = self.params['octaves']
            persistence = self.params['persistence']
            lacunarity = self.params['lacunarity']
            
            # 确保size是2^n+1形式
            power = int(np.log2(size - 1))
            size = 2 ** power + 1
            
            self.progress_var.set(f"创建 {size}x{size} 网格...")
            self.progress_val.set(10)
            
            # 初始化高度图
            heightmap = np.zeros((size, size))
            
            # 随机初始化四个角
            heightmap[0, 0] = np.random.uniform(-1, 1)
            heightmap[0, size-1] = np.random.uniform(-1, 1)
            heightmap[size-1, 0] = np.random.uniform(-1, 1)
            heightmap[size-1, size-1] = np.random.uniform(-1, 1)
            
            # Diamond-Square 算法
            step_size = size - 1
            current_roughness = roughness
            
            self.progress_var.set("执行 Diamond-Square 算法...")
            total_steps = int(np.log2(step_size)) + 1
            step_count = 0
            
            while step_size > 1:
                half_step = step_size // 2
                
                # Diamond step
                for x in range(half_step, size, step_size):
                    for y in range(half_step, size, step_size):
                        a = heightmap[x - half_step, y - half_step]
                        b = heightmap[x - half_step, y + half_step]
                        c = heightmap[x + half_step, y - half_step]
                        d = heightmap[x + half_step, y + half_step]
                        avg = (a + b + c + d) / 4.0
                        heightmap[x, y] = avg + np.random.uniform(-current_roughness, current_roughness)
                
                # Square step
                for x in range(0, size, half_step):
                    for y in range((x + half_step) % step_size, size, step_size):
                        points = []
                        if x - half_step >= 0:
                            points.append(heightmap[x - half_step, y])
                        if x + half_step < size:
                            points.append(heightmap[x + half_step, y])
                        if y - half_step >= 0:
                            points.append(heightmap[x, y - half_step])
                        if y + half_step < size:
                            points.append(heightmap[x, y + half_step])
                        
                        if points:
                            avg = sum(points) / len(points)
                            heightmap[x, y] = avg + np.random.uniform(-current_roughness, current_roughness)
                
                # 更新参数
                current_roughness *= (0.5 ** persistence)
                step_size //= 2
                step_count += 1
                
                # 更新进度
                progress = 20 + (step_count / total_steps) * 40
                self.progress_val.set(progress)
                time.sleep(0.01)
            
            # 应用多重分形噪声 (增强复杂度)
            if octaves > 1:
                self.progress_var.set("应用分形噪声增强...")
                enhanced = heightmap.copy()
                amplitude = 0.5
                
                for octave in range(1, octaves):
                    # 生成更小的噪声层
                    small_size = max(4, size // (2 ** octave))
                    noise = np.random.randn(small_size, small_size) * amplitude
                    # 简单的放大噪声
                    noise_large = np.zeros((size, size))
                    for i in range(size):
                        for j in range(size):
                            ni = int(i / size * small_size)
                            nj = int(j / size * small_size)
                            ni = min(ni, small_size - 1)
                            nj = min(nj, small_size - 1)
                            noise_large[i, j] = noise[ni, nj]
                    
                    enhanced += noise_large * (0.5 ** octave)
                    amplitude *= persistence
                
                heightmap = enhanced / (1 + sum([0.5 ** i for i in range(1, octaves)]))
            
            # 标准化高度图
            self.progress_var.set("标准化高度图...")
            self.progress_val.set(65)
            
            heightmap = (heightmap - heightmap.min()) / (heightmap.max() - heightmap.min())
            heightmap = heightmap * self.params['elevation_factor']
            
            self.progress_var.set("生成网格数据...")
            self.progress_val.set(80)
            
            # 创建坐标网格
            x = np.linspace(-self.params['scale'], self.params['scale'], size)
            y = np.linspace(-self.params['scale'], self.params['scale'], size)
            X, Y = np.meshgrid(x, y)
            Z = heightmap
            
            self.progress_var.set("渲染3D图形...")
            self.progress_val.set(90)
            
            # 在主线程中更新UI
            self.root.after(0, self.update_plot, X, Y, Z)
            
        except Exception as e:
            self.root.after(0, self.handle_error, str(e))
            
    def update_plot(self, X, Y, Z):
        """更新3D图形 (在主线程中运行)"""
        try:
            self.ax.clear()
            
            # 获取颜色映射
            cmap_name = self.params['color_map']
            cmap_dict = {
                'terrain': cm.terrain,
                'viridis': cm.viridis,
                'plasma': cm.plasma,
                'coolwarm': cm.coolwarm,
                'jet': cm.jet,
                'gist_earth': cm.gist_earth
            }
            cmap = cmap_dict.get(cmap_name, cm.terrain)
            
            # 绘制曲面
            surf = self.ax.plot_surface(X, Y, Z, cmap=cmap, linewidth=0, 
                                       antialiased=True, alpha=0.95,
                                       rstride=1, cstride=1)
            
            # 添加水位线等高线
            water_level = self.params['water_level']
            if water_level > 0 and water_level < Z.max():
                try:
                    self.ax.contour(X, Y, Z, levels=[water_level], 
                                   colors='cyan', linewidths=1, alpha=0.6)
                except:
                    pass
            
            # 设置颜色条
            cbar = self.fig.colorbar(surf, ax=self.ax, shrink=0.6, aspect=20, 
                                     label='海拔高度', pad=0.1)
            cbar.ax.yaxis.label.set_color('white')
            cbar.ax.tick_params(colors='white')
            
            # 设置标题和标签
            self.ax.set_title(f"3D 分形山脉 | 粗糙度: {self.params['roughness']:.2f} | 八度: {self.params['octaves']}",
                             color='white', fontsize=12, pad=20)
            self.ax.set_xlabel("X 轴", color='white', fontsize=10)
            self.ax.set_ylabel("Y 轴", color='white', fontsize=10)
            self.ax.set_zlabel("海拔高度", color='white', fontsize=10)
            
            # 设置样式
            self.ax.xaxis.pane.fill = False
            self.ax.yaxis.pane.fill = False
            self.ax.zaxis.pane.fill = False
            self.ax.grid(True, alpha=0.3)
            self.ax.tick_params(colors='white')
            
            # 应用相机设置
            self.ax.view_init(elev=self.camera_params['elev'], azim=self.camera_params['azim'])
            self.ax.dist = self.camera_params['dist']
            
            self.canvas.draw()
            self.current_surf = surf
            
            # 计算统计信息
            stats = f"✅ 生成完成!\n"
            stats += f"📐 网格大小: {X.shape[0]}x{X.shape[1]}\n"
            stats += f"⛰️  最高点: {Z.max():.2f}\n"
            stats += f"🏞️  最低点: {Z.min():.2f}\n"
            stats += f"📊 平均高度: {Z.mean():.2f}"
            
            self.progress_var.set(stats)
            self.progress_val.set(100)
            
            # 延迟重置进度条
            def reset_progress():
                if not self.is_rendering:
                    self.progress_val.set(0)
            self.root.after(3000, reset_progress)
            
        except Exception as e:
            self.handle_error(f"渲染错误: {str(e)}")
        finally:
            self.is_rendering = False
            self.generate_btn.config(state='normal', text='🚀 生成山脉')
            
    def handle_error(self, error_msg):
        """错误处理"""
        self.is_rendering = False
        self.generate_btn.config(state='normal', text='🚀 生成山脉')
        self.progress_var.set(f"错误: {error_msg}")
        messagebox.showerror("生成错误", f"山脉生成失败:\n{error_msg}")
        
    def export_heightmap(self):
        """导出高度图数据"""
        if self.current_surf is None:
            messagebox.showwarning("警告", "请先生成一个山脉")
            return
            
        try:
            filename = filedialog.asksaveasfilename(
                defaultextension=".npy",
                filetypes=[("NumPy 数组", "*.npy"), ("文本文件", "*.txt")]
            )
            
            if filename:
                messagebox.showinfo("导出功能", "当前版本的导出功能需要先生成山脉并保存数据")
                
        except Exception as e:
            messagebox.showerror("导出失败", f"无法导出文件: {str(e)}")


def main():
    root = tk.Tk()
    
    # 设置主题风格
    style = ttk.Style()
    style.theme_use('clam')
    
    # 自定义进度条样式
    style.configure("TProgressbar", 
                   background='#27ae60',
                   troughcolor='#2c3e50',
                   thickness=10)
    
    app = ComplexMountainGenerator(root)
    
    # 设置窗口居中
    root.update_idletasks()
    width = root.winfo_width()
    height = root.winfo_height()
    x = (root.winfo_screenwidth() // 2) - (width // 2)
    y = (root.winfo_screenheight() // 2) - (height // 2)
    root.geometry(f'{width}x{height}+{x}+{y}')
    
    root.mainloop()


if __name__ == "__main__":
    # 检查必需的库
    required_libs = ['numpy', 'matplotlib']
    missing_libs = []
    
    for lib in required_libs:
        try:
            __import__(lib)
        except ImportError:
            missing_libs.append(lib)
    
    if missing_libs:
        print(f"缺少必需的库: {', '.join(missing_libs)}")
        print("请运行以下命令安装:")
        print(f"pip install {' '.join(missing_libs)}")
        sys.exit(1)
    
    main()