Skip to content

Commit abe12ab

Browse files
author
kevin.zhang
committed
feat: add support for maximum concurrency of /api/v1/videos
1 parent 414bcb0 commit abe12ab

File tree

7 files changed

+170
-2
lines changed

7 files changed

+170
-2
lines changed
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import threading
2+
from typing import Callable, Any, Dict
3+
4+
5+
class TaskManager:
6+
def __init__(self, max_concurrent_tasks: int):
7+
self.max_concurrent_tasks = max_concurrent_tasks
8+
self.current_tasks = 0
9+
self.lock = threading.Lock()
10+
self.queue = self.create_queue()
11+
12+
def create_queue(self):
13+
raise NotImplementedError()
14+
15+
def add_task(self, func: Callable, *args: Any, **kwargs: Any):
16+
with self.lock:
17+
if self.current_tasks < self.max_concurrent_tasks:
18+
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
19+
self.execute_task(func, *args, **kwargs)
20+
else:
21+
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
22+
self.enqueue({"func": func, "args": args, "kwargs": kwargs})
23+
24+
def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
25+
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
26+
thread.start()
27+
28+
def run_task(self, func: Callable, *args: Any, **kwargs: Any):
29+
try:
30+
with self.lock:
31+
self.current_tasks += 1
32+
func(*args, **kwargs) # 在这里调用函数,传递*args和**kwargs
33+
finally:
34+
self.task_done()
35+
36+
def check_queue(self):
37+
with self.lock:
38+
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
39+
task_info = self.dequeue()
40+
func = task_info['func']
41+
args = task_info.get('args', ())
42+
kwargs = task_info.get('kwargs', {})
43+
self.execute_task(func, *args, **kwargs)
44+
45+
def task_done(self):
46+
with self.lock:
47+
self.current_tasks -= 1
48+
self.check_queue()
49+
50+
def enqueue(self, task: Dict):
51+
raise NotImplementedError()
52+
53+
def dequeue(self):
54+
raise NotImplementedError()
55+
56+
def is_queue_empty(self):
57+
raise NotImplementedError()
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from queue import Queue
2+
from typing import Dict
3+
4+
from app.controllers.manager.base_manager import TaskManager
5+
6+
7+
class InMemoryTaskManager(TaskManager):
8+
def create_queue(self):
9+
return Queue()
10+
11+
def enqueue(self, task: Dict):
12+
self.queue.put(task)
13+
14+
def dequeue(self):
15+
return self.queue.get()
16+
17+
def is_queue_empty(self):
18+
return self.queue.empty()
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from typing import Dict
3+
4+
import redis
5+
6+
from app.controllers.manager.base_manager import TaskManager
7+
from app.models.schema import VideoParams
8+
from app.services import task as tm
9+
10+
FUNC_MAP = {
11+
'start': tm.start,
12+
# 'start_test': tm.start_test
13+
}
14+
15+
16+
class RedisTaskManager(TaskManager):
17+
def __init__(self, max_concurrent_tasks: int, redis_url: str):
18+
self.redis_client = redis.Redis.from_url(redis_url)
19+
super().__init__(max_concurrent_tasks)
20+
21+
def create_queue(self):
22+
return "task_queue"
23+
24+
def enqueue(self, task: Dict):
25+
task_with_serializable_params = task.copy()
26+
27+
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
28+
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
29+
30+
# 将函数对象转换为其名称
31+
task_with_serializable_params['func'] = task['func'].__name__
32+
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
33+
34+
def dequeue(self):
35+
task_json = self.redis_client.lpop(self.queue)
36+
if task_json:
37+
task_info = json.loads(task_json)
38+
# 将函数名称转换回函数对象
39+
task_info['func'] = FUNC_MAP[task_info['func']]
40+
41+
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
42+
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
43+
44+
return task_info
45+
return None
46+
47+
def is_queue_empty(self):
48+
return self.redis_client.llen(self.queue) == 0

app/controllers/v1/video.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from app.config import config
1212
from app.controllers import base
13+
from app.controllers.manager.memory_manager import InMemoryTaskManager
14+
from app.controllers.manager.redis_manager import RedisTaskManager
1315
from app.controllers.v1.base import new_router
1416
from app.models.exception import HttpException
1517
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
@@ -22,6 +24,35 @@
2224
# router = new_router(dependencies=[Depends(base.verify_token)])
2325
router = new_router()
2426

27+
_enable_redis = config.app.get("enable_redis", False)
28+
_redis_host = config.app.get("redis_host", "localhost")
29+
_redis_port = config.app.get("redis_port", 6379)
30+
_redis_db = config.app.get("redis_db", 0)
31+
_redis_password = config.app.get("redis_password", None)
32+
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
33+
34+
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
35+
# 根据配置选择合适的任务管理器
36+
if _enable_redis:
37+
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
38+
else:
39+
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
40+
41+
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
42+
# async def create_video_test(request: Request, body: TaskVideoRequest):
43+
# task_id = utils.get_uuid()
44+
# request_id = base.get_task_id(request)
45+
# try:
46+
# task = {
47+
# "task_id": task_id,
48+
# "request_id": request_id,
49+
# "params": body.dict(),
50+
# }
51+
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
52+
# return utils.get_response(200, task)
53+
# except ValueError as e:
54+
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
55+
2556

2657
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
2758
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
@@ -34,7 +65,8 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
3465
"params": body.dict(),
3566
}
3667
sm.state.update_task(task_id)
37-
background_tasks.add_task(tm.start, task_id=task_id, params=body)
68+
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
69+
task_manager.add_task(tm.start, task_id=task_id, params=body)
3870
logger.success(f"video created: {utils.to_json(task)}")
3971
return utils.get_response(200, task)
4072
except ValueError as e:

app/models/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class MaterialInfo:
7373
# ]
7474

7575

76-
class VideoParams:
76+
class VideoParams(BaseModel):
7777
"""
7878
{
7979
"video_subject": "",

app/services/task.py

+6
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,9 @@ def start(task_id, params: VideoParams):
173173
}
174174
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
175175
return kwargs
176+
177+
178+
# def start_test(task_id, params: VideoParams):
179+
# print(f"start task {task_id} \n")
180+
# time.sleep(5)
181+
# print(f"task {task_id} finished \n")

config.example.toml

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
listen_host = "127.0.0.1"
2+
listen_port = 8502
3+
14
[app]
25
# Pexels API Key
36
# Register at https://www.pexels.com/api/ to get your API key.
@@ -134,6 +137,10 @@
134137
redis_host = "localhost"
135138
redis_port = 6379
136139
redis_db = 0
140+
redis_password = ""
141+
142+
# 文生视频时的最大并发任务数
143+
max_concurrent_tasks = 5
137144

138145
[whisper]
139146
# Only effective when subtitle_provider is "whisper"

0 commit comments

Comments
 (0)