124 lines
5.1 KiB
Python
124 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
from .config import ComposeConfig
|
|
from .models import CommandResult, HealthcheckResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ComposeManager:
|
|
def __init__(self, config: ComposeConfig) -> None:
|
|
self.config = config
|
|
self.working_dir = Path(config.working_dir)
|
|
self.working_dir.mkdir(parents=True, exist_ok=True)
|
|
self.compose_file_path = self.working_dir / config.file
|
|
self.env_path = self.working_dir / config.env_file
|
|
self.backup_env_path = self.working_dir / config.backup_env_file
|
|
|
|
async def apply_manifest(self, env_mapping: dict[str, str]) -> None:
|
|
current_text = self.env_path.read_text(encoding="utf-8") if self.env_path.exists() else ""
|
|
if self.env_path.exists():
|
|
self.backup_env_path.write_text(current_text, encoding="utf-8")
|
|
logger.info("已备份镜像环境文件: %s", self.backup_env_path)
|
|
|
|
env_lines = self._merge_env(current_text, env_mapping)
|
|
self.env_path.write_text("\n".join(env_lines) + "\n", encoding="utf-8")
|
|
logger.info("已写入新镜像配置到 %s", self.env_path)
|
|
|
|
async def rollback(self) -> None:
|
|
if self.backup_env_path.exists():
|
|
self.env_path.write_text(self.backup_env_path.read_text(encoding="utf-8"), encoding="utf-8")
|
|
logger.warning("升级失败,已恢复镜像环境文件: %s", self.env_path)
|
|
await self.pull()
|
|
await self.up()
|
|
|
|
async def pull(self) -> CommandResult:
|
|
return await self.run_command(self._build_command(self.config.pull_command))
|
|
|
|
async def up(self) -> CommandResult:
|
|
return await self.run_command(self._build_command(self.config.up_command))
|
|
|
|
async def health_check(self) -> HealthcheckResult:
|
|
if not self.config.healthcheck_url:
|
|
return HealthcheckResult(success=True, detail="healthcheck skipped")
|
|
|
|
deadline = asyncio.get_running_loop().time() + self.config.health_check_seconds
|
|
timeout = httpx.Timeout(min(self.config.healthcheck_interval_seconds, self.config.health_check_seconds))
|
|
last_detail = "healthcheck not started"
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
while True:
|
|
try:
|
|
response = await client.get(self.config.healthcheck_url)
|
|
if 200 <= response.status_code < 300:
|
|
logger.info("健康检查通过: %s", response.status_code)
|
|
return HealthcheckResult(success=True, detail=f"healthcheck ok: {response.status_code}")
|
|
last_detail = f"healthcheck failed: {response.status_code} {response.text}"
|
|
except Exception as exc:
|
|
last_detail = f"healthcheck error: {exc}"
|
|
|
|
if asyncio.get_running_loop().time() >= deadline:
|
|
logger.error("健康检查失败: %s", last_detail)
|
|
return HealthcheckResult(success=False, detail=last_detail)
|
|
await asyncio.sleep(self.config.healthcheck_interval_seconds)
|
|
|
|
async def run_command(self, command: list[str]) -> CommandResult:
|
|
logger.info("执行命令: %s", " ".join(command))
|
|
process = await asyncio.create_subprocess_exec(
|
|
*command,
|
|
cwd=str(self.working_dir),
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await process.communicate()
|
|
result = CommandResult(
|
|
success=process.returncode == 0,
|
|
stdout=stdout.decode("utf-8", errors="ignore"),
|
|
stderr=stderr.decode("utf-8", errors="ignore"),
|
|
returncode=process.returncode or 0,
|
|
)
|
|
if result.success:
|
|
logger.info("命令执行成功(returncode=%s)", result.returncode)
|
|
else:
|
|
logger.error("命令执行失败(returncode=%s): %s", result.returncode, result.stderr or result.stdout)
|
|
return result
|
|
|
|
def _build_command(self, command: str) -> list[str]:
|
|
parts = command.split()
|
|
if len(parts) >= 2 and parts[0] == "docker" and parts[1] == "compose":
|
|
return [
|
|
"docker",
|
|
"compose",
|
|
"-f",
|
|
str(self.compose_file_path),
|
|
"--env-file",
|
|
str(self.env_path),
|
|
*parts[2:],
|
|
]
|
|
return [*command.split()]
|
|
|
|
def _merge_env(self, current_text: str, env_mapping: dict[str, str]) -> list[str]:
|
|
data: dict[str, str] = {}
|
|
order: list[str] = []
|
|
for raw_line in current_text.splitlines():
|
|
line = raw_line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
key, value = line.split("=", 1)
|
|
key = key.strip()
|
|
if key not in order:
|
|
order.append(key)
|
|
data[key] = value.strip()
|
|
|
|
for key, value in env_mapping.items():
|
|
if key not in order:
|
|
order.append(key)
|
|
data[key] = value
|
|
|
|
return [f"{key}={data[key]}" for key in order]
|