Changes Initial commit
This commit is contained in:
1
ota_agent/__init__.py
Normal file
1
ota_agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""OTA Agent package."""
|
||||
306
ota_agent/app.py
Normal file
306
ota_agent/app.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from .cloud_client import CloudClient
|
||||
from .compose_manager import ComposeManager
|
||||
from .config import AgentConfig
|
||||
from .manifest_store import ManifestStore
|
||||
from .models import (
|
||||
AgentStatus,
|
||||
ConfirmUpgradeRequest,
|
||||
HeartbeatPayload,
|
||||
LocalStatusResponse,
|
||||
OperationResult,
|
||||
PostponeUpgradeRequest,
|
||||
ReportPayload,
|
||||
UpdateCheckRequest,
|
||||
utc_now,
|
||||
)
|
||||
from .mysql_backup import MysqlBackupManager
|
||||
from .registry_login import RegistryLoginManager
|
||||
from .state_store import StateStore
|
||||
|
||||
|
||||
class AgentService:
|
||||
def __init__(self, config: AgentConfig) -> None:
|
||||
self.config = config
|
||||
self.compose = ComposeManager(config.compose)
|
||||
self.cloud = CloudClient(config.cloud)
|
||||
self.state_store = StateStore(config.storage.state_file)
|
||||
self.manifest_store = ManifestStore(config.storage.manifest_dir)
|
||||
self.mysql_backup = MysqlBackupManager(
|
||||
enabled=config.mysql_backup.enabled,
|
||||
backup_dir=config.mysql_backup.backup_dir,
|
||||
dump_command=config.mysql_backup.dump_command,
|
||||
timeout_seconds=config.mysql_backup.timeout_seconds,
|
||||
compose_manager=self.compose,
|
||||
)
|
||||
self.registry_login = RegistryLoginManager(
|
||||
enabled=config.registry.enabled,
|
||||
server=config.registry.server,
|
||||
username=config.registry.username,
|
||||
password=config.registry.password,
|
||||
timeout_seconds=config.registry.timeout_seconds,
|
||||
compose_manager=self.compose,
|
||||
)
|
||||
self.state = self.state_store.load(
|
||||
vehicle_id=config.vehicle.vehicle_id,
|
||||
vin=config.vehicle.vin,
|
||||
current_release=config.vehicle.current_release,
|
||||
)
|
||||
self.lock = asyncio.Lock()
|
||||
self.upgrade_task: asyncio.Task[None] | None = None
|
||||
self.background_tasks: list[asyncio.Task[None]] = []
|
||||
self.last_backup_file: str | None = None
|
||||
self.last_target_release: str | None = None
|
||||
self.last_images: dict[str, str] = {}
|
||||
|
||||
async def startup(self) -> None:
|
||||
self.background_tasks = [
|
||||
asyncio.create_task(self._heartbeat_loop(), name="heartbeat-loop"),
|
||||
asyncio.create_task(self._update_loop(), name="update-loop"),
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for task in self.background_tasks:
|
||||
task.cancel()
|
||||
for task in self.background_tasks:
|
||||
with suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
def get_local_status(self) -> LocalStatusResponse:
|
||||
return LocalStatusResponse(
|
||||
vehicle_id=self.state.vehicle_id,
|
||||
vin=self.state.vin,
|
||||
current_release=self.state.current_release,
|
||||
status=self.state.status.value,
|
||||
available_update=self.state.available_update,
|
||||
last_result=self.state.last_result,
|
||||
updated_at=self.state.updated_at,
|
||||
)
|
||||
|
||||
async def postpone_upgrade(self, request: PostponeUpgradeRequest) -> OperationResult:
|
||||
async with self.lock:
|
||||
if not self.state.available_update:
|
||||
return OperationResult(success=False, detail="当前没有可延期的升级任务")
|
||||
self.state.status = AgentStatus.WAIT_USER_CONFIRM
|
||||
self.state.last_result = f"用户稍后提醒: {request.reason}"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.WAIT_USER_CONFIRM, self.state.last_result)
|
||||
return OperationResult(success=True, detail="已记录稍后提醒")
|
||||
|
||||
async def confirm_upgrade(self, request: ConfirmUpgradeRequest) -> OperationResult:
|
||||
async with self.lock:
|
||||
if not self.state.available_update:
|
||||
raise HTTPException(status_code=400, detail="当前没有可升级版本")
|
||||
if self.upgrade_task and not self.upgrade_task.done():
|
||||
return OperationResult(success=False, detail="升级任务正在执行中")
|
||||
self.state.last_result = f"用户已确认升级: {request.confirmed_by}"
|
||||
self._touch_and_save()
|
||||
self.upgrade_task = asyncio.create_task(self._execute_upgrade(), name="upgrade-task")
|
||||
return OperationResult(success=True, detail="升级任务已启动")
|
||||
|
||||
async def check_update_once(self) -> None:
|
||||
async with self.lock:
|
||||
payload = UpdateCheckRequest(
|
||||
vehicle_id=self.state.vehicle_id,
|
||||
vin=self.state.vin,
|
||||
current_release=self.state.current_release,
|
||||
)
|
||||
try:
|
||||
response = await self.cloud.check_update(payload)
|
||||
except Exception as exc:
|
||||
async with self.lock:
|
||||
self.state.last_check_at = utc_now()
|
||||
self.state.last_result = f"检查更新失败: {exc}"
|
||||
self._touch_and_save()
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
self.state.last_check_at = utc_now()
|
||||
if response.has_update and response.manifest:
|
||||
self.manifest_store.save(response.manifest)
|
||||
self.state.available_update = response.manifest
|
||||
self.state.status = AgentStatus.WAIT_USER_CONFIRM
|
||||
self.state.last_result = f"发现新版本 {response.manifest.release_version},等待人工确认"
|
||||
else:
|
||||
if self.state.status in {AgentStatus.IDLE, AgentStatus.SUCCESS, AgentStatus.WAIT_USER_CONFIRM}:
|
||||
self.state.status = AgentStatus.IDLE
|
||||
self.state.available_update = None
|
||||
self.state.last_result = response.message or "当前无可用更新"
|
||||
self._touch_and_save()
|
||||
|
||||
async def heartbeat_once(self) -> None:
|
||||
async with self.lock:
|
||||
if self.state.status == AgentStatus.SUCCESS:
|
||||
self.state.status = AgentStatus.IDLE
|
||||
self.state.last_result = self.state.last_result or "升级成功,恢复空闲状态"
|
||||
self._touch_and_save()
|
||||
payload = HeartbeatPayload(
|
||||
vehicle_id=self.state.vehicle_id,
|
||||
vin=self.state.vin,
|
||||
current_release=self.state.current_release,
|
||||
agent_status=self.state.status.value,
|
||||
target_release=self.last_target_release,
|
||||
last_result=self.state.last_result,
|
||||
images=self.last_images,
|
||||
backup_file=self.last_backup_file,
|
||||
updated_at=self.state.updated_at,
|
||||
)
|
||||
try:
|
||||
await self.cloud.heartbeat(payload)
|
||||
async with self.lock:
|
||||
self.state.last_heartbeat_at = utc_now()
|
||||
self._touch_and_save()
|
||||
except Exception as exc:
|
||||
async with self.lock:
|
||||
self.state.last_result = f"心跳上报失败: {exc}"
|
||||
self._touch_and_save()
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
while True:
|
||||
await self.heartbeat_once()
|
||||
await asyncio.sleep(self.config.polling.heartbeat_interval_seconds)
|
||||
|
||||
async def _update_loop(self) -> None:
|
||||
while True:
|
||||
await self.check_update_once()
|
||||
await asyncio.sleep(self.config.polling.update_interval_seconds)
|
||||
|
||||
async def _execute_upgrade(self) -> None:
|
||||
async with self.lock:
|
||||
manifest = self.state.available_update
|
||||
if not manifest:
|
||||
return
|
||||
target_release = manifest.release_version
|
||||
env_mapping = manifest.components.to_env_mapping()
|
||||
self.last_target_release = target_release
|
||||
self.last_images = dict(env_mapping)
|
||||
self.last_backup_file = None
|
||||
self.state.status = AgentStatus.BACKING_UP_DATABASE
|
||||
self.state.last_result = f"开始升级到 {target_release},先执行数据库备份"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.BACKING_UP_DATABASE, f"开始升级到 {target_release},先执行数据库备份", target_release)
|
||||
|
||||
try:
|
||||
backup_result = await self.mysql_backup.backup_before_upgrade(target_release)
|
||||
if not backup_result.success:
|
||||
raise RuntimeError(f"数据库备份失败: {backup_result.stderr or backup_result.stdout}")
|
||||
|
||||
backup_file_path = (backup_result.stdout or "").strip()
|
||||
self.last_backup_file = backup_file_path or None
|
||||
backup_file_name = Path(backup_file_path).name if backup_file_path else ""
|
||||
backup_message = f"数据库备份完成,开始拉取镜像: {target_release}"
|
||||
if backup_file_name:
|
||||
backup_message = f"数据库备份完成({backup_file_name}),开始拉取镜像: {target_release}"
|
||||
|
||||
async with self.lock:
|
||||
self.state.status = AgentStatus.PULLING_IMAGE
|
||||
self.state.last_result = backup_message
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.PULLING_IMAGE, backup_message, target_release)
|
||||
|
||||
login_result = await self.registry_login.login_if_needed()
|
||||
if not login_result.success:
|
||||
raise RuntimeError(f"私有仓库登录失败: {login_result.stderr or login_result.stdout}")
|
||||
|
||||
await self.compose.apply_manifest(env_mapping)
|
||||
pull_result = await self.compose.pull()
|
||||
if not pull_result.success:
|
||||
raise RuntimeError(f"拉取镜像失败: {pull_result.stderr or pull_result.stdout}")
|
||||
|
||||
async with self.lock:
|
||||
self.state.status = AgentStatus.RESTARTING_SERVICE
|
||||
self.state.last_result = "镜像拉取完成,开始重启服务"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.RESTARTING_SERVICE, "镜像拉取完成,开始重启服务", target_release)
|
||||
|
||||
up_result = await self.compose.up()
|
||||
if not up_result.success:
|
||||
raise RuntimeError(f"服务启动失败: {up_result.stderr or up_result.stdout}")
|
||||
|
||||
async with self.lock:
|
||||
self.state.status = AgentStatus.HEALTH_CHECKING
|
||||
self.state.last_result = "服务已启动,开始健康检查"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.HEALTH_CHECKING, "服务已启动,开始健康检查", target_release)
|
||||
|
||||
health = await self.compose.health_check()
|
||||
if not health.success:
|
||||
raise RuntimeError(health.detail)
|
||||
|
||||
async with self.lock:
|
||||
self.state.status = AgentStatus.SUCCESS
|
||||
self.state.current_release = target_release
|
||||
self.state.available_update = None
|
||||
self.state.last_result = f"升级成功: {target_release}"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.SUCCESS, f"升级成功: {target_release}", target_release)
|
||||
self.last_target_release = None
|
||||
except Exception as exc:
|
||||
await self.compose.rollback()
|
||||
async with self.lock:
|
||||
self.state.status = AgentStatus.ROLLED_BACK
|
||||
self.state.last_result = f"升级失败并已回滚: {exc}"
|
||||
self._touch_and_save()
|
||||
await self._safe_report(AgentStatus.ROLLED_BACK, f"升级失败并已回滚: {exc}", target_release)
|
||||
|
||||
async def _safe_report(self, status: AgentStatus, detail: str, release_version: str | None = None) -> None:
|
||||
payload = ReportPayload(
|
||||
vehicle_id=self.state.vehicle_id,
|
||||
vin=self.state.vin,
|
||||
current_release=self.state.current_release,
|
||||
target_release=release_version or self.last_target_release,
|
||||
agent_status=status.value,
|
||||
success=status == AgentStatus.SUCCESS,
|
||||
message=detail,
|
||||
images=self.last_images,
|
||||
backup_file=self.last_backup_file,
|
||||
)
|
||||
with suppress(Exception):
|
||||
await self.cloud.report(payload)
|
||||
|
||||
def _touch_and_save(self) -> None:
|
||||
self.state.updated_at = utc_now()
|
||||
self.state_store.save(self.state)
|
||||
|
||||
|
||||
def create_app(config: AgentConfig) -> FastAPI:
|
||||
service = AgentService(config)
|
||||
app = FastAPI(title="Vehicle OTA Agent", version="0.1.0")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup() -> None:
|
||||
await service.startup()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown() -> None:
|
||||
await service.shutdown()
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/ota/status", response_model=LocalStatusResponse)
|
||||
async def local_status() -> LocalStatusResponse:
|
||||
return service.get_local_status()
|
||||
|
||||
@app.post("/ota/check-update")
|
||||
async def check_update() -> OperationResult:
|
||||
await service.check_update_once()
|
||||
return OperationResult(success=True, detail="已执行一次检查更新")
|
||||
|
||||
@app.post("/ota/confirm", response_model=OperationResult)
|
||||
async def confirm_upgrade(request: ConfirmUpgradeRequest) -> OperationResult:
|
||||
return await service.confirm_upgrade(request)
|
||||
|
||||
@app.post("/ota/postpone", response_model=OperationResult)
|
||||
async def postpone_upgrade(request: PostponeUpgradeRequest) -> OperationResult:
|
||||
return await service.postpone_upgrade(request)
|
||||
|
||||
return app
|
||||
48
ota_agent/cloud_client.py
Normal file
48
ota_agent/cloud_client.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import CloudConfig
|
||||
from .models import HeartbeatPayload, ReportPayload, UpdateCheckRequest, UpdateCheckResponse
|
||||
|
||||
|
||||
class CloudClient:
|
||||
def __init__(self, config: CloudConfig) -> None:
|
||||
self.config = config
|
||||
self.timeout = httpx.Timeout(config.timeout_seconds)
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
self.config.token_header: self.config.token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def heartbeat(self, payload: HeartbeatPayload) -> None:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.config.base_url.rstrip("/") + self.config.heartbeat_path,
|
||||
headers=self._headers(),
|
||||
json=payload.model_dump(by_alias=True),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def check_update(self, payload: UpdateCheckRequest) -> UpdateCheckResponse:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.config.base_url.rstrip("/") + self.config.update_check_path,
|
||||
headers=self._headers(),
|
||||
json=payload.model_dump(by_alias=True),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return UpdateCheckResponse(**response.json())
|
||||
|
||||
async def report(self, payload: ReportPayload) -> None:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
self.config.base_url.rstrip("/") + self.config.report_path,
|
||||
headers=self._headers(),
|
||||
json=payload.model_dump(by_alias=True),
|
||||
)
|
||||
response.raise_for_status()
|
||||
123
ota_agent/compose_manager.py
Normal file
123
ota_agent/compose_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import ComposeConfig
|
||||
from .models import CommandResult, HealthcheckResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComposeManager:
|
||||
def __init__(self, config: ComposeConfig) -> None:
|
||||
self.config = config
|
||||
self.working_dir = Path(config.working_dir)
|
||||
self.working_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.compose_file_path = self.working_dir / config.file
|
||||
self.env_path = self.working_dir / config.env_file
|
||||
self.backup_env_path = self.working_dir / config.backup_env_file
|
||||
|
||||
async def apply_manifest(self, env_mapping: dict[str, str]) -> None:
|
||||
current_text = self.env_path.read_text(encoding="utf-8") if self.env_path.exists() else ""
|
||||
if self.env_path.exists():
|
||||
self.backup_env_path.write_text(current_text, encoding="utf-8")
|
||||
logger.info("已备份镜像环境文件: %s", self.backup_env_path)
|
||||
|
||||
env_lines = self._merge_env(current_text, env_mapping)
|
||||
self.env_path.write_text("\n".join(env_lines) + "\n", encoding="utf-8")
|
||||
logger.info("已写入新镜像配置到 %s", self.env_path)
|
||||
|
||||
async def rollback(self) -> None:
|
||||
if self.backup_env_path.exists():
|
||||
self.env_path.write_text(self.backup_env_path.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
logger.warning("升级失败,已恢复镜像环境文件: %s", self.env_path)
|
||||
await self.pull()
|
||||
await self.up()
|
||||
|
||||
async def pull(self) -> CommandResult:
|
||||
return await self.run_command(self._build_command(self.config.pull_command))
|
||||
|
||||
async def up(self) -> CommandResult:
|
||||
return await self.run_command(self._build_command(self.config.up_command))
|
||||
|
||||
async def health_check(self) -> HealthcheckResult:
|
||||
if not self.config.healthcheck_url:
|
||||
return HealthcheckResult(success=True, detail="healthcheck skipped")
|
||||
|
||||
deadline = asyncio.get_running_loop().time() + self.config.health_check_seconds
|
||||
timeout = httpx.Timeout(min(self.config.healthcheck_interval_seconds, self.config.health_check_seconds))
|
||||
last_detail = "healthcheck not started"
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
try:
|
||||
response = await client.get(self.config.healthcheck_url)
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info("健康检查通过: %s", response.status_code)
|
||||
return HealthcheckResult(success=True, detail=f"healthcheck ok: {response.status_code}")
|
||||
last_detail = f"healthcheck failed: {response.status_code} {response.text}"
|
||||
except Exception as exc:
|
||||
last_detail = f"healthcheck error: {exc}"
|
||||
|
||||
if asyncio.get_running_loop().time() >= deadline:
|
||||
logger.error("健康检查失败: %s", last_detail)
|
||||
return HealthcheckResult(success=False, detail=last_detail)
|
||||
await asyncio.sleep(self.config.healthcheck_interval_seconds)
|
||||
|
||||
async def run_command(self, command: list[str]) -> CommandResult:
|
||||
logger.info("执行命令: %s", " ".join(command))
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
cwd=str(self.working_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
result = CommandResult(
|
||||
success=process.returncode == 0,
|
||||
stdout=stdout.decode("utf-8", errors="ignore"),
|
||||
stderr=stderr.decode("utf-8", errors="ignore"),
|
||||
returncode=process.returncode or 0,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("命令执行成功(returncode=%s)", result.returncode)
|
||||
else:
|
||||
logger.error("命令执行失败(returncode=%s): %s", result.returncode, result.stderr or result.stdout)
|
||||
return result
|
||||
|
||||
def _build_command(self, command: str) -> list[str]:
|
||||
parts = command.split()
|
||||
if len(parts) >= 2 and parts[0] == "docker" and parts[1] == "compose":
|
||||
return [
|
||||
"docker",
|
||||
"compose",
|
||||
"-f",
|
||||
str(self.compose_file_path),
|
||||
"--env-file",
|
||||
str(self.env_path),
|
||||
*parts[2:],
|
||||
]
|
||||
return [*command.split()]
|
||||
|
||||
def _merge_env(self, current_text: str, env_mapping: dict[str, str]) -> list[str]:
|
||||
data: dict[str, str] = {}
|
||||
order: list[str] = []
|
||||
for raw_line in current_text.splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
if key not in order:
|
||||
order.append(key)
|
||||
data[key] = value.strip()
|
||||
|
||||
for key, value in env_mapping.items():
|
||||
if key not in order:
|
||||
order.append(key)
|
||||
data[key] = value
|
||||
|
||||
return [f"{key}={data[key]}" for key in order]
|
||||
88
ota_agent/config.py
Normal file
88
ota_agent/config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 19090
|
||||
|
||||
|
||||
class VehicleConfig(BaseModel):
|
||||
vehicle_id: str = "vehicle-test-001"
|
||||
vin: str = "vehicle-test-001"
|
||||
current_release: str = "vehicle-release-0.0.1"
|
||||
|
||||
|
||||
class CloudConfig(BaseModel):
|
||||
base_url: str = "http://127.0.0.1:8080"
|
||||
heartbeat_path: str = "/api/agent/heartbeat"
|
||||
update_check_path: str = "/api/agent/update-check"
|
||||
report_path: str = "/api/agent/report"
|
||||
timeout_seconds: int = 10
|
||||
token: str = "change-me"
|
||||
token_header: str = "X-OTA-TOKEN"
|
||||
|
||||
|
||||
class RegistryConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
server: str = ""
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
timeout_seconds: int = 30
|
||||
|
||||
|
||||
class ComposeConfig(BaseModel):
|
||||
working_dir: str
|
||||
file: str = "docker-compose.yml"
|
||||
env_file: str = ".env"
|
||||
backup_env_file: str = ".env.bak"
|
||||
pull_command: str = "docker compose pull"
|
||||
up_command: str = "docker compose up -d"
|
||||
health_check_seconds: int = 15
|
||||
healthcheck_url: str | None = None
|
||||
healthcheck_interval_seconds: int = 3
|
||||
|
||||
|
||||
class PollingConfig(BaseModel):
|
||||
update_interval_seconds: int = 30
|
||||
heartbeat_interval_seconds: int = 30
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
state_file: str
|
||||
manifest_dir: str
|
||||
log_dir: str = "./runtime/logs"
|
||||
|
||||
|
||||
class MysqlBackupConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
backup_dir: str = "./runtime/mysql-backups"
|
||||
dump_command: str = "docker exec mysql mysqldump -uroot -p123456 app_db > {backup_file}"
|
||||
timeout_seconds: int = 120
|
||||
|
||||
|
||||
class AgentConfig(BaseSettings):
|
||||
model_config = SettingsConfigDict(extra="ignore")
|
||||
|
||||
server: ServerConfig = Field(default_factory=ServerConfig)
|
||||
vehicle: VehicleConfig = Field(default_factory=VehicleConfig)
|
||||
cloud: CloudConfig = Field(default_factory=CloudConfig)
|
||||
registry: RegistryConfig = Field(default_factory=RegistryConfig)
|
||||
compose: ComposeConfig = Field(default_factory=lambda: ComposeConfig(working_dir="./runtime"))
|
||||
polling: PollingConfig = Field(default_factory=PollingConfig)
|
||||
storage: StorageConfig = Field(default_factory=lambda: StorageConfig(state_file="./runtime/state.json", manifest_dir="./runtime/manifests", log_dir="./runtime/logs"))
|
||||
mysql_backup: MysqlBackupConfig = Field(default_factory=MysqlBackupConfig)
|
||||
|
||||
|
||||
def load_config(config_path: str | Path) -> AgentConfig:
|
||||
path = Path(config_path)
|
||||
raw: dict[str, Any] = {}
|
||||
if path.exists():
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
return AgentConfig(**raw)
|
||||
33
ota_agent/logging_utils.py
Normal file
33
ota_agent/logging_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def setup_logging(log_dir: str) -> None:
|
||||
target_dir = Path(log_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = target_dir / "ota-agent.log"
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(name)s - %(message)s"
|
||||
)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=2 * 1024 * 1024,
|
||||
backupCount=3,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(stream_handler)
|
||||
17
ota_agent/manifest_store.py
Normal file
17
ota_agent/manifest_store.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from .models import UpdateManifest
|
||||
|
||||
|
||||
class ManifestStore:
|
||||
def __init__(self, manifest_dir: str) -> None:
|
||||
self.path = Path(manifest_dir)
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save(self, manifest: UpdateManifest) -> Path:
|
||||
target = self.path / f"{manifest.release_version}.json"
|
||||
target.write_text(json.dumps(manifest.model_dump(), ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
return target
|
||||
130
ota_agent/models.py
Normal file
130
ota_agent/models.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
def to_camel(value: str) -> str:
|
||||
parts = value.split("_")
|
||||
return parts[0] + "".join(part.capitalize() for part in parts[1:])
|
||||
|
||||
|
||||
class ApiModel(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
IDLE = "IDLE"
|
||||
HAS_UPDATE = "HAS_UPDATE"
|
||||
WAIT_USER_CONFIRM = "WAIT_USER_CONFIRM"
|
||||
BACKING_UP_DATABASE = "BACKING_UP_DATABASE"
|
||||
PULLING_IMAGE = "PULLING_IMAGE"
|
||||
RESTARTING_SERVICE = "RESTARTING_SERVICE"
|
||||
HEALTH_CHECKING = "HEALTH_CHECKING"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILED = "FAILED"
|
||||
ROLLED_BACK = "ROLLED_BACK"
|
||||
|
||||
|
||||
class ComponentImages(ApiModel):
|
||||
images: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
def to_env_mapping(self) -> dict[str, str]:
|
||||
return {key: value for key, value in self.images.items() if value}
|
||||
|
||||
|
||||
class UpdateManifest(ApiModel):
|
||||
release_version: str
|
||||
release_notes: str = ""
|
||||
components: ComponentImages = Field(default_factory=ComponentImages)
|
||||
upgrade_mode: str = "manual_confirm"
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
vehicle_id: str
|
||||
vin: str
|
||||
current_release: str
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
available_update: UpdateManifest | None = None
|
||||
last_check_at: str | None = None
|
||||
last_heartbeat_at: str | None = None
|
||||
last_result: str | None = None
|
||||
updated_at: str = Field(default_factory=lambda: utc_now())
|
||||
|
||||
|
||||
class HeartbeatPayload(ApiModel):
|
||||
vehicle_id: str
|
||||
vin: str
|
||||
current_release: str
|
||||
agent_status: str
|
||||
target_release: str | None = None
|
||||
last_result: str | None = None
|
||||
images: dict[str, str] = Field(default_factory=dict)
|
||||
backup_file: str | None = None
|
||||
updated_at: str
|
||||
|
||||
|
||||
class UpdateCheckRequest(ApiModel):
|
||||
vehicle_id: str
|
||||
vin: str
|
||||
current_release: str
|
||||
|
||||
|
||||
class UpdateCheckResponse(ApiModel):
|
||||
has_update: bool = False
|
||||
manifest: UpdateManifest | None = None
|
||||
message: str = ""
|
||||
|
||||
|
||||
class ReportPayload(ApiModel):
|
||||
vehicle_id: str
|
||||
vin: str
|
||||
current_release: str
|
||||
target_release: str | None = None
|
||||
agent_status: str
|
||||
success: bool
|
||||
message: str = ""
|
||||
images: dict[str, str] = Field(default_factory=dict)
|
||||
backup_file: str | None = None
|
||||
updated_at: str = Field(default_factory=lambda: utc_now())
|
||||
|
||||
|
||||
class ConfirmUpgradeRequest(BaseModel):
|
||||
confirmed_by: str = "android-app"
|
||||
|
||||
|
||||
class PostponeUpgradeRequest(BaseModel):
|
||||
reason: str = "user_postpone"
|
||||
|
||||
|
||||
class LocalStatusResponse(BaseModel):
|
||||
vehicle_id: str
|
||||
vin: str
|
||||
current_release: str
|
||||
status: str
|
||||
available_update: UpdateManifest | None = None
|
||||
last_result: str | None = None
|
||||
updated_at: str
|
||||
|
||||
|
||||
class CommandResult(BaseModel):
|
||||
success: bool
|
||||
stdout: str = ""
|
||||
stderr: str = ""
|
||||
returncode: int = 0
|
||||
|
||||
|
||||
class HealthcheckResult(BaseModel):
|
||||
success: bool
|
||||
detail: str
|
||||
|
||||
|
||||
class OperationResult(BaseModel):
|
||||
success: bool
|
||||
detail: str
|
||||
|
||||
|
||||
def utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
59
ota_agent/mysql_backup.py
Normal file
59
ota_agent/mysql_backup.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .compose_manager import ComposeManager
|
||||
from .models import CommandResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MysqlBackupManager:
|
||||
def __init__(self, *, enabled: bool, backup_dir: str, dump_command: str, timeout_seconds: int, compose_manager: ComposeManager) -> None:
|
||||
self.enabled = enabled
|
||||
self.backup_dir = Path(backup_dir)
|
||||
self.dump_command = dump_command
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.compose_manager = compose_manager
|
||||
if self.enabled:
|
||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def backup_before_upgrade(self, target_release: str) -> CommandResult:
|
||||
if not self.enabled:
|
||||
return CommandResult(success=True, stdout="mysql backup skipped", returncode=0)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
backup_file = self.backup_dir / f"mysql-backup-{target_release}-{timestamp}.sql"
|
||||
command = self.dump_command.replace("{backup_file}", str(backup_file))
|
||||
logger.info("开始执行MySQL备份,目标文件: %s", backup_file)
|
||||
result = await self._run_shell_command(command)
|
||||
if result.success:
|
||||
logger.info("MySQL备份完成: %s", backup_file)
|
||||
result.stdout = str(backup_file)
|
||||
else:
|
||||
logger.error("MySQL备份失败: %s", result.stderr or result.stdout)
|
||||
return result
|
||||
|
||||
async def _run_shell_command(self, command: str) -> CommandResult:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
cwd=str(self.compose_manager.working_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=self.timeout_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return CommandResult(success=False, stderr=f"mysql backup timeout after {self.timeout_seconds}s", returncode=-1)
|
||||
|
||||
return CommandResult(
|
||||
success=process.returncode == 0,
|
||||
stdout=stdout.decode("utf-8", errors="ignore"),
|
||||
stderr=stderr.decode("utf-8", errors="ignore"),
|
||||
returncode=process.returncode or 0,
|
||||
)
|
||||
70
ota_agent/registry_login.py
Normal file
70
ota_agent/registry_login.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from .compose_manager import ComposeManager
|
||||
from .models import CommandResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistryLoginManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool,
|
||||
server: str,
|
||||
username: str,
|
||||
password: str,
|
||||
timeout_seconds: int,
|
||||
compose_manager: ComposeManager,
|
||||
) -> None:
|
||||
self.enabled = enabled
|
||||
self.server = server
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.compose_manager = compose_manager
|
||||
|
||||
async def login_if_needed(self) -> CommandResult:
|
||||
if not self.enabled:
|
||||
return CommandResult(success=True, stdout="registry login skipped", returncode=0)
|
||||
|
||||
command = [
|
||||
"docker",
|
||||
"login",
|
||||
self.server,
|
||||
"-u",
|
||||
self.username,
|
||||
"-p",
|
||||
self.password,
|
||||
]
|
||||
logger.info("开始执行私有仓库登录: %s", self.server)
|
||||
result = await self._run_command(command)
|
||||
if result.success:
|
||||
logger.info("私有仓库登录成功: %s", self.server)
|
||||
else:
|
||||
logger.error("私有仓库登录失败: %s", result.stderr or result.stdout)
|
||||
return result
|
||||
|
||||
async def _run_command(self, command: list[str]) -> CommandResult:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
cwd=str(self.compose_manager.working_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=self.timeout_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
return CommandResult(success=False, stderr=f"registry login timeout after {self.timeout_seconds}s", returncode=-1)
|
||||
|
||||
return CommandResult(
|
||||
success=process.returncode == 0,
|
||||
stdout=stdout.decode("utf-8", errors="ignore"),
|
||||
stderr=stderr.decode("utf-8", errors="ignore"),
|
||||
returncode=process.returncode or 0,
|
||||
)
|
||||
39
ota_agent/state_store.py
Normal file
39
ota_agent/state_store.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from .models import AgentState
|
||||
|
||||
|
||||
class StateStore:
|
||||
def __init__(self, state_file: str) -> None:
|
||||
self.path = Path(state_file)
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load(self, vehicle_id: str, vin: str, current_release: str) -> AgentState:
|
||||
if not self.path.exists() or self._is_empty_file():
|
||||
state = AgentState(vehicle_id=vehicle_id, vin=vin, current_release=current_release)
|
||||
self.save(state)
|
||||
return state
|
||||
|
||||
try:
|
||||
raw = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
state = AgentState(vehicle_id=vehicle_id, vin=vin, current_release=current_release)
|
||||
self.save(state)
|
||||
return state
|
||||
|
||||
raw["vehicle_id"] = vehicle_id
|
||||
raw["vin"] = vin
|
||||
raw["current_release"] = current_release
|
||||
return AgentState(**raw)
|
||||
|
||||
def save(self, state: AgentState) -> None:
|
||||
self.path.write_text(
|
||||
state.model_dump_json(indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _is_empty_file(self) -> bool:
|
||||
return self.path.stat().st_size == 0
|
||||
Reference in New Issue
Block a user