import pygame
import math
import sys
from typing import List, Optional

# 初始化Pygame
pygame.init()

# 颜色定义
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
GRAY = (128, 128, 128)
DARK_GRAY = (64, 64, 64)
LIGHT_GRAY = (200, 200, 200)
BLUE = (30, 144, 255)
RED = (255, 69, 0)
GREEN = (50, 205, 50)
ORANGE = (255, 165, 0)

# 屏幕设置
WIDTH, HEIGHT = 400, 600
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("高级计算器")

# 字体
font_small = pygame.font.Font(None, 24)
font_medium = pygame.font.Font(None, 32)
font_large = pygame.font.Font(None, 40)
font_history = pygame.font.Font(None, 20)


class Calculator:
    def __init__(self):
        self.current_input = "0"
        self.previous_input = ""
        self.operator = ""
        self.waiting_for_operand = False
        self.history: List[str] = []
        self.memory = 0.0

    def clear(self):
        self.current_input = "0"
        self.previous_input = ""
        self.operator = ""
        self.waiting_for_operand = False

    def clear_entry(self):
        self.current_input = "0"

    def delete_last(self):
        if len(self.current_input) > 1:
            self.current_input = self.current_input[:-1]
        else:
            self.current_input = "0"

    def memory_clear(self):
        self.memory = 0.0

    def memory_recall(self):
        self.current_input = str(self.memory)
        self.waiting_for_operand = True

    def memory_add(self):
        try:
            self.memory += float(self.current_input)
        except ValueError:
            pass

    def memory_subtract(self):
        try:
            self.memory -= float(self.current_input)
        except ValueError:
            pass

    def input_number(self, number: str):
        if self.waiting_for_operand:
            self.current_input = number
            self.waiting_for_operand = False
        else:
            if self.current_input == "0":
                self.current_input = number
            else:
                self.current_input += number

    def input_decimal(self):
        if self.waiting_for_operand:
            self.current_input = "0."
            self.waiting_for_operand = False
        elif "." not in self.current_input:
            self.current_input += "."

    def set_operator(self, op: str):
        if self.operator and not self.waiting_for_operand:
            self.calculate()
        self.previous_input = self.current_input
        self.operator = op
        self.waiting_for_operand = True

    def calculate(self):
        if not self.operator or not self.previous_input:
            return

        try:
            prev = float(self.previous_input)
            current = float(self.current_input)

            if self.operator == "+":
                result = prev + current
            elif self.operator == "-":
                result = prev - current
            elif self.operator == "×":
                result = prev * current
            elif self.operator == "÷":
                if current == 0:
                    self.current_input = "Error"
                    self.clear_operation()
                    return
                result = prev / current
            elif self.operator == "^":
                result = prev ** current
            else:
                return

            # 格式化结果
            if result == int(result):
                self.current_input = str(int(result))
            else:
                self.current_input = f"{result:.10g}"

            # 添加到历史记录
            history_entry = f"{self.previous_input} {self.operator} {current} = {self.current_input}"
            self.history.append(history_entry)
            if len(self.history) > 5:
                self.history.pop(0)

            self.clear_operation()

        except (ValueError, OverflowError):
            self.current_input = "Error"
            self.clear_operation()

    def clear_operation(self):
        self.operator = ""
        self.previous_input = ""
        self.waiting_for_operand = False

    def scientific_function(self, func: str):
        try:
            value = float(self.current_input)
            if func == "sin":
                result = math.sin(math.radians(value))
            elif func == "cos":
                result = math.cos(math.radians(value))
            elif func == "tan":
                result = math.tan(math.radians(value))
            elif func == "log":
                if value <= 0:
                    self.current_input = "Error"
                    return
                result = math.log10(value)
            elif func == "ln":
                if value <= 0:
                    self.current_input = "Error"
                    return
                result = math.log(value)
            elif func == "sqrt":
                if value < 0:
                    self.current_input = "Error"
                    return
                result = math.sqrt(value)
            elif func == "x²":
                result = value ** 2
            elif func == "1/x":
                if value == 0:
                    self.current_input = "Error"
                    return
                result = 1 / value
            elif func == "±":
                result = -value
            elif func == "π":
                result = math.pi
                self.current_input = f"{result:.10g}"
                return
            elif func == "e":
                result = math.e
                self.current_input = f"{result:.10g}"
                return
            else:
                return

            # 格式化结果
            if abs(result) < 1e-10:
                result = 0
            if result == int(result):
                self.current_input = str(int(result))
            else:
                self.current_input = f"{result:.10g}"

            self.waiting_for_operand = True

        except (ValueError, OverflowError):
            self.current_input = "Error"


