From 9e8d5921ac6c7f6e36dbc7d09bdec54eb1552f04 Mon Sep 17 00:00:00 2001
From: limil <gravitylmlml@gmail.com>
Date: Sun, 3 Nov 2024 11:36:09 +0800
Subject: [PATCH] =?UTF-8?q?=E6=8C=81=E4=B9=85=E5=8C=96=E6=95=B0=E6=8D=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .gitignore     |  1 +
 TaskManager.py | 72 +++++++++++++++++++++++++++++++++++++-------------
 readme.md      |  4 +--
 3 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/.gitignore b/.gitignore
index 901e8aa..9a07b96 100644
--- a/.gitignore
+++ b/.gitignore
@@ -106,3 +106,4 @@ venv.bak/
 # add
 .idea/
 token.json
+task.db
diff --git a/TaskManager.py b/TaskManager.py
index 15a8fb0..b890a7c 100644
--- a/TaskManager.py
+++ b/TaskManager.py
@@ -6,6 +6,9 @@ import shortuuid
 from PikPakFileSystem import PikPakFileSystem, FileNode, DirNode
 from pikpakapi import DownloadStatus
 import random
+import pickle
+
+DB_PATH = "task.db"
 
 class TaskStatus(Enum):
     PENDING = "pending"
@@ -29,24 +32,36 @@ class TaskBase:
     TAG = ""
     MAX_CONCURRENT_NUMBER = 5
 
-    def __init__(self, client : PikPakFileSystem):
+    def __init__(self):
         self.id : str = shortuuid.uuid() 
         self.status : TaskStatus = TaskStatus.PENDING
         self.worker : asyncio.Task = None
         self.handler : Callable[..., Awaitable] = None
-        self.client : PikPakFileSystem = client
 
     def Resume(self):
         if self.status in {TaskStatus.PAUSED, TaskStatus.ERROR}:
             self.status = TaskStatus.PENDING
     
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        if 'handler' in state:
+            del state['handler']
+        if 'worker' in state:
+            del state['worker']
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self.worker = None
+        self.handler = None
+    
 
 class TorrentTask(TaskBase):
     TAG = "TorrentTask"
     MAX_CONCURRENT_NUMBER = 5
 
     def __init__(self, torrent : str):
-        super().__init__(self)
+        super().__init__()
         self.torrent_status : TorrentTaskStatus = TorrentTaskStatus.PENDING
         self.torrent : str = torrent
         self.info : str = ""
@@ -56,19 +71,18 @@ class TorrentTask(TaskBase):
         self.remote_base_path : str = None
         self.node_id : str = None
         self.task_id : str = None
-
-
+    
 class FileDownloadTask(TaskBase):
     TAG = "FileDownloadTask"
     MAX_CONCURRENT_NUMBER = 5
 
     def __init__(self, node_id : str, remote_path : str, owner_id : str):
-        super().__init__(self)
+        super().__init__()
         self.file_download_status : FileDownloadTaskStatus = FileDownloadTaskStatus.PENDING
         self.node_id : str = node_id
         self.remote_path : str = remote_path
         self.owner_id : str = owner_id
-
+    
 async def TaskWorker(task : TaskBase):
     try:
         if task.status != TaskStatus.PENDING:
@@ -91,15 +105,18 @@ class TaskManager:
     
     async def _loop(self):
         while True:
-            await asyncio.sleep(0.5)
-            for taskQueue in self.taskQueues.values():
-                notRunningTasks = [task for task in taskQueue if task.worker is None or task.worker.done()]
-                runningTasksNumber = len(taskQueue) - len(notRunningTasks)
-                for task in [task for task in notRunningTasks if task.status == TaskStatus.PENDING]:
-                    if runningTasksNumber >= task.MAX_CONCURRENT_NUMBER:
-                        break
-                    task.worker = asyncio.create_task(TaskWorker(task))
-                    runningTasksNumber += 1
+            try:
+                await asyncio.sleep(0.5)
+                for taskQueue in self.taskQueues.values():
+                    notRunningTasks = [task for task in taskQueue if task.worker is None or task.worker.done()]
+                    runningTasksNumber = len(taskQueue) - len(notRunningTasks)
+                    for task in [task for task in notRunningTasks if task.status == TaskStatus.PENDING]:
+                        if runningTasksNumber >= task.MAX_CONCURRENT_NUMBER:
+                            break
+                        task.worker = asyncio.create_task(TaskWorker(task))
+                        runningTasksNumber += 1
+            except Exception as e:
+                logging.error(f"task loop failed, exception occurred: {e}")
 
     async def _get_task_by_id(self, task_id : str) -> TaskBase:
         for queue in self.taskQueues.values():
@@ -241,12 +258,30 @@ class TaskManager:
 
     #endregion
 
+    def _load_tasks_from_db(self):
+        try:
+            self.taskQueues = pickle.load(open(DB_PATH, "rb"))
+            for queue in self.taskQueues.values():
+                for task in queue:
+                    if task.status == TaskStatus.RUNNING:
+                        task.status = TaskStatus.PENDING
+                    if isinstance(task, TorrentTask):
+                        task.handler = self._torrent_task_handler
+                        task.info = ""
+                    if isinstance(task, FileDownloadTask):
+                        task.handler = self._file_download_task_handler
+        except:
+            pass
+    
+    def _dump_tasks_to_db(self):
+        pickle.dump(self.taskQueues, open(DB_PATH, "wb"))
+
     #endregion
 
     #region 对外接口
 
     def Start(self):
-        # todo: 从文件中恢复任务
+        self._load_tasks_from_db()
         if self.loop is None:
             self.loop = asyncio.create_task(self._loop())
 
@@ -254,7 +289,8 @@ class TaskManager:
         if self.loop is not None:
             self.loop.cancel()
             self.loop = None
-        # todo: 保存任务到文件
+        self._dump_tasks_to_db()
+        
     
     async def CreateTorrentTask(self, torrent : str, remote_base_path : str) -> str:
         task = TorrentTask(torrent)
diff --git a/readme.md b/readme.md
index b80b180..37a52d2 100644
--- a/readme.md
+++ b/readme.md
@@ -9,9 +9,9 @@ Todo:
 - [x] 实现Task队列管理
 - [x] 自动刷新文件系统缓存
 - [x] 分析以下方法的返回值:offline_file_info、offline_list
+- [x] 持久化数据
 - [ ] 实现本地下载队列(多文件,文件夹)
-- [ ] 实现任务暂停、继续、恢复
-- [ ] 持久化数据
+- [x] 实现任务暂停、继续、恢复
 - [ ] 添加测试用例
 - [ ] 完全类型化