实现中断指令

This commit is contained in:
limil 2024-11-02 19:39:39 +08:00
parent 3234aaf757
commit fb08711287
2 changed files with 122 additions and 40 deletions

View File

@ -7,7 +7,7 @@ import os
import logging
from enum import Enum
import asyncio
import uuid
import shortuuid
from utils import PathWalker
from typing import Callable, Awaitable
@ -20,8 +20,8 @@ class TaskStatus(Enum):
class PikPakTaskStatus(Enum):
PENDING = "pending"
REMOTE_DOWNLOADING = "remote downloading"
LOCAL_DOWNLOADING = "local downloading"
REMOTE_DOWNLOADING = "remote"
LOCAL_DOWNLOADING = "local"
class FileDownloadTaskStatus(Enum):
PENDING = "pending"
@ -33,7 +33,7 @@ class UnRecoverableError(Exception):
class TaskBase:
def __init__(self, id : str, tag : str = "", maxConcurrentNumber = -1):
self.id : str = uuid.uuid4() if id is None else id
self.id : str = shortuuid.uuid() if id is None else id
self.tag : str = tag
self.maxConcurrentNumber : int = maxConcurrentNumber
@ -59,11 +59,12 @@ class FileDownloadTask(TaskBase):
TAG = "FileDownloadTask"
MAX_CONCURRENT_NUMBER = 5
def __init__(self, nodeId : str, PikPakTaskId : str, id : str = None, status : FileDownloadTaskStatus = FileDownloadTaskStatus.PENDING):
super().__init__(id, FileDownloadTask.TAG, FileDownloadTask.MAX_CONCURRENT_NUMBER)
def __init__(self, nodeId : str, PikPakTaskId : str, relativePath : str, status : FileDownloadTaskStatus = FileDownloadTaskStatus.PENDING):
super().__init__(nodeId, FileDownloadTask.TAG, FileDownloadTask.MAX_CONCURRENT_NUMBER)
self.status : FileDownloadTaskStatus = status
self.PikPakTaskId : str = PikPakTaskId
self.nodeId : str = nodeId
self.relativePath : str = relativePath
async def TaskWorker(task : TaskBase):
try:
@ -158,11 +159,33 @@ class PikPakFs:
self.nodes[task.nodeId] = DirNode(task.nodeId, task.name, task.toDirId)
else:
self.nodes[task.nodeId] = FileNode(task.nodeId, task.name, task.toDirId)
father = self.GetNodeById(task.toDirId)
if father.id is not None and task.nodeId not in father.childrenId:
father.childrenId.append(task.nodeId)
task.status = PikPakTaskStatus.LOCAL_DOWNLOADING
async def _pikpak_local_downloading(self, task : PikPakTask):
node = self.GetNodeById(task.nodeId)
if IsFile(node):
fileDownloadTask = FileDownloadTask(task.nodeId, task.id, self.NodeToPath(node, node))
fileDownloadTask.handler = self._file_download_task_handler
self._add_task(fileDownloadTask)
elif IsDir(node):
# 使用广度优先遍历
queue : list[DirNode] = [node]
while len(queue) > 0:
current = queue.pop(0)
await self.Refresh(current)
for childId in current.childrenId:
child = self.GetNodeById(childId)
if IsDir(child):
queue.append(child)
elif IsFile(child):
fileDownloadTask = FileDownloadTask(childId, task.id, self.NodeToPath(child, node))
fileDownloadTask.handler = self._file_download_task_handler
self._add_task(fileDownloadTask)
async def _pikpak_task_handler(self, task : PikPakTask):
while True:
if task.status == PikPakTaskStatus.PENDING:
@ -170,7 +193,7 @@ class PikPakFs:
elif task.status == PikPakTaskStatus.REMOTE_DOWNLOADING:
await self._pikpak_offline_downloading(task)
elif task.status == PikPakTaskStatus.LOCAL_DOWNLOADING:
break
await self._pikpak_local_downloading(task)
else:
break
@ -180,7 +203,11 @@ class PikPakFs:
def _add_task(self, task : TaskBase):
if self.taskQueues.get(task.tag) is None:
self.taskQueues[task.tag] = []
self.taskQueues[task.tag].append(task)
taskQueue = self.taskQueues[task.tag]
for t in taskQueue:
if t.id == task.id:
return
taskQueue.append(task)
async def StopTask(self, task : TaskBase):
pass
@ -317,16 +344,16 @@ class PikPakFs:
node.lastUpdate = datetime.now()
async def PathToNode(self, pathStr : str) -> FsNode:
father, sonName = await self.PathToFatherNodeAndNodeName(pathStr)
async def PathToNode(self, path : str) -> FsNode:
father, sonName = await self.PathToFatherNodeAndNodeName(path)
if sonName == "":
return father
if not IsDir(father):
return None
return self.FindChildInDirByName(father, sonName)
async def PathToFatherNodeAndNodeName(self, pathStr : str) -> tuple[FsNode, str]:
pathWalker = PathWalker(pathStr)
async def PathToFatherNodeAndNodeName(self, path : str) -> tuple[FsNode, str]:
pathWalker = PathWalker(path)
father : FsNode = None
sonName : str = None
current = self.root if pathWalker.IsAbsolute() else self.currentLocation
@ -354,12 +381,14 @@ class PikPakFs:
return father, sonName
def NodeToPath(self, node : FsNode) -> str:
if node is self.root:
def NodeToPath(self, node : FsNode, root : FsNode = None) -> str:
if root is None:
root = self.root
if node is root:
return "/"
spots : list[str] = []
current = node
while current is not self.root:
while current is not root:
spots.append(current.name)
current = self.GetFatherNode(current)
spots.append("")
@ -392,6 +421,12 @@ class PikPakFs:
task.handler = self._pikpak_task_handler
self._add_task(task)
return task
async def Pull(self, node : FsNode) -> PikPakTask:
task = PikPakTask("", node.fatherId, node.id, PikPakTaskStatus.LOCAL_DOWNLOADING)
task.handler = self._pikpak_local_downloading
self._add_task(task)
return task
async def QueryPikPakTasks(self, filterStatus : TaskStatus = None) -> list[PikPakTask]:
if PikPakTask.TAG not in self.taskQueues:
@ -400,6 +435,14 @@ class PikPakFs:
if filterStatus is None:
return taskQueue
return [task for task in taskQueue if task._status == filterStatus]
async def QueryFileDownloadTasks(self, filterStatus : TaskStatus = None) -> list[FileDownloadTask]:
if FileDownloadTask.TAG not in self.taskQueues:
return []
taskQueue = self.taskQueues[FileDownloadTask.TAG]
if filterStatus is None:
return taskQueue
return [task for task in taskQueue if task._status == filterStatus]
async def Delete(self, nodes : list[FsNode]) -> None:
nodeIds = [node.id for node in nodes]