class Button:
    def __init__(self, x: int, y: int, width: int, height: int, text: str,
                 color=LIGHT_GRAY, text_color=BLACK, font=font_medium):
        self.rect = pygame.Rect(x, y, width, height)
        self.text = text
        self.color = color
        self.text_color = text_color
        self.font = font
        self.hovered = False

    def draw(self, surface):
        color = self.color
        if self.hovered:
            # 高亮效果
            if isinstance(color, tuple):
                color = tuple(min(255, c + 30) for c in color)

        pygame.draw.rect(surface, color, self.rect, border_radius=8)
        pygame.draw.rect(surface, BLACK, self.rect, 2, border_radius=8)

        text_surface = self.font.render(self.text, True, self.text_color)
        text_rect = text_surface.get_rect(center=self.rect.center)
        surface.blit(text_surface, text_rect)

    def is_hovered(self, pos):
        self.hovered = self.rect.collidepoint(pos)
        return self.hovered


def create_buttons():
    buttons = []

    # 数字按钮
    button_positions = [
        # 第一行（科学函数）
        ("sin", 10, 120, 60, 40, BLUE),
        ("cos", 80, 120, 60, 40, BLUE),
        ("tan", 150, 120, 60, 40, BLUE),
        ("log", 220, 120, 60, 40, BLUE),
        ("ln", 290, 120, 60, 40, BLUE),

        # 第二行（更多科学函数）
        ("√", 10, 170, 60, 40, BLUE),
        ("x²", 80, 170, 60, 40, BLUE),
        ("1/x", 150, 170, 60, 40, BLUE),
        ("±", 220, 170, 60, 40, BLUE),
        ("π", 290, 170, 60, 40, ORANGE),

        # 数字和基本操作
        ("C", 10, 220, 60, 50, RED),
        ("CE", 80, 220, 60, 50, RED),
        ("⌫", 150, 220, 60, 50, RED),
        ("÷", 220, 220, 60, 50, ORANGE),
        ("e", 290, 220, 60, 50, ORANGE),

        ("7", 10, 280, 60, 50, WHITE),
        ("8", 80, 280, 60, 50, WHITE),
        ("9", 150, 280, 60, 50, WHITE),
        ("×", 220, 280, 60, 50, ORANGE),
        ("^", 290, 280, 60, 50, ORANGE),

        ("4", 10, 340, 60, 50, WHITE),
        ("5", 80, 340, 60, 50, WHITE),
        ("6", 150, 340, 60, 50, WHITE),
        ("-", 220, 340, 60, 50, ORANGE),
        ("MC", 290, 340, 60, 24, GRAY),

        ("1", 10, 400, 60, 50, WHITE),
        ("2", 80, 400, 60, 50, WHITE),
        ("3", 150, 400, 60, 50, WHITE),
        ("+", 220, 400, 60, 50, ORANGE),
        ("MR", 290, 366, 60, 24, GRAY),

        ("0", 10, 460, 130, 50, WHITE),
        (".", 150, 460, 60, 50, WHITE),
        ("=", 220, 460, 60, 50, GREEN),
        ("M+", 290, 392, 60, 24, GRAY),
        ("M-", 290, 418, 60, 24, GRAY),
    ]

    for btn_data in button_positions:
        if len(btn_data) == 5:
            text, x, y, w, h = btn_data
            color = LIGHT_GRAY
        else:
            text, x, y, w, h, color = btn_data
        buttons.append(Button(x, y, w, h, text, color))

    return buttons


