import pygame
import math
import random

# --- 初始化 ---
pygame.init()
WIDTH, HEIGHT = 800, 600
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("图形工厂枪战 - Geometry Factory")
clock = pygame.time.Clock()

# 赛博朋克配色
BG_COLOR = (20, 20, 30)
GRID_COLOR = (40, 40, 60)
CORE_COLOR = (0, 255, 200)
PLAYER_COLOR = (255, 255, 255)
LASER_COLOR = (255, 50, 50)
UI_COLOR = (0, 255, 200)

FONT = pygame.font.SysFont("consolas", 24)
BIG_FONT = pygame.font.SysFont("consolas", 60)

# --- 类定义 ---

class Laser:
    def __init__(self, x, y, angle):
        self.x = x
        self.y = y
        self.speed = 15
        self.vx = math.cos(angle) * self.speed
        self.vy = math.sin(angle) * self.speed
        self.rect = pygame.Rect(x-4, y-4, 8, 8)
        self.alive = True

    def update(self):
        self.x += self.vx
        self.y += self.vy
        self.rect.center = (self.x, self.y)
        if not screen.get_rect().contains(self.rect):
            self.alive = False

    def draw(self, surface):
        # 画个菱形子弹
        points = [
            (self.x, self.y - 6),
            (self.x + 4, self.y),
            (self.x, self.y + 6),
            (self.x - 4, self.y)
        ]
        pygame.draw.polygon(surface, LASER_COLOR, points)