89
main.py
View File

@ -6,7 +6,7 @@ import threading
import colorlog
from PikPakFs import PikPakFs, IsDir, IsFile, TaskStatus
import os
import keyboard
from tabulate import tabulate
LogFormatter = colorlog.ColoredFormatter(
"%(log_color)s%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -34,19 +34,32 @@ setup_logging()
MainLoop : asyncio.AbstractEventLoop = None
Client = PikPakFs("token.json", proxy="http://127.0.0.1:7897")
def RunSync(func):
@wraps(func)
def decorator(*args, **kwargs):
class RunSync:
_current_task : asyncio.Task = None
def StopCurrentRunningCoroutine():
if RunSync._current_task is not None:
RunSync._current_task.cancel()
def __init__(self, func):
wraps(func)(self)
def __call__(self, *args, **kwargs):
currentLoop = None
try:
currentLoop = asyncio.get_running_loop()
except RuntimeError:
logging.error("Not in an event loop")
pass
func = self.__wrapped__
if currentLoop is MainLoop:
return MainLoop.run_until_complete(func(*args, **kwargs))
task = asyncio.Task(func(*args, **kwargs))
RunSync._current_task = task
result = MainLoop.run_until_complete(task)
RunSync._current_task = None
return result
else:
return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), MainLoop).result()
return decorator
class Console(cmd2.Cmd):
def _io_worker(self, loop):
@ -76,7 +89,7 @@ class Console(cmd2.Cmd):
# 1. 设置忽略SIGINT
import signal
def signal_handler(sig, frame):
pass
RunSync.StopCurrentRunningCoroutine()
signal.signal(signal.SIGINT, signal_handler)
# 2. 创建IO线程处理输入输出
@ -125,8 +138,8 @@ class Console(cmd2.Cmd):
login_parser = cmd2.Cmd2ArgumentParser()
login_parser.add_argument("username", help="username", nargs="?")
login_parser.add_argument("password", help="password", nargs="?")
@RunSync
@cmd2.with_argparser(login_parser)
@RunSync
async def do_login(self, args):
"""
Login to pikpak
@ -134,7 +147,7 @@ class Console(cmd2.Cmd):
await Client.Login(args.username, args.password)
await self.Print("Logged in successfully")
async def _path_completer(self, text, line, begidx, endidx, filterfiles):
async def _path_completer(self, text, line, begidx, endidx, ignoreFiles):
father, sonName = await Client.PathToFatherNodeAndNodeName(text)
if not IsDir(father):
return []
@ -142,7 +155,7 @@ class Console(cmd2.Cmd):
matchesNode = []
for childId in father.childrenId:
child = Client.GetNodeById(childId)
if filterfiles and IsFile(child):
if ignoreFiles and IsFile(child):
continue
if child.name.startswith(sonName):
self.display_matches.append(child.name)
@ -163,8 +176,8 @@ class Console(cmd2.Cmd):
ls_parser = cmd2.Cmd2ArgumentParser()
ls_parser.add_argument("path", help="path", default="", nargs="?", type=RunSync(Client.PathToNode))
@RunSync
@cmd2.with_argparser(ls_parser)
@RunSync
async def do_ls(self, args):
"""
List files in a directory
@ -187,8 +200,8 @@ class Console(cmd2.Cmd):
cd_parser = cmd2.Cmd2ArgumentParser()
cd_parser.add_argument("path", help="path", default="", nargs="?", type=RunSync(Client.PathToNode))
@RunSync
@cmd2.with_argparser(cd_parser)
@RunSync
async def do_cd(self, args):
"""
Change directory
@ -218,8 +231,8 @@ class Console(cmd2.Cmd):
rm_parser = cmd2.Cmd2ArgumentParser()
rm_parser.add_argument("paths", help="paths", default="", nargs="+", type=RunSync(Client.PathToNode))
@RunSync
@cmd2.with_argparser(rm_parser)
@RunSync
async def do_rm(self, args):
"""
Remove a file or directory
@ -232,8 +245,8 @@ class Console(cmd2.Cmd):
mkdir_parser = cmd2.Cmd2ArgumentParser()
mkdir_parser.add_argument("path_and_son", help="path and son", default="", nargs="?", type=RunSync(Client.PathToFatherNodeAndNodeName))
@RunSync
@cmd2.with_argparser(mkdir_parser)
@RunSync
async def do_mkdir(self, args):
"""
Create a directory
@ -251,11 +264,11 @@ class Console(cmd2.Cmd):
download_parser = cmd2.Cmd2ArgumentParser()
download_parser.add_argument("url", help="url")
download_parser.add_argument("path", help="path", default="", nargs="?", type=RunSync(Client.PathToNode))
@RunSync
@cmd2.with_argparser(download_parser)
@RunSync
async def do_download(self, args):
"""
Download a file
Download a file or directory
"""
node = args.path
if not IsDir(node):
@ -264,24 +277,47 @@ class Console(cmd2.Cmd):
task = await Client.Download(args.url, node)
await self.Print(f"Task {task.id} created")
query_parser = cmd2.Cmd2ArgumentParser()
query_parser.add_argument("-f", "--filter", help="filter", nargs="?", choices=[member.value for member in TaskStatus])
@RunSync
async def complete_pull(self, text, line, begidx, endidx):
return await self._path_completer(text, line, begidx, endidx, False)
pull_parser = cmd2.Cmd2ArgumentParser()
pull_parser.add_argument("target", help="pull target", type=RunSync(Client.PathToNode))
@cmd2.with_argparser(pull_parser)
@RunSync
async def do_pull(self, args):
"""
Pull a file or directory
"""
await Client.Pull(args.target)
query_parser = cmd2.Cmd2ArgumentParser()
query_parser.add_argument("-t", "--type", help="type", nargs="?", choices=["pikpak", "filedownload"], default="pikpak")
query_parser.add_argument("-f", "--filter", help="filter", nargs="?", choices=[member.value for member in TaskStatus])
@cmd2.with_argparser(query_parser)
@RunSync
async def do_query(self, args):
"""
Query All Tasks
"""
tasks = await Client.QueryPikPakTasks(TaskStatus(args.filter) if args.filter is not None else None)
# 格式化输出所有task信息idstatuslastStatus的信息输出表格
await self.Print("tstatus\tdetails\tid")
for task in tasks:
await self.Print(f"{task._status.value}\t{task.status.value}\t{task.id}")
if args.type == "pikpak":
tasks = await Client.QueryPikPakTasks(TaskStatus(args.filter) if args.filter is not None else None)
# 格式化输出所有task信息idstatuslastStatus的信息输出表格
table = [[task.id, task._status.value, task.status.value] for task in tasks]
headers = ["id", "status", "details"]
await self.Print(tabulate(table, headers, tablefmt="grid"))
elif args.type == "filedownload":
tasks = await Client.QueryFileDownloadTasks(TaskStatus(args.filter) if args.filter is not None else None)
# 格式化输出所有task信息idstatuslastStatus的信息输出表格
table = [[task.id, task._status.value, task.status.value, task.relativePath] for task in tasks]
headers = ["id", "status", "details", "path"]
await self.Print(tabulate(table, headers, tablefmt="grid"))
retry_parser = cmd2.Cmd2ArgumentParser()
retry_parser.add_argument("taskId", help="taskId")
@RunSync
@cmd2.with_argparser(retry_parser)
@RunSync
async def do_retry(self, args):
"""
Retry a task
@ -299,7 +335,10 @@ async def mainLoop():
stop = False
while not stop:
line = await console.Input(console.prompt)
stop = console.onecmd_plus_hooks(line)
try:
stop = console.onecmd_plus_hooks(line)
except asyncio.CancelledError:
await console.Print("^C: Task cancelled")
finally:
console.postloop()
clientWorker.cancel()