Changes Initial commit
This commit is contained in:
123
ota_agent/compose_manager.py
Normal file
123
ota_agent/compose_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import ComposeConfig
|
||||
from .models import CommandResult, HealthcheckResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComposeManager:
|
||||
def __init__(self, config: ComposeConfig) -> None:
|
||||
self.config = config
|
||||
self.working_dir = Path(config.working_dir)
|
||||
self.working_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.compose_file_path = self.working_dir / config.file
|
||||
self.env_path = self.working_dir / config.env_file
|
||||
self.backup_env_path = self.working_dir / config.backup_env_file
|
||||
|
||||
async def apply_manifest(self, env_mapping: dict[str, str]) -> None:
|
||||
current_text = self.env_path.read_text(encoding="utf-8") if self.env_path.exists() else ""
|
||||
if self.env_path.exists():
|
||||
self.backup_env_path.write_text(current_text, encoding="utf-8")
|
||||
logger.info("已备份镜像环境文件: %s", self.backup_env_path)
|
||||
|
||||
env_lines = self._merge_env(current_text, env_mapping)
|
||||
self.env_path.write_text("\n".join(env_lines) + "\n", encoding="utf-8")
|
||||
logger.info("已写入新镜像配置到 %s", self.env_path)
|
||||
|
||||
async def rollback(self) -> None:
|
||||
if self.backup_env_path.exists():
|
||||
self.env_path.write_text(self.backup_env_path.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
logger.warning("升级失败,已恢复镜像环境文件: %s", self.env_path)
|
||||
await self.pull()
|
||||
await self.up()
|
||||
|
||||
async def pull(self) -> CommandResult:
|
||||
return await self.run_command(self._build_command(self.config.pull_command))
|
||||
|
||||
async def up(self) -> CommandResult:
|
||||
return await self.run_command(self._build_command(self.config.up_command))
|
||||
|
||||
async def health_check(self) -> HealthcheckResult:
|
||||
if not self.config.healthcheck_url:
|
||||
return HealthcheckResult(success=True, detail="healthcheck skipped")
|
||||
|
||||
deadline = asyncio.get_running_loop().time() + self.config.health_check_seconds
|
||||
timeout = httpx.Timeout(min(self.config.healthcheck_interval_seconds, self.config.health_check_seconds))
|
||||
last_detail = "healthcheck not started"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
try:
|
||||
response = await client.get(self.config.healthcheck_url)
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info("健康检查通过: %s", response.status_code)
|
||||
return HealthcheckResult(success=True, detail=f"healthcheck ok: {response.status_code}")
|
||||
last_detail = f"healthcheck failed: {response.status_code} {response.text}"
|
||||
except Exception as exc:
|
||||
last_detail = f"healthcheck error: {exc}"
|
||||
|
||||
if asyncio.get_running_loop().time() >= deadline:
|
||||
logger.error("健康检查失败: %s", last_detail)
|
||||
return HealthcheckResult(success=False, detail=last_detail)
|
||||
await asyncio.sleep(self.config.healthcheck_interval_seconds)
|
||||
|
||||
async def run_command(self, command: list[str]) -> CommandResult:
|
||||
logger.info("执行命令: %s", " ".join(command))
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
cwd=str(self.working_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
result = CommandResult(
|
||||
success=process.returncode == 0,
|
||||
stdout=stdout.decode("utf-8", errors="ignore"),
|
||||
stderr=stderr.decode("utf-8", errors="ignore"),
|
||||
returncode=process.returncode or 0,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("命令执行成功(returncode=%s)", result.returncode)
|
||||
else:
|
||||
logger.error("命令执行失败(returncode=%s): %s", result.returncode, result.stderr or result.stdout)
|
||||
return result
|
||||
|
||||
def _build_command(self, command: str) -> list[str]:
|
||||
parts = command.split()
|
||||
if len(parts) >= 2 and parts[0] == "docker" and parts[1] == "compose":
|
||||
return [
|
||||
"docker",
|
||||
"compose",
|
||||
"-f",
|
||||
str(self.compose_file_path),
|
||||
"--env-file",
|
||||
str(self.env_path),
|
||||
*parts[2:],
|
||||
]
|
||||
return [*command.split()]
|
||||
|
||||
def _merge_env(self, current_text: str, env_mapping: dict[str, str]) -> list[str]:
|
||||
data: dict[str, str] = {}
|
||||
order: list[str] = []
|
||||
for raw_line in current_text.splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
if key not in order:
|
||||
order.append(key)
|
||||
data[key] = value.strip()
|
||||
|
||||
for key, value in env_mapping.items():
|
||||
if key not in order:
|
||||
order.append(key)
|
||||
data[key] = value
|
||||
|
||||
return [f"{key}={data[key]}" for key in order]
|
||||
Reference in New Issue
Block a user