diff --git a/.gitignore b/.gitignore index 901e8aa..9a07b96 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,4 @@ venv.bak/ # add .idea/ token.json +task.db diff --git a/TaskManager.py b/TaskManager.py index 15a8fb0..b890a7c 100644 --- a/TaskManager.py +++ b/TaskManager.py @@ -6,6 +6,9 @@ import shortuuid from PikPakFileSystem import PikPakFileSystem, FileNode, DirNode from pikpakapi import DownloadStatus import random +import pickle + +DB_PATH = "task.db" class TaskStatus(Enum): PENDING = "pending" @@ -29,24 +32,36 @@ class TaskBase: TAG = "" MAX_CONCURRENT_NUMBER = 5 - def __init__(self, client : PikPakFileSystem): + def __init__(self): self.id : str = shortuuid.uuid() self.status : TaskStatus = TaskStatus.PENDING self.worker : asyncio.Task = None self.handler : Callable[..., Awaitable] = None - self.client : PikPakFileSystem = client def Resume(self): if self.status in {TaskStatus.PAUSED, TaskStatus.ERROR}: self.status = TaskStatus.PENDING + def __getstate__(self): + state = self.__dict__.copy() + if 'handler' in state: + del state['handler'] + if 'worker' in state: + del state['worker'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.worker = None + self.handler = None + class TorrentTask(TaskBase): TAG = "TorrentTask" MAX_CONCURRENT_NUMBER = 5 def __init__(self, torrent : str): - super().__init__(self) + super().__init__() self.torrent_status : TorrentTaskStatus = TorrentTaskStatus.PENDING self.torrent : str = torrent self.info : str = "" @@ -56,19 +71,18 @@ class TorrentTask(TaskBase): self.remote_base_path : str = None self.node_id : str = None self.task_id : str = None - - + class FileDownloadTask(TaskBase): TAG = "FileDownloadTask" MAX_CONCURRENT_NUMBER = 5 def __init__(self, node_id : str, remote_path : str, owner_id : str): - super().__init__(self) + super().__init__() self.file_download_status : FileDownloadTaskStatus = FileDownloadTaskStatus.PENDING self.node_id : str = node_id self.remote_path : str = remote_path self.owner_id : str = owner_id - + async def TaskWorker(task : TaskBase): try: if task.status != TaskStatus.PENDING: @@ -91,15 +105,18 @@ class TaskManager: async def _loop(self): while True: - await asyncio.sleep(0.5) - for taskQueue in self.taskQueues.values(): - notRunningTasks = [task for task in taskQueue if task.worker is None or task.worker.done()] - runningTasksNumber = len(taskQueue) - len(notRunningTasks) - for task in [task for task in notRunningTasks if task.status == TaskStatus.PENDING]: - if runningTasksNumber >= task.MAX_CONCURRENT_NUMBER: - break - task.worker = asyncio.create_task(TaskWorker(task)) - runningTasksNumber += 1 + try: + await asyncio.sleep(0.5) + for taskQueue in self.taskQueues.values(): + notRunningTasks = [task for task in taskQueue if task.worker is None or task.worker.done()] + runningTasksNumber = len(taskQueue) - len(notRunningTasks) + for task in [task for task in notRunningTasks if task.status == TaskStatus.PENDING]: + if runningTasksNumber >= task.MAX_CONCURRENT_NUMBER: + break + task.worker = asyncio.create_task(TaskWorker(task)) + runningTasksNumber += 1 + except Exception as e: + logging.error(f"task loop failed, exception occurred: {e}") async def _get_task_by_id(self, task_id : str) -> TaskBase: for queue in self.taskQueues.values(): @@ -241,12 +258,30 @@ class TaskManager: #endregion + def _load_tasks_from_db(self): + try: + self.taskQueues = pickle.load(open(DB_PATH, "rb")) + for queue in self.taskQueues.values(): + for task in queue: + if task.status == TaskStatus.RUNNING: + task.status = TaskStatus.PENDING + if isinstance(task, TorrentTask): + task.handler = self._torrent_task_handler + task.info = "" + if isinstance(task, FileDownloadTask): + task.handler = self._file_download_task_handler + except: + pass + + def _dump_tasks_to_db(self): + pickle.dump(self.taskQueues, open(DB_PATH, "wb")) + #endregion #region 对外接口 def Start(self): - # todo: 从文件中恢复任务 + self._load_tasks_from_db() if self.loop is None: self.loop = asyncio.create_task(self._loop()) @@ -254,7 +289,8 @@ class TaskManager: if self.loop is not None: self.loop.cancel() self.loop = None - # todo: 保存任务到文件 + self._dump_tasks_to_db() + async def CreateTorrentTask(self, torrent : str, remote_base_path : str) -> str: task = TorrentTask(torrent) diff --git a/readme.md b/readme.md index b80b180..37a52d2 100644 --- a/readme.md +++ b/readme.md @@ -9,9 +9,9 @@ Todo: - [x] 实现Task队列管理 - [x] 自动刷新文件系统缓存 - [x] 分析以下方法的返回值:offline_file_info、offline_list +- [x] 持久化数据 - [ ] 实现本地下载队列(多文件,文件夹) -- [ ] 实现任务暂停、继续、恢复 -- [ ] 持久化数据 +- [x] 实现任务暂停、继续、恢复 - [ ] 添加测试用例 - [ ] 完全类型化