AI 任务调度:别被 Cron 带偏了,这才是处理 LLM 任务的正确姿势
一、Cron 搞不定的那些事儿
用 Cron 或 Celery Beat 跑定时任务挺舒服的:定时触发、排队、重试,逻辑简单。但一旦涉及到 AI 任务,这套逻辑就崩了。
举个例子,某个 RAG 平台每天要处理 3 万篇文档入库,流程是“分块 → 向量化 → 写入 Milvus → 校验召回率”。看着像个标准流水线?麻烦在于:向量化得调外部 LLM API,有 60 RPM 的速率限制,而且文档长度不一,有的几行,有的几百块。如果固定并发数,短文档瞬间跑完,长文档堵在队列里,整体吞吐量直接被长尾拖死。更头疼的是,LLM API 偶尔会报 429 或 503,得指数退避重试,但绝不能让重试阻塞整个队列。
还有个容易被忽视的问题:资源感知。Milvus 的写入性能跟集群负载直接挂钩,如果同时跑 10 个写入任务,每个任务的延迟都会飙升。调度器得能感知下游服务的健康状态,动态调整并发数——这点 Cron 真做不到。
AI 任务调度的核心难点就在于:任务之间有资源竞争和依赖,执行环境是动态的,调度策略必须能自适应。
二、架构演进:从静态编排到动态调度
2.1 三层架构
graph TB subgraph 第一层: 静态调度 S1[Cron/Beat] --> Q1[固定队列] Q1 --> W1[Worker] Q1 --> W2[Worker] Q1 --> W3[Worker] end subgraph 第二层: DAG 编排 S2[调度器] --> DAG[DAG 引擎] DAG --> T1[分块任务] T1 --> T2[向量化任务] T2 --> T3[写入任务] T3 --> T4[校验任务] end subgraph 第三层: 自适应调度 S3[智能调度器] --> RM[资源监控] RM --> AD[自适应决策] AD --> PQ[优先级队列] PQ --> TW1[Task Worker 1] PQ --> TW2[Task Worker 2] RM --> AD end第一层适合无依赖的独立任务。第二层适合有依赖关系的流水线。第三层才是为 AI 任务准备的——资源竞争激烈、环境动态变化,必须得自适应。
2.2 自适应的核心:反馈回路
自适应调度的关键不在于“调度”,而在于“反馈”。调度器得根据执行结果动态调整策略。
sequenceDiagram participant Scheduler as 调度器 participant Queue as 优先级队列 participant Worker as Task Worker participant Monitor as 资源监控 loop 调度循环 Scheduler->>Monitor: 查询下游服务状态 Monitor-->>Scheduler: 延迟/错误率/负载 Scheduler->>Scheduler: 计算最优并发数 Scheduler->>Queue: 按优先级取任务 Queue-->>Scheduler: 返回任务 Scheduler->>Worker: 分配任务 Worker-->>Scheduler: 执行结果(成功/失败/延迟) Scheduler->>Scheduler: 更新任务优先级和并发限制 end三个关键决策点:
- 并发数自适应:根据下游服务的响应延迟和错误率,动态调整并发数。延迟上升就降低并发,延迟下降就提高并发,逻辑跟 TCP 拥塞控制差不多。
- 优先级动态调整:长任务在队列里等得越久,优先级越高(防止饥饿)。但如果有短任务能快速完成,优先调度短任务(提高吞吐量)。
- 退避策略:遇到 429/503 时,对该类任务执行指数退避,但不影响其他类型任务的调度。
三、生产级调度引擎实现
下面是一个支持 DAG 依赖、资源感知和自适应并发的调度引擎:
""" AI 任务自适应调度引擎 支持 DAG 依赖、资源感知、自适应并发和指数退避 """ import asyncio import time import uuid from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Coroutine, Optional from collections import defaultdict class TaskState(Enum): """任务状态机""" PENDING = "pending" READY = "ready" # 依赖已满足,可调度 RUNNING = "running" SUCCESS = "success" FAILED = "failed" RETRYING = "retrying" @dataclass class TaskDefinition: """任务定义:描述一个可调度的工作单元""" task_id: str = field(default_factory=lambda: uuid.uuid4().hex[:8]) task_type: str = "" # 任务类型,如 "chunk"、"embed"、"write" payload: dict = field(default_factory=dict) dependencies: list[str] = field(default_factory=list) # 依赖的任务 ID priority: int = 0 # 初始优先级,越高越先执行 max_retries: int = 3 timeout: float = 60.0 # 运行时状态(由调度器维护) state: TaskState = TaskState.PENDING retry_count: int = 0 result: Any = None error: Optional[str] = None created_at: float = field(default_factory=time.time) started_at: float = 0.0 finished_at: float = 0.0 @dataclass class ResourceType: """资源类型定义""" name: str max_concurrency: int = 10 current_concurrency: int = 0 # 自适应参数 min_concurrency: int = 1 target_latency_ms: float = 1000.0 current_latency_ms: float = 0.0 error_rate: float = 0.0 # 拥塞控制(类似 TCP AIMD) _last_adjust_time: float = 0.0 @property def available_slots(self) -> int: return max(0, self.max_concurrency - self.current_concurrency) class AdaptiveScheduler: """ 自适应调度引擎 核心能力:DAG 依赖解析、资源感知调度、自适应并发控制 """ def __init__(self): self._tasks: dict[str, TaskDefinition] = {} self._resources: dict[str, ResourceType] = {} self._handlers: dict[str, Callable[[TaskDefinition], Coroutine]] = {} self._task_type_resource_map: dict[str, str] = {} self._running = False # 调度间隔 self._schedule_interval = 0.1 # 资源监控间隔 self._monitor_interval = 5.0 def register_resource(self, resource: ResourceType): """注册受控资源""" self._resources[resource.name] = resource def map_task_type(self, task_type: str, resource_name: str): """映射任务类型到资源,同类型任务共享资源配额""" self._task_type_resource_map[task_type] = resource_name def on_task_type(self, task_type: str, handler: Callable[[TaskDefinition], Coroutine]): """注册任务类型的处理函数""" self._handlers[task_type] = handler def add_task(self, task: TaskDefinition): """添加任务到调度器""" self._tasks[task.task_id] = task # 无依赖的任务直接标记为 READY if not task.dependencies: task.state = TaskState.READY def add_tasks(self, tasks: list[TaskDefinition]): """批量添加任务""" for task in tasks: self.add_task(task) async def start(self): """启动调度引擎""" self._running = True await asyncio.gather( self._schedule_loop(), self._dependency_resolver_loop(), self._resource_monitor_loop(), ) async def stop(self): """停止调度引擎""" self._running = False async def _dependency_resolver_loop(self): """持续解析任务依赖,将满足条件的 PENDING 任务标记为 READY""" while self._running: for task in self._tasks.values(): if task.state != TaskState.PENDING: continue # 检查所有依赖是否已成功完成 deps_met = all( self._tasks.get(dep_id, TaskDefinition()).state == TaskState.SUCCESS for dep_id in task.dependencies ) if deps_met: task.state = TaskState.READY # 继承上游任务的结果作为上下文 task.payload["_upstream_results"] = { dep_id: self._tasks[dep_id].result for dep_id in task.dependencies if dep_id in self._tasks } await asyncio.sleep(0.05) async def _schedule_loop(self): """核心调度循环:从就绪队列中选取任务并分配资源""" while self._running: # 收集所有 READY 状态的任务 ready_tasks = [ t for t in self._tasks.values() if t.state == TaskState.READY ] if not ready_tasks: await asyncio.sleep(self._schedule_interval) continue # 按优先级排序(等待时间越长优先级越高,短任务有加分) for task in ready_tasks: wait_time = time.time() - task.created_at # 动态优先级 = 基础优先级 + 等待时间加权 - 预估耗时加权 estimated_duration = task.payload.get("estimated_duration_ms", 1000) task.priority = int( task.priority + wait_time * 2 - estimated_duration / 100 ) ready_tasks.sort(key=lambda t: t.priority, reverse=True) for task in ready_tasks: resource_name = self._task_type_resource_map.get(task.task_type) if resource_name and resource_name in self._resources: resource = self._resources[resource_name] if resource.available_slots <= 0: continue # 资源已满,跳过 resource.current_concurrency += 1 # 异步执行任务 asyncio.create_task(self._execute_task(task)) await asyncio.sleep(self._schedule_interval) async def _execute_task(self, task: TaskDefinition): """执行单个任务,处理超时、重试和资源释放""" task.state = TaskState.RUNNING task.started_at = time.time() handler = self._handlers.get(task.task_type) if handler is None: task.state = TaskState.FAILED task.error = f"未注册的任务类型: {task.task_type}" self._release_resource(task) return try: result = await asyncio.wait_for( handler(task), timeout=task.timeout, ) task.state = TaskState.SUCCESS task.result = result task.finished_at = time.time() # 更新资源监控指标(成功时降低错误率) self._update_resource_metrics(task, success=True) except asyncio.TimeoutError: await self._handle_task_failure(task, "任务超时") except Exception as e: await self._handle_task_failure(task, str(e)) finally: self._release_resource(task) async def _handle_task_failure(self, task: TaskDefinition, error: str): """处理任务失败:重试或标记为最终失败""" task.retry_count += 1 task.error = error if task.retry_count < task.max_retries: # 指数退避重试 task.state = TaskState.RETRYING backoff = min(2 ** task.retry_count, 60) # 最大 60 秒 await asyncio.sleep(backoff) task.state = TaskState.READY # 重新入队 else: task.state = TaskState.FAILED task.finished_at = time.time() # 更新资源监控指标(失败时提高错误率) self._update_resource_metrics(task, success=False) def _release_resource(self, task: TaskDefinition): """释放任务占用的资源槽位""" resource_name = self._task_type_resource_map.get(task.task_type) if resource_name and resource_name in self._resources: resource = self._resources[resource_name] resource.current_concurrency = max(0, resource.current_concurrency - 1) def _update_resource_metrics(self, task: TaskDefinition, success: bool): """根据任务执行结果更新资源监控指标""" resource_name = self._task_type_resource_map.get(task.task_type) if not resource_name or resource_name not in self._resources: return resource = self._resources[resource_name] if task.started_at > 0 and task.finished_at > 0: latency_ms = (task.finished_at - task.started_at) * 1000 # 指数移动平均更新延迟 alpha = 0.3 resource.current_latency_ms = ( alpha * latency_ms + (1 - alpha) * resource.current_latency_ms ) # 更新错误率(指数移动平均) error_signal = 0.0 if success else 1.0 resource.error_rate = 0.2 * error_signal + 0.8 * resource.error_rate async def _resource_monitor_loop(self): """资源监控循环:根据延迟和错误率自适应调整并发数""" while self._running: for resource in self._resources.values(): now = time.time() # 至少间隔 5 秒调整一次,避免震荡 if now - resource._last_adjust_time < 5.0: continue resource._last_adjust_time = now old_max = resource.max_concurrency if resource.error_rate > 0.3: # 错误率过高,急剧降低并发(类似 TCP 拥塞避免的乘法减小) resource.max_concurrency = max( resource.min_concurrency, int(resource.max_concurrency * 0.5), ) elif resource.current_latency_ms > resource.target_latency_ms * 2: # 延迟过高,温和降低并发 resource.max_concurrency = max( resource.min_concurrency, resource.max_concurrency - 2, ) elif (resource.error_rate < 0.05 and resource.current_latency_ms < resource.target_latency_ms): # 延迟和错误率都正常,温和提高并发(类似 TCP 的加法增大) resource.max_concurrency = min( 50, # 硬上限 resource.max_concurrency + 1, ) if old_max != resource.max_concurrency: print(f"[资源调整] {resource.name}: " f"并发 {old_max} → {resource.max_concurrency} " f"(延迟={resource.current_latency_ms:.0f}ms, " f"错误率={resource.error_rate:.1%})") await asyncio.sleep(self._monitor_interval) def get_stats(self) -> dict: """获取调度器统计信息""" states = defaultdict(int) for task in self._tasks.values(): states[task.state.value] += 1 return { "tasks": dict(states), "resources": { name: { "concurrency": f"{r.current_concurrency}/{r.max_concurrency}", "latency_ms": f"{r.current_latency_ms:.0f}", "error_rate": f"{r.error_rate:.1%}", } for name, r in self._resources.items() }, } # ===== 使用示例:RAG 文档入库调度 ===== async def chunk_handler(task: TaskDefinition) -> dict: """分块任务处理器""" doc = task.payload.get("content", "") # 模拟分块逻辑 chunks = [doc[i:i+500] for i in range(0, len(doc), 500)] await asyncio.sleep(0.05) return {"chunk_count": len(chunks), "chunks": chunks} async def embed_handler(task: TaskDefinition) -> dict: """向量化任务处理器""" upstream = task.payload.get("_upstream_results", {}) chunk_count = upstream.get("chunk_count", 0) if upstream else 0 await asyncio.sleep(0.1 * chunk_count) # 模拟 API 调用 return {"vector_count": chunk_count} async def write_handler(task: TaskDefinition) -> dict: """写入 Milvus 任务处理器""" await asyncio.sleep(0.2) return {"written": True} async def main(): """构建 RAG 入库 DAG 并调度执行""" scheduler = AdaptiveScheduler() # 注册资源:向量化 API 有速率限制 scheduler.register_resource(ResourceType( name="embed_api", max_concurrency=5, min_concurrency=1, target_latency_ms=500.0, )) scheduler.register_resource(ResourceType( name="milvus", max_concurrency=3, min_concurrency=1, target_latency_ms=1000.0, )) # 映射任务类型到资源 scheduler.map_task_type("embed", "embed_api") scheduler.map_task_type("write", "milvus") # 注册处理器 scheduler.on_task_type("chunk", chunk_handler) scheduler.on_task_type("embed", embed_handler) scheduler.on_task_type("write", write_handler) # 构建 DAG:chunk → embed → write doc_id = "doc-001" chunk_task = TaskDefinition( task_type="chunk", payload={"content": "这是一篇很长的文档..." * 100}, ) embed_task = TaskDefinition( task_type="embed", dependencies=[chunk_task.task_id], ) write_task = TaskDefinition( task_type="write", dependencies=[embed_task.task_id], ) scheduler.add_tasks([chunk_task, embed_task, write_task]) # 启动调度(实际生产中会持续运行) run_task = asyncio.create_task(scheduler.start()) await asyncio.sleep(5) await scheduler.stop() print(scheduler.get_stats()) if __name__ == "__main__": asyncio.run(main())这个框架重点做了三件事:
DAG 依赖解析与上下文传递:_dependency_resolver_loop持续扫描 PENDING 任务,检查依赖是否满足。一旦依赖满足,就把上游任务的结果注入_upstream_results,下游任务可以直接用,不用额外查数据库。
AIMD 拥塞控制:直接借鉴 TCP 的加法增大/乘法减小(AIMD)算法。错误率超过 30% 时并发数减半(乘法减小),延迟和错误率正常时并发数加 1(加法增大)。稳定时慢慢加吞吐量,异常时快速降压力。
动态优先级:优先级不是固定的,而是根据等待时间和预估耗时动态计算。等得越久优先级越高(防饥饿),耗时越短优先级越高(提吞吐)。两者加权平衡,避免走极端。
四、自适应调度的坑:复杂度与可观测性
4.1 调试困难:调度决策不透明
自适应调度最头疼的问题就是决策过程不透明。任务执行异常时,很难判断是调度策略的问题还是任务本身的问题。比如任务延迟升高,可能是因为并发数调得太高,也可能是因为下游服务本身变慢了。
应对策略:每次调度决策都记录日志,包括当前并发数、延迟、错误率和调整方向。用结构化日志(JSON 格式)方便后续分析。同时暴露 Prometheus 指标,用 Grafana 面板实时监控调度状态。
4.2 参数调优成本高
AIMD 的参数(错误率阈值、延迟阈值、增减步长)得根据实际业务调。不同下游服务的特性差异很大:LLM API 对并发敏感,Milvus 对批量大小敏感。一套参数不可能适用所有场景。
应对策略:为每种资源类型设置独立的参数。先用保守参数上线,然后根据监控数据逐步调整。建议初始并发数设为预期值的 50%,避免上线即过载。
4.3 适用边界与禁用场景
- 简单定时任务:每天凌晨跑个报表,用 Cron 就够了。自适应调度引入的复杂度不值得。
- 无依赖的独立任务:如果任务之间没有依赖关系,用简单的并发池(
asyncio.Semaphore)更直接。 - 强一致性要求:调度器的状态在内存中,进程重启会丢失。如果需要持久化调度状态,需要配合数据库或 Redis。
五、总结
AI 任务调度的核心挑战是资源竞争和动态环境。传统静态调度搞不定 LLM API 的速率限制和下游服务的负载波动。自适应调度通过反馈回路动态调整并发数,借鉴 TCP 的 AIMD 算法实现拥塞控制。DAG 依赖解析支持多阶段流水线,动态优先级在防饥饿和提吞吐之间取得平衡。代价是调试困难和参数调优成本高,需要完善的可观测性支撑。在简单定时任务和无依赖独立任务场景下,建议选择更轻量的方案。