def main():
    clock = pygame.time.Clock()
    calculator = Calculator()
    buttons = create_buttons()

    running = True
    while running:
        mouse_pos = pygame.mouse.get_pos()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:  # 左键点击
                    for button in buttons:
                        if button.is_hovered(mouse_pos):
                            handle_button_click(calculator, button.text)
            elif event.type == pygame.KEYDOWN:
                handle_key_press(calculator, event.key)

        # 更新按钮悬停状态
        for button in buttons:
            button.is_hovered(mouse_pos)

        # 绘制界面
        screen.fill(DARK_GRAY)

        # 绘制显示屏
        display_rect = pygame.Rect(10, 10, WIDTH - 20, 100)
        pygame.draw.rect(screen, BLACK, display_rect, border_radius=10)

        # 显示当前输入
        current_text = calculator.current_input
        if len(current_text) > 15:
            current_text = current_text[-15:]
        current_surface = font_large.render(current_text, True, WHITE)
        current_rect = current_surface.get_rect(bottomright=(WIDTH - 20, 90))
        screen.blit(current_surface, current_rect)

        # 显示历史记录
        history_y = 20
        for i, entry in enumerate(reversed(calculator.history)):
            if i >= 3:  # 只显示最近3条
                break
            history_surface = font_small.render(entry, True, LIGHT_GRAY)
            screen.blit(history_surface, (20, history_y))
            history_y += 20

        # 绘制内存指示器
        if calculator.memory != 0:
            memory_text = f"M: {calculator.memory:.4g}"
            memory_surface = font_small.render(memory_text, True, GREEN)
            screen.blit(memory_surface, (WIDTH - 120, 20))

        # 绘制按钮
        for button in buttons:
            button.draw(screen)

        pygame.display.flip()
        clock.tick(60)

    pygame.quit()
    sys.exit()


def handle_button_click(calculator: Calculator, text: str):
    # 科学函数
    scientific_funcs = ["sin", "cos", "tan", "log",
                        "ln", "√", "x²", "1/x", "±", "π", "e"]
    if text in scientific_funcs:
        func_map = {"√": "sqrt"}
        func_name = func_map.get(text, text)
        calculator.scientific_function(func_name)
        return

    # 内存操作
    if text == "MC":
        calculator.memory_clear()
    elif text == "MR":
        calculator.memory_recall()
    elif text == "M+":
        calculator.memory_add()
    elif text == "M-":
        calculator.memory_subtract()
    # 清除操作
    elif text == "C":
        calculator.clear()
    elif text == "CE":
        calculator.clear_entry()
    elif text == "⌫":
        calculator.delete_last()
    # 运算符
    elif text in ["+", "-", "×", "÷", "^"]:
        calculator.set_operator(text)
    # 等号
    elif text == "=":
        calculator.calculate()
    # 小数点
    elif text == ".":
        calculator.input_decimal()
    # 数字
    elif text.isdigit():
        calculator.input_number(text)


def handle_key_press(calculator: Calculator, key):
    key_map = {
        pygame.K_0: "0", pygame.K_1: "1", pygame.K_2: "2", pygame.K_3: "3",
        pygame.K_4: "4", pygame.K_5: "5", pygame.K_6: "6", pygame.K_7: "7",
        pygame.K_8: "8", pygame.K_9: "9", pygame.K_PERIOD: ".",
        pygame.K_PLUS: "+", pygame.K_MINUS: "-", pygame.K_ASTERISK: "×",
        pygame.K_SLASH: "÷", pygame.K_EQUALS: "=", pygame.K_RETURN: "=",
        pygame.K_BACKSPACE: "⌫", pygame.K_c: "C", pygame.K_ESCAPE: "C"
    }

    if key in key_map:
        handle_button_click(calculator, key_map[key])
    elif key == pygame.K_p:  # π
        handle_button_click(calculator, "π")
    elif key == pygame.K_e:  # e
        handle_button_click(calculator, "e")


if __name__ == "__main__":
    main()
