import pygame
import random
import sys
import os

# 初始化
pygame.init()
pygame.mixer.init()

# 窗口设置
WIDTH, HEIGHT = 960, 720
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("顶配植物大战僵尸 - Python贴图版")
clock = pygame.time.Clock()
FPS = 60

# 网格常量
ROWS = 5
COLS = 9
CELL_SIZE = 80
GRID_X = 80
GRID_Y = 120

# 游戏状态
STATE_MENU = 0
STATE_DIFF = 1
STATE_PLAY = 2
STATE_GAMEOVER = 3
STATE_WIN = 4
game_state = STATE_MENU

# 难度 & 关卡波次
diff_cfg = {
    0: {"spawn": 200, "wave_num": 3},   # 简单
    1: {"spawn": 140, "wave_num": 5},   # 普通
    2: {"spawn": 80, "wave_num": 8}     # 困难
}
diff_level = 1
zombie_spawn = 140
total_wave = 5
cur_wave = 1
wave_timer = 0
wave_start = False

# 游戏变量
sunlight = 50
plants = []
bullets = []
zombies = []
suns = []
explosions = []
selected_plant = None
game_over = False

# 字体
font_lg = pygame.font.SysFont("simhei", 72)
font_md = pygame.font.SysFont("simhei", 36)
font_sm = pygame.font.SysFont("simhei", 24)

# ===================== 加载图片素材 =====================
def load_img(path, size=None):
    img = pygame.image.load(path).convert_alpha()
    if size:
        img = pygame.transform.scale(img, size)
    return img

try:
    bg_img = load_img("images/bg.png", (WIDTH, HEIGHT))
    sun_img = load_img("images/sun.png", (40, 40))
    sunflower_img = load_img("images/sunflower.png", (CELL_SIZE, CELL_SIZE))
    pea_img = load_img("images/peashooter.png", (CELL_SIZE, CELL_SIZE))
    ice_img = load_img("images/icepea.png", (CELL_SIZE, CELL_SIZE))
    nut_img = load_img("images/wallnut.png", (CELL_SIZE, CELL_SIZE))
    cherry_img = load_img("images/cherry.png", (CELL_SIZE, CELL_SIZE))

    zom_norm_img = load_img("images/zombie_normal.png", (CELL_SIZE, CELL_SIZE))
    zom_cone_img = load_img("images/zombie_cone.png", (CELL_SIZE, CELL_SIZE))
    zom_bucket_img = load_img("images/zombie_bucket.png", (CELL_SIZE, CELL_SIZE))
except:
    # 素材缺失时备用纯色兜底
    bg_img = None
    sunflower_img = pea_img = ice_img = nut_img = cherry_img = None
    zom_norm_img = zom_cone_img = zom_bucket_img = None

# ===================== 加载音效BGM =====================
try:
    pygame.mixer.music.load("sounds/bgm.mp3")
    pygame.mixer.music.play(-1)
    pygame.mixer.music.set_volume(0.6)

    sound_sun = pygame.mixer.Sound("sounds/sun.wav")
    sound_plant = pygame.mixer.Sound("sounds/plant.wav")
    sound_explode = pygame.mixer.Sound("sounds/explode.wav")
except:
    sound_sun = sound_plant = sound_explode = None

# ===================== 游戏类 =====================
# 阳光
class Sun:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.val = 25
        self.rect = pygame.Rect(x, y, 40, 40)
        self.speed = 1.2

    def update(self):
        self.y += self.speed
        self.rect.y = self.y

    def draw(self):
        if sun_img:
            screen.blit(sun_img, (self.x, self.y))
        else:
            pygame.draw.circle(screen, (255,255,0), (self.x+20, self.y+20), 18)

# 子弹
class Bullet:
    def __init__(self, x, y, ice=False):
        self.x = x
        self.y = y
        self.ice = ice
        self.speed = 7
        self.dmg = 20
        self.rect = pygame.Rect(x, y, 12, 6)

    def update(self):
        self.x += self.speed
        self.rect.x = self.x

    def draw(self):
        color = (0,150,255) if self.ice else (0,200,0)
        pygame.draw.ellipse(screen, color, self.rect)

# 植物基类
class Plant:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.hp = 100
        self.rect = pygame.Rect(x, y, CELL_SIZE, CELL_SIZE)

# 向日葵 产阳光
class SunFlower(Plant):
    cost = 50
    def __init__(self, x, y):
        super().__init__(x, y)
        self.prod_cd = 0
        self.cd_max = 120

    def update(self):
        self.prod_cd += 1
        if self.prod_cd >= self.cd_max:
            self.prod_cd = 0
            suns.append(Sun(self.x+20, self.y-20))

    def draw(self):
        if sunflower_img:
            screen.blit(sunflower_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (255,200,0), self.rect)

