Skip to content

Commit 3d45348

Browse files
author
kevin.zhang
committed
feat: add redis support for task state management
1 parent a0944fa commit 3d45348

File tree

5 files changed

+113
-46
lines changed

5 files changed

+113
-46
lines changed

app/controllers/v1/video.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
2929
"request_id": request_id,
3030
"params": body.dict(),
3131
}
32-
sm.update_task(task_id)
32+
sm.state.update_task(task_id)
3333
background_tasks.add_task(tm.start, task_id=task_id, params=body)
3434
logger.success(f"video created: {utils.to_json(task)}")
3535
return utils.get_response(200, task)
@@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
4646
endpoint = endpoint.rstrip("/")
4747

4848
request_id = base.get_task_id(request)
49-
task = sm.get_task(task_id)
49+
task = sm.state.get_task(task_id)
5050
if task:
5151
task_dir = utils.task_dir()
5252

app/services/state.py

+93-32
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,96 @@
1-
# State Management
2-
# This module is responsible for managing the state of the application.
3-
import math
1+
import ast
2+
import json
3+
from abc import ABC, abstractmethod
4+
import redis
5+
from app.config import config
6+
from app.models import const
47

5-
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
6-
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
78

8-
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
9-
# If your application is single-node, you can use memory to store the state.
9+
# Base class for state management
10+
class BaseState(ABC):
1011

11-
from app.models import const
12-
from app.utils import utils
13-
14-
_tasks = {}
15-
16-
17-
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
18-
"""
19-
Set the state of the task.
20-
"""
21-
progress = int(progress)
22-
if progress > 100:
23-
progress = 100
24-
25-
_tasks[task_id] = {
26-
"state": state,
27-
"progress": progress,
28-
**kwargs,
29-
}
30-
31-
def get_task(task_id: str):
32-
"""
33-
Get the state of the task.
34-
"""
35-
return _tasks.get(task_id, None)
12+
@abstractmethod
13+
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
14+
pass
15+
16+
@abstractmethod
17+
def get_task(self, task_id: str):
18+
pass
19+
20+
21+
# Memory state management
22+
class MemoryState(BaseState):
23+
24+
def __init__(self):
25+
self._tasks = {}
26+
27+
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
28+
progress = int(progress)
29+
if progress > 100:
30+
progress = 100
31+
32+
self._tasks[task_id] = {
33+
"state": state,
34+
"progress": progress,
35+
**kwargs,
36+
}
37+
38+
def get_task(self, task_id: str):
39+
return self._tasks.get(task_id, None)
40+
41+
42+
# Redis state management
43+
class RedisState(BaseState):
44+
45+
def __init__(self, host='localhost', port=6379, db=0):
46+
self._redis = redis.StrictRedis(host=host, port=port, db=db)
47+
48+
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
49+
progress = int(progress)
50+
if progress > 100:
51+
progress = 100
52+
53+
fields = {
54+
"state": state,
55+
"progress": progress,
56+
**kwargs,
57+
}
58+
59+
for field, value in fields.items():
60+
self._redis.hset(task_id, field, str(value))
61+
62+
def get_task(self, task_id: str):
63+
task_data = self._redis.hgetall(task_id)
64+
if not task_data:
65+
return None
66+
67+
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
68+
return task
69+
70+
@staticmethod
71+
def _convert_to_original_type(value):
72+
"""
73+
Convert the value from byte string to its original data type.
74+
You can extend this method to handle other data types as needed.
75+
"""
76+
value_str = value.decode('utf-8')
77+
78+
try:
79+
# try to convert byte string array to list
80+
return ast.literal_eval(value_str)
81+
except (ValueError, SyntaxError):
82+
pass
83+
84+
if value_str.isdigit():
85+
return int(value_str)
86+
# Add more conversions here if needed
87+
return value_str
88+
89+
90+
# Global state
91+
_enable_redis = config.app.get("enable_redis", False)
92+
_redis_host = config.app.get("redis_host", "localhost")
93+
_redis_port = config.app.get("redis_port", 6379)
94+
_redis_db = config.app.get("redis_db", 0)
95+
96+
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()

app/services/task.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
2828
}
2929
"""
3030
logger.info(f"start task: {task_id}")
31-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
31+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
3232

3333
video_subject = params.video_subject
3434
voice_name = voice.parse_voice_name(params.voice_name)
@@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
4444
else:
4545
logger.debug(f"video script: \n{video_script}")
4646

47-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
47+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
4848

4949
logger.info("\n\n## generating video terms")
5050
video_terms = params.video_terms
@@ -70,21 +70,21 @@ def start(task_id, params: VideoParams):
7070
with open(script_file, "w", encoding="utf-8") as f:
7171
f.write(utils.to_json(script_data))
7272

73-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
73+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
7474

7575
logger.info("\n\n## generating audio")
7676
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
7777
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
7878
if sub_maker is None:
79-
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
79+
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
8080
logger.error(
8181
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
8282
return
8383

8484
audio_duration = voice.get_audio_duration(sub_maker)
8585
audio_duration = math.ceil(audio_duration)
8686

87-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
87+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
8888

8989
subtitle_path = ""
9090
if params.subtitle_enabled:
@@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
108108
logger.warning(f"subtitle file is invalid: {subtitle_path}")
109109
subtitle_path = ""
110110

111-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
111+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
112112

113113
logger.info("\n\n## downloading videos")
114114
downloaded_videos = material.download_videos(task_id=task_id,
@@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
119119
max_clip_duration=max_clip_duration,
120120
)
121121
if not downloaded_videos:
122-
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
122+
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
123123
logger.error(
124124
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
125125
return
126126

127-
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
127+
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
128128

129129
final_video_paths = []
130130
combined_video_paths = []
@@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
146146
threads=n_threads)
147147

148148
_progress += 50 / params.video_count / 2
149-
sm.update_task(task_id, progress=_progress)
149+
sm.state.update_task(task_id, progress=_progress)
150150

151151
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
152152

@@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
160160
)
161161

162162
_progress += 50 / params.video_count / 2
163-
sm.update_task(task_id, progress=_progress)
163+
sm.state.update_task(task_id, progress=_progress)
164164

165165
final_video_paths.append(final_video_path)
166166
combined_video_paths.append(combined_video_path)
@@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
171171
"videos": final_video_paths,
172172
"combined_videos": combined_video_paths
173173
}
174-
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
174+
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
175175
return kwargs

config.example.toml

+5
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@
129129

130130
material_directory = ""
131131

132+
# Used for state management of the task
133+
enable_redis = true
134+
redis_host = "localhost"
135+
redis_port = 6379
136+
redis_db = 0
132137

133138
[whisper]
134139
# Only effective when subtitle_provider is "whisper"

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ pydantic~=2.6.3
1515
g4f~=0.2.5.4
1616
dashscope~=1.15.0
1717
google.generativeai~=0.4.1
18-
python-multipart~=0.0.9
18+
python-multipart~=0.0.9
19+
redis==5.0.3

0 commit comments

Comments
 (0)