class GeoEnemy:
    def __init__(self, x, y, shape_type):
        self.x = x
        self.y = y
        self.shape = shape_type # 'tri', 'rect', 'circle'
        self.speed = 1.2 + random.random() * 0.8
        self.size = 20
        self.rect = pygame.Rect(x-self.size, y-self.size, self.size*2, self.size*2)
        self.alive = True
        self.angle = 0 # 用于旋转动画

    def update(self, target_x, target_y):
        dx = target_x - self.x
        dy = target_y - self.y
        dist = math.hypot(dx, dy)
        if dist > 0:
            self.x += (dx / dist) * self.speed
            self.y += (dy / dist) * self.speed
        self.rect.center = (self.x, self.y)
        self.angle += 3 # 旋转速度

    def draw(self, surface):
        color = (255, 100, 50) if self.shape == 'tri' else (100, 100, 255) if self.shape == 'rect' else (255, 255, 50)
        
        # 简单的旋转效果：通过改变绘制坐标模拟（这里用固定大小，仅视觉旋转）
        # 为了性能，不每帧计算复杂旋转矩阵，直接用形状区分
        if self.shape == 'tri':
            points = [
                (self.x, self.y - self.size),
                (self.x - self.size, self.y + self.size),
                (self.x + self.size, self.y + self.size)
            ]
            pygame.draw.polygon(surface, color, points, 3)
        elif self.shape == 'rect':
            rect = pygame.Rect(self.x - self.size//1.5, self.y - self.size//1.5, self.size*1.3, self.size*1.3)
            pygame.draw.rect(surface, color, rect, 3)
        else: # circle
            pygame.draw.circle(surface, color, (int(self.x), int(self.y)), self.size, 3)

class Player:
    def __init__(self):
        self.x = WIDTH // 2
        self.y = HEIGHT // 2
        self.speed = 5
        self.rect = pygame.Rect(self.x-15, self.y-15, 30, 30)
        self.angle = 0
        self.cooldown = 0

    def update(self, mouse_pos):
        keys = pygame.key.get_pressed()
        if keys[pygame.K_w]: self.y -= self.speed
        if keys[pygame.K_s]: self.y += self.speed
        if keys[pygame.K_a]: self.x -= self.speed
        if keys[pygame.K_d]: self.x += self.speed
        
        # 限制在工厂范围内
        margin = 40
        self.x = max(margin, min(WIDTH-margin, self.x))
        self.y = max(margin, min(HEIGHT-margin, self.y))
        self.rect.center = (self.x, self.y)

        dx = mouse_pos[0] - self.x
        dy = mouse_pos[1] - self.y
        self.angle = math.atan2(dy, dx)
        
        if self.cooldown > 0: self.cooldown -= 1

    def shoot(self, lasers):
        if self.cooldown <= 0:
            bx = self.x + math.cos(self.angle) * 30
            by = self.y + math.sin(self.angle) * 30
            lasers.append(Laser(bx, by, self.angle))
            self.cooldown = 8 # 射速

    def draw(self, surface):
        # 身体：六边形
        points = []
        for i in range(6):
            angle = self.angle + i * math.pi / 3
            px = self.x + math.cos(angle) * 18
            py = self.y + math.sin(angle) * 18
            points.append((px, py))
        pygame.draw.polygon(surface, PLAYER_COLOR, points, 2)
        
        # 枪管
        gun_end_x = self.x + math.cos(self.angle) * 35
        gun_end_y = self.y + math.sin(self.angle) * 35
        pygame.draw.line(surface, PLAYER_COLOR, (self.x, self.y), (gun_end_x, gun_end_y), 3)

# --- 主程序 ---

def main():
    player = Player()
    lasers = []
    enemies = []
    score = 0
    spawn_timer = 0
    game_over = False
    core_hp = 100

    running = True
    while running:
        clock.tick(60)
        mouse_pos = pygame.mouse.get_pos()

        for event in pygame.event.get():
            if event.type == pygame.QUIT: running = False
            if event.type == pygame.MOUSEBUTTONDOWN and not game_over:
                if event.button == 1: player.shoot(lasers)
            if event.type == pygame.KEYDOWN and game_over:
                if event.key == pygame.K_r:
                    main(); return

        if not game_over:
            player.update(mouse_pos)

            # 生成几何敌人
            spawn_timer += 1
            difficulty = max(20, 60 - score // 100) # 分数越高，生成越快
            if spawn_timer > difficulty:
                spawn_timer = 0
                side = random.choice(['top', 'bottom', 'left', 'right'])
                if side == 'top': ex, ey = random.randint(0, WIDTH), -30
                elif side == 'bottom': ex, ey = random.randint(0, WIDTH), HEIGHT+30
                elif side == 'left': ex, ey = -30, random.randint(0, HEIGHT)
                else: ex, ey = WIDTH+30, random.randint(0, HEIGHT)
                
                shape = random.choice(['tri', 'rect', 'circle'])
                enemies.append(GeoEnemy(ex, ey, shape))

            # 更新激光
            for l in lasers[:]:
                l.update()
                if not l.alive: lasers.remove(l)

            # 更新敌人
            for e in enemies[:]:
                e.update(player.x, player.y)
                
                # 激光击中
                hit = False
                for l in lasers[:]:
                    if l.rect.colliderect(e.rect):
                        e.alive = False
                        l.alive = False
                        score += 10
                        hit = True
                        break
                if hit:
                    lasers = [l for l in lasers if l.alive]
                    enemies = [en for en in enemies if en.alive]
                    continue

                # 敌人碰到玩家 -> 扣核心血
                if e.rect.colliderect(player.rect):
                    core_hp -= 10
                    e.alive = False
                    enemies = [en for en in enemies if en.alive]
                    if core_hp <= 0: game_over = True

        # --- 绘图 ---
        screen.fill(BG_COLOR)
        
        # 画背景网格
        for x in range(0, WIDTH, 40): pygame.draw.line(screen, GRID_COLOR, (x, 0), (x, HEIGHT))
        for y in range(0, HEIGHT, 40): pygame.draw.line(screen, GRID_COLOR, (0, y), (WIDTH, y))

        # 画工厂核心（玩家位置）
        pygame.draw.circle(screen, CORE_COLOR, (player.x, player.y), 40, 2)
        pygame.draw.circle(screen, CORE_COLOR, (player.x, player.y), 30, 1)

        player.draw(screen)
        for e in enemies: e.draw(screen)
        for l in lasers: l.draw(screen)

        # UI
        hp_txt = FONT.render(f"CORE INTEGRITY: {max(0, core_hp)}%", True, UI_COLOR)
        score_txt = FONT.render(f"OUTPUT SCORE: {score}", True, UI_COLOR)
        screen.blit(hp_txt, (10, 10))
        screen.blit(score_txt, (WIDTH - 250, 10))

        if game_over:
            over_txt = BIG_FONT.render("FACTORY BREACH", True, (255, 50, 50))
            rect = over_txt.get_rect(center=(WIDTH//2, HEIGHT//2 - 30))
            screen.blit(over_txt, rect)
            restart_txt = FONT.render("Press R to Reboot System", True, UI_COLOR)
            r_rect = restart_txt.get_rect(center=(WIDTH//2, HEIGHT//2 + 30))
            screen.blit(restart_txt, r_rect)

        pygame.display.flip()

    pygame.quit()

if __name__ == "__main__":
    main()