# 豌豆射手
class PeaShooter(Plant):
    cost = 100
    def __init__(self, x, y):
        super().__init__(x, y)
        self.atk_cd = 0
        self.cd_max = 55

    def update(self):
        self.atk_cd += 1
        if self.atk_cd >= self.cd_max:
            self.atk_cd = 0
            return Bullet(self.x+CELL_SIZE, self.y+CELL_SIZE//2)
        return None

    def draw(self):
        if pea_img:
            screen.blit(pea_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (0,180,0), self.rect)

# 寒冰射手
class IceShooter(Plant):
    cost = 175
    def __init__(self, x, y):
        super().__init__(x, y)
        self.atk_cd = 0
        self.cd_max = 65

    def update(self):
        self.atk_cd += 1
        if self.atk_cd >= self.cd_max:
            self.atk_cd = 0
            return Bullet(self.x+CELL_SIZE, self.y+CELL_SIZE//2, ice=True)
        return None

    def draw(self):
        if ice_img:
            screen.blit(ice_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (100,200,255), self.rect)

# 坚果
class WallNut(Plant):
    cost = 50
    def __init__(self, x, y):
        super().__init__(x, y)
        self.hp = 500

    def draw(self):
        if nut_img:
            screen.blit(nut_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (150,100,50), self.rect)

# 樱桃炸弹
class CherryBomb(Plant):
    cost = 150
    def __init__(self, x, y):
        super().__init__(x, y)
        self.count = 25
        self.exploded = False

    def update(self):
        if not self.exploded:
            self.count -= 1
            if self.count <= 0:
                self.exploded = True
                return True
        return False

    def draw(self):
        if cherry_img:
            screen.blit(cherry_img, (self.x, self.y))
        else:
            pygame.draw.circle(screen, (255,0,0), (self.x+40, self.y+40), 30)

# 僵尸基类
class Zombie:
    def __init__(self, row):
        self.row = row
        self.x = WIDTH - 60
        self.y = GRID_Y + row * CELL_SIZE
        self.speed = 0.35
        self.hp = 100
        self.max_hp = 100
        self.ice_slow = 0
        self.rect = pygame.Rect(self.x, self.y, CELL_SIZE, CELL_SIZE)

    def update(self):
        spd = self.speed * 0.3 if self.ice_slow > 0 else self.speed
        if self.ice_slow > 0:
            self.ice_slow -= 1
        self.x -= spd
        self.rect.x = self.x

    def draw_hp(self):
        ratio = self.hp / self.max_hp
        pygame.draw.rect(screen, (255,0,0), (self.x, self.y-12, 60, 6))
        pygame.draw.rect(screen, (0,255,0), (self.x, self.y-12, 60*ratio, 6))

class NormalZom(Zombie):
    def draw(self):
        if zom_norm_img:
            screen.blit(zom_norm_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (100,0,0), self.rect)
        self.draw_hp()

class ConeZom(Zombie):
    def __init__(self, row):
        super().__init__(row)
        self.hp = 180
        self.max_hp = 180
    def draw(self):
        if zom_cone_img:
            screen.blit(zom_cone_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (80,0,0), self.rect)
        self.draw_hp()

class BucketZom(Zombie):
    def __init__(self, row):
        super().__init__(row)
        self.hp = 320
        self.max_hp = 320
    def draw(self):
        if zom_bucket_img:
            screen.blit(zom_bucket_img, (self.x, self.y))
        else:
            pygame.draw.rect(screen, (60,0,0), self.rect)
        self.draw_hp()

# 爆炸特效
class Explosion:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.r = 20
        self.max_r = 130
        self.life = 25

    def update(self):
        self.r += 6
        self.life -= 1

    def draw(self):
        alpha = int(255 * (self.life/25))
        surf = pygame.Surface((self.max_r*2, self.max_r*2), pygame.SRCALPHA)
        pygame.draw.circle(surf, (255,80,0,alpha), (self.max_r,self.max_r), self.r)
        screen.blit(surf, (self.x-self.max_r, self.y-self.max_r))

# ===================== 绘制UI菜单 =====================
def draw_menu():
    if bg_img:
        screen.blit(bg_img, (0,0))
    else:
        screen.fill((20,60,20))
    title = font_lg.render("植物大战僵尸", True, (255,255,0))
    screen.blit(title, (WIDTH//2-180, 150))
    btn_start = pygame.Rect(WIDTH//2-130, 320, 260, 60)
    btn_quit = pygame.Rect(WIDTH//2-130, 410, 260, 60)
    pygame.draw.rect(screen, (139,69,19), btn_start, border_radius=12)
    pygame.draw.rect(screen, (200,0,0), btn_quit, border_radius=12)
    screen.blit(font_md.render("开始游戏", True, (255,255,255)), (WIDTH//2-90, 335))
    screen.blit(font_md.render("退出游戏", True, (255,255,255)), (WIDTH//2-90, 425))
    return btn_start, btn_quit

def draw_diff():
    if bg_img:
        screen.blit(bg_img, (0,0))
    else:
        screen.fill((20,60,20))
    screen.blit(font_lg.render("选择难度", True, (255,255,0)), (WIDTH//2-150, 130))
    btn_e = pygame.Rect(WIDTH//2-120, 260, 240, 55)
    btn_n = pygame.Rect(WIDTH//2-120, 330, 240, 55)
    btn_h = pygame.Rect(WIDTH//2-120, 400, 240, 55)
    pygame.draw.rect(screen, (0,180,0), btn_e, border_radius=10)
    pygame.draw.rect(screen, (255,255,0), btn_n, border_radius=10)
    pygame.draw.rect(screen, (255,0,0), btn_h, border_radius=10)
    screen.blit(font_md.render("简单", True, (255,255,255)), (WIDTH//2-45, 270))
    screen.blit(font_md.render("普通", True, (0,0,0)), (WIDTH//2-45, 340))
    screen.blit(font_md.render("困难", True, (255,255,255)), (WIDTH//2-45, 410))
    return btn_e, btn_n, btn_h

def draw_card_bar():
    pygame.draw.rect(screen, (40,40,40), (15, 120, 60, 420), border_radius=8)
    card_info = [
        ("sunflower", SunFlower.cost, 130),
        ("pea", PeaShooter.cost, 205),
        ("ice", IceShooter.cost, 280),
        ("nut", WallNut.cost, 355),
        ("cherry", CherryBomb.cost, 430)
    ]
    for name, cost, y in card_info:
        color = (0,180,0)
        if name=="ice": color=(100,200,255)
        if name=="nut": color=(150,100,50)
        if name=="cherry": color=(255,0,0)
        pygame.draw.rect(screen, color, (20, y, 50, 50), border_radius=6)
        screen.blit(font_sm.render(str(cost), True, (255,255,255)), (25, y+12))

def get_grid_pos(mx, my):
    if not (GRID_X < mx < GRID_X+COLS*CELL_SIZE and GRID_Y < my < GRID_Y+ROWS*CELL_SIZE):
        return None
    c = (mx - GRID_X) // CELL_SIZE
    r = (my - GRID_Y) // CELL_SIZE
    return r, c

def has_plant(r,c):
    x = GRID_X + c*CELL_SIZE
    y = GRID_Y + r*CELL_SIZE
    for p in plants:
        if p.x == x and p.y == y:
            return True
    return False

# ===================== 主循环 =====================
running = True
while running:
    mx, my = pygame.mouse.get_pos()
    for e in pygame.event.get():
        if e.type == pygame.QUIT:
            running = False
        if e.type == pygame.MOUSEBUTTONDOWN:
            if game_state == STATE_MENU:
                bs, bq = draw_menu()
                if bs.collidepoint(mx,my):
                    game_state = STATE_DIFF
                if bq.collidepoint(mx,my):
                    running = False
            elif game_state == STATE_DIFF:
                be, bn, bh = draw_diff()
                if be.collidepoint(mx,my): diff_level=0
                if bn.collidepoint(mx,my): diff_level=1
                if bh.collidepoint(mx,my): diff_level=2
                zombie_spawn = diff_cfg[diff_level]["spawn"]
                total_wave = diff_cfg[diff_level]["wave_num"]
                cur_wave = 1
                game_state = STATE_PLAY
                plants.clear()
                zombies.clear()
                bullets.clear()
                suns.clear()
                explosions.clear()
                sunlight = 50
                game_over = False
                wave_start = False
            elif game_state == STATE_PLAY and not game_over:
                # 收集阳光
                for s in suns[:]:
                    if s.rect.collidepoint(mx,my):
                        sunlight += s.val
                        suns.remove(s)
                        if sound_sun: sound_sun.play()
                # 选卡片
                if 20 < mx < 70:
                    if 130<my<180 and sunlight>=SunFlower.cost:
                        selected_plant = "sunflower"
                    elif 205<my<255 and sunlight>=PeaShooter.cost:
                        selected_plant = "pea"
                    elif 280<my<330 and sunlight>=IceShooter.cost:
                        selected_plant = "ice"
                    elif 355<my<405 and sunlight>=WallNut.cost:
                        selected_plant = "nut"
                    elif 430<my<480 and sunlight>=CherryBomb.cost:
                        selected_plant = "cherry"
                # 种植
                pos = get_grid_pos(mx, my)
                if pos and selected_plant:
                    r, c = pos
                    if not has_plant(r,c):
                        x = GRID_X + c*CELL_SIZE
                        y = GRID_Y + r*CELL_SIZE
                        if selected_plant=="sunflower":
                            plants.append(SunFlower(x,y))
                            sunlight -= SunFlower.cost
                        elif selected_plant=="pea":
                            plants.append(PeaShooter(x,y))
                            sunlight -= PeaShooter.cost
                        elif selected_plant=="ice":
                            plants.append(IceShooter(x,y))
                            sunlight -= IceShooter.cost
                        elif selected_plant=="nut":
                            plants.append(WallNut(x,y))
                            sunlight -= WallNut.cost
                        elif selected_plant=="cherry":
                            plants.append(CherryBomb(x,y))
                            sunlight -= CherryBomb.cost
                        if sound_plant: sound_plant.play()
                        selected_plant = None

    # 菜单
    if game_state == STATE_MENU:
        draw_menu()
    elif game_state == STATE_DIFF:
        draw_diff()
    elif game_state == STATE_PLAY:
        if bg_img:
            screen.blit(bg_img, (0,0))
        else:
            screen.fill((30,80,20))
        draw_card_bar()
        # 显示信息
        screen.blit(font_md.render(f"阳光: {sunlight}", True, (255,255,0)), (80, 30))
        screen.blit(font_md.render(f"波次: {cur_wave}/{total_wave}", True, (255,255,255)), (300, 30))

        if not game_over:
            # 波次系统
            if not wave_start:
                wave_start = True
                wave_timer = 0
            wave_timer += 1
            # 生成僵尸
            if wave_timer > zombie_spawn:
                wave_timer = 0
                row = random.randint(0,4)
                rd = random.random()
                if rd < 0.55:
                    zombies.append(NormalZom(row))
                elif rd < 0.85:
                    zombies.append(ConeZom(row))
                else:
                    zombies.append(BucketZom(row))
            # 一波打完进下一波
            if wave_start and wave_timer > zombie_spawn*8 and len(zombies)==0:
                cur_wave += 1
                wave_start = False
                if cur_wave > total_wave:
                    game_state = STATE_WIN

            # 阳光生成
            for s in suns:
                s.update()
            # 僵尸更新
            for z in zombies:
                z.update()
                if z.x < GRID_X - 30:
                    game_over = True
            # 植物逻辑
            for p in plants:
                if isinstance(p, SunFlower):
                    p.update()
                if isinstance(p, (PeaShooter, IceShooter)):
                    b = p.update()
                    if b: bullets.append(b)
                if isinstance(p, CherryBomb):
                    if p.update():
                        explosions.append(Explosion(p.x+40, p.y+40))
                        if sound_explode: sound_explode.play()
                        for z in zombies[:]:
                            if abs(z.x-p.x)<130 and abs(z.y-p.y)<130:
                                z.hp -= 160
                                if z.hp<=0: zombies.remove(z)
                        plants.remove(p)
            # 子弹碰撞
            for b in bullets[:]:
                b.update()
                if b.x>WIDTH:
                    bullets.remove(b)
                    continue
                for z in zombies[:]:
                    if b.rect.colliderect(z.rect):
                        z.hp -= b.dmg
                        if b.ice: z.ice_slow = 70
                        bullets.remove(b)
                        if z.hp<=0: zombies.remove(z)
                        break
            # 僵尸啃植物
            for z in zombies:
                for p in plants:
                    if z.rect.colliderect(p.rect):
                        p.hp -= 0.25
                        if p.hp <= 0:
                            plants.remove(p)
            # 爆炸特效
            for e in explosions[:]:
                e.update()
                if e.life <= 0:
                    explosions.remove(e)

        # 绘制所有
        for s in suns: s.draw()
        for p in plants: p.draw()
        for b in bullets: b.draw()
        for z in zombies: z.draw()
        for e in explosions: e.draw()

        if game_over:
            txt = font_lg.render("游戏失败", True, (255,0,0))
            screen.blit(txt, (WIDTH//2-160, HEIGHT//2))
    elif game_state == STATE_WIN:
        if bg_img:
            screen.blit(bg_img, (0,0))
        txt = font_lg.render("恭喜通关!", True, (0,255,0))
        screen.blit(txt, (WIDTH//2-180, HEIGHT//2))

    pygame.display.flip()
    clock.tick(FPS)

pygame.quit()
sys.exit()