import pygame
import sys
import math

# --- 配置常量 ---
SCREEN_WIDTH, SCREEN_HEIGHT = 800, 600
FPS = 60

# 颜色定义
COLOR_BG = (30, 30, 30)           # 背景深灰
COLOR_PATH = (100, 100, 100)      # 路径灰
COLOR_ENEMY = (255, 50, 50)       # 敌人红
COLOR_TOWER = (50, 200, 50)       # 防御塔绿
COLOR_BULLET = (255, 255, 0)      # 子弹黄
COLOR_RANGE = (255, 255, 255, 50) # 攻击范围白(半透明需要特殊处理，这里简化)

pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Pygame Tower Defense Mini Demo")
clock = pygame.time.Clock()
font = pygame.font.Font(None, 36)

# --- 基础类 ---
class Enemy(pygame.sprite.Sprite):
    def __init__(self, path):
        super().__init__()
        self.path = path
        self.current_point = 0
        self.x, self.y = path[0]
        self.speed = 2
        self.health = 100
        
        self.image = pygame.Surface((30, 30), pygame.SRCALPHA)
        pygame.draw.circle(self.image, COLOR_ENEMY, (15, 15), 15)
        self.rect = self.image.get_rect(center=(self.x, self.y))

    def update(self):
        if self.current_point < len(self.path) - 1:
            target_x, target_y = self.path[self.current_point + 1]
            dx = target_x - self.x
            dy = target_y - self.y
            dist = math.hypot(dx, dy)
            
            if dist <= self.speed:
                self.x, self.y = target_x, target_y
                self.current_point += 1
            else:
                self.x += self.speed * dx / dist
                self.y += self.speed * dy / dist
        self.rect.center = (int(self.x), int(self.y))

class Bullet(pygame.sprite.Sprite):
    def __init__(self, x, y, target, damage=20):
        super().__init__()
        self.target = target
        self.damage = damage
        self.speed = 7
        self.x, self.y = x, y
        
        self.image = pygame.Surface((10, 10), pygame.SRCALPHA)
        pygame.draw.circle(self.image, COLOR_BULLET, (5, 5), 5)
        self.rect = self.image.get_rect(center=(x, y))

    def update(self):
        if not self.target or not self.target.alive():
            self.kill()
            return
            
        dx = self.target.rect.centerx - self.x
        dy = self.target.rect.centery - self.y
        dist = math.hypot(dx, dy)
        
        if dist <= self.speed:
            self.target.health -= self.damage
            self.kill()
        else:
            self.x += self.speed * dx / dist
            self.y += self.speed * dy / dist
        self.rect.center = (int(self.x), int(self.y))

class Tower(pygame.sprite.Sprite):
    def __init__(self, x, y):
        super().__init__()
        self.range = 120
        self.cooldown = 0
        self.fire_rate = 40  # 帧数
        
        self.image = pygame.Surface((40, 40), pygame.SRCALPHA)
        pygame.draw.circle(self.image, COLOR_TOWER, (20, 20), 20)
        self.rect = self.image.get_rect(center=(x, y))

    def update(self, enemies, bullets):
        if self.cooldown > 0:
            self.cooldown -= 1
            return
            
        # 寻找范围内的敌人
        for enemy in enemies:
            dx = self.rect.centerx - enemy.rect.centerx
            dy = self.rect.centery - enemy.rect.centery
            dist = math.hypot(dx, dy)
            
            if dist <= self.range:
                bullets.add(Bullet(self.rect.centerx, self.rect.centery, enemy))
                self.cooldown = self.fire_rate
                break

# --- 关卡构建 ---
def create_level():
    # 定义一条蜿蜒的路径
    path = [(0, 300), (200, 300), (200, 100), (600, 100), (600, 500), (800, 500)]
    
    enemies = pygame.sprite.Group()
    towers = pygame.sprite.Group()
    bullets = pygame.sprite.Group()
    
    return path, enemies, towers, bullets

# --- 主循环 ---
def main():
    path, enemies, towers, bullets = create_level()
    all_sprites = pygame.sprite.Group(enemies, towers, bullets)
    
    spawn_timer = 0
    wave_health_bonus = 0
    score = 0

    while True:
        dt = clock.tick(FPS)
        
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit(); sys.exit()
                
            # 鼠标左键点击建造防御塔
            if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
                mx, my = pygame.mouse.get_pos()
                tower = Tower(mx, my)
                towers.add(tower)
                all_sprites.add(tower)

        # 自动生成敌人
        spawn_timer += 1
        if spawn_timer >= 60:  # 每60帧生成一个
            enemy = Enemy(path)
            enemy.health += wave_health_bonus
            enemies.add(enemy)
            all_sprites.add(enemy)
            spawn_timer = 0

        # 更新逻辑
        enemies.update()
        towers.update(enemies, bullets)
        bullets.update()
        
        # 检查到达终点的敌人
        for enemy in enemies:
            if enemy.current_point >= len(enemy.path) - 1:
                enemy.kill()
                
        # 清理死亡敌人并计分
        for enemy in enemies:
            if enemy.health <= 0:
                score += 100
                enemy.kill()

        # 绘制
        screen.fill(COLOR_BG)
        
        # 绘制路径
        pygame.draw.lines(screen, COLOR_PATH, False, path, 40)
        
        all_sprites.draw(screen)
        
        # UI显示
        score_text = font.render(f"Score: {score}", True, (255, 255, 255))
        hint_text = font.render("Click to Build Tower", True, (255, 255, 255))
        screen.blit(score_text, (10, 10))
        screen.blit(hint_text, (SCREEN_WIDTH - 280, 10))

        pygame.display.flip()

if __name__ == "__main__":
    main()