From 8f955f130f8d15ab6e9a93a83fe5d059a98cbb81 Mon Sep 17 00:00:00 2001 From: XYenon <20698483+XYenon@users.noreply.github.com> Date: Sun, 31 Jul 2022 20:46:30 +0800 Subject: [PATCH 1/3] feat: add aria2 seeding --- assets | 2 +- pkg/aria2/aria2.go | 6 +- pkg/aria2/common/common.go | 9 +- pkg/aria2/common/common_test.go | 25 ++- pkg/aria2/monitor/monitor.go | 48 +++++- pkg/aria2/monitor/monitor_test.go | 11 +- pkg/aria2/rpc/resp.go | 60 +++---- .../driver/shadow/slaveinmaster/handler.go | 46 +++++- pkg/serializer/slave.go | 21 +++ pkg/serializer/slave_test.go | 14 +- pkg/task/job.go | 4 + pkg/task/job_test.go | 14 +- pkg/task/recycle.go | 155 ++++++++++++++++++ pkg/task/recycle_test.go | 131 +++++++++++++++ pkg/task/slavetask/recycle.go | 95 +++++++++++ pkg/task/slavetask/transfer.go | 13 +- pkg/task/tranfer.go | 13 -- routers/controllers/slave.go | 11 ++ routers/router.go | 1 + service/aria2/manage.go | 2 +- service/explorer/slave.go | 27 ++- service/user/register.go | 5 +- 22 files changed, 627 insertions(+), 86 deletions(-) create mode 100644 pkg/task/recycle.go create mode 100644 pkg/task/recycle_test.go create mode 100644 pkg/task/slavetask/recycle.go diff --git a/assets b/assets index a1028e7e0a..02d93206cc 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit a1028e7e0ae96be4bb67d8c117cf39e07c207473 +Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603 diff --git a/pkg/aria2/aria2.go b/pkg/aria2/aria2.go index 60d254e524..f91766faf9 100644 --- a/pkg/aria2/aria2.go +++ b/pkg/aria2/aria2.go @@ -3,8 +3,6 @@ package aria2 import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" "net/url" "sync" "time" @@ -14,6 +12,8 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" ) // Instance 默认使用的Aria2处理实例 @@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) { if !isReload { // 从数据库中读取未完成任务,创建监控 - unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading) + unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding) for i := 0; i < len(unfinished); i++ { // 创建任务监控 diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index 8f281d810a..d4a8313d89 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -38,6 +38,8 @@ const ( Downloading // Paused 暂停中 Paused + // Seeding 做种中 + Seeding // Error 出错 Error // Complete 完成 @@ -94,11 +96,14 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { } // GetStatus 将给定的状态字符串转换为状态标识数字 -func GetStatus(status string) int { - switch status { +func GetStatus(status rpc.StatusInfo) int { + switch status.Status { case "complete": return Complete case "active": + if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength { + return Seeding + } return Downloading case "waiting": return Ready diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go index a93f5f806a..7b0f2378e9 100644 --- a/pkg/aria2/common/common_test.go +++ b/pkg/aria2/common/common_test.go @@ -1,9 +1,11 @@ package common import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/stretchr/testify/assert" - "testing" ) func TestDummyAria2(t *testing.T) { @@ -35,11 +37,18 @@ func TestDummyAria2(t *testing.T) { func TestGetStatus(t *testing.T) { a := assert.New(t) - a.Equal(GetStatus("complete"), Complete) - a.Equal(GetStatus("active"), Downloading) - a.Equal(GetStatus("waiting"), Ready) - a.Equal(GetStatus("paused"), Paused) - a.Equal(GetStatus("error"), Error) - a.Equal(GetStatus("removed"), Canceled) - a.Equal(GetStatus("unknown"), Unknown) + a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "single"}, + TotalLength: "100", CompletedLength: "50"}), Downloading) + a.Equal(GetStatus(rpc.StatusInfo{Status: "active", + BitTorrent: rpc.BitTorrentInfo{Mode: "multi"}, + TotalLength: "100", CompletedLength: "100"}), Seeding) + a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready) + a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused) + a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error) + a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled) + a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown) } diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index a515b66f86..6f6de7e976 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -109,14 +109,14 @@ func (monitor *Monitor) Update() bool { util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status) - switch status.Status { - case "complete": + switch common.GetStatus(status) { + case common.Complete, common.Seeding: return monitor.Complete(task.TaskPoll) - case "error": + case common.Error: return monitor.Error(status) - case "active", "waiting", "paused": + case common.Downloading, common.Ready, common.Paused: return false - case "removed": + case common.Canceled: monitor.Task.Status = common.Canceled monitor.Task.Save() monitor.RemoveTempFolder() @@ -132,7 +132,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error { originSize := monitor.Task.TotalSize monitor.Task.GID = status.Gid - monitor.Task.Status = common.GetStatus(status.Status) + monitor.Task.Status = common.GetStatus(status) // 文件大小、已下载大小 total, err := strconv.ParseUint(status.TotalLength, 10, 64) @@ -235,6 +235,40 @@ func (monitor *Monitor) RemoveTempFolder() { // Complete 完成下载,返回是否中断监控 func (monitor *Monitor) Complete(pool task.Pool) bool { + // 未开始转存,提交转存任务 + if monitor.Task.TaskID == 0 { + return monitor.transfer(pool) + } + + // 做种完成 + if common.GetStatus(monitor.Task.StatusInfo) == common.Complete { + transferTask, err := model.GetTasksByID(monitor.Task.TaskID) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 转存完成,回收下载目录 + if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error { + job, err := task.NewRecycleTask(monitor.Task.UserID, monitor.Task.Parent, monitor.node.ID()) + if err != nil { + monitor.setErrorStatus(err) + monitor.RemoveTempFolder() + return true + } + + // 提交回收任务 + pool.Submit(job) + + return true + } + } + + return false +} + +func (monitor *Monitor) transfer(pool task.Pool) bool { // 创建中转任务 file := make([]string, 0, len(monitor.Task.StatusInfo.Files)) sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files)) @@ -269,7 +303,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool { monitor.Task.TaskID = job.Model().ID monitor.Task.Save() - return true + return false } func (monitor *Monitor) setErrorStatus(err error) { diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go index 885484a385..a6be586add 100644 --- a/pkg/aria2/monitor/monitor_test.go +++ b/pkg/aria2/monitor/monitor_test.go @@ -3,6 +3,8 @@ package monitor import ( "database/sql" "errors" + "testing" + "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" @@ -13,7 +15,6 @@ import ( "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" - "testing" ) var mock sqlmock.Sqlmock @@ -431,6 +432,14 @@ func TestMonitor_Complete(t *testing.T) { mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() + mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4)) + mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + + a.False(m.Complete(mockPool)) + m.Task.StatusInfo.Status = "complete" a.True(m.Complete(mockPool)) a.NoError(mock.ExpectationsWereMet()) mockNode.AssertExpectations(t) diff --git a/pkg/aria2/rpc/resp.go b/pkg/aria2/rpc/resp.go index e685ce66f2..3614228fe4 100644 --- a/pkg/aria2/rpc/resp.go +++ b/pkg/aria2/rpc/resp.go @@ -4,35 +4,27 @@ package rpc // StatusInfo represents response of aria2.tellStatus type StatusInfo struct { - Gid string `json:"gid"` // GID of the download. - Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. - TotalLength string `json:"totalLength"` // Total length of the download in bytes. - CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. - UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. - BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. - DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. - UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. - InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. - NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. - Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. - PieceLength string `json:"pieceLength"` // Piece length in bytes. - NumPieces string `json:"numPieces"` // The number of pieces. - Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. - ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. - ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. - FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. - BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. - Dir string `json:"dir"` // Directory to save files. - Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. - BitTorrent struct { - AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. - Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. - CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. - Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. - Info struct { - Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. - } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. - } `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. + Gid string `json:"gid"` // GID of the download. + Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. + TotalLength string `json:"totalLength"` // Total length of the download in bytes. + CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. + UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. + BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. + DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. + UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec. + InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. + NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. + Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only. + PieceLength string `json:"pieceLength"` // Piece length in bytes. + NumPieces string `json:"numPieces"` // The number of pieces. + Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. + ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. + ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode. + FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. + BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. + Dir string `json:"dir"` // Directory to save files. + Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. + BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. } // URIInfo represents an element of response of aria2.getUris @@ -100,3 +92,13 @@ type Method struct { Name string `json:"methodName"` // Method name to call Params []interface{} `json:"params"` // Array containing parameters to the method call } + +type BitTorrentInfo struct { + AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. + Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. + CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. + Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. + Info struct { + Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. + } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. +} diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index 7fc7b0983f..84116394d2 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -5,6 +5,9 @@ import ( "context" "encoding/json" "errors" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" @@ -13,8 +16,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "net/url" - "time" ) // Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果 @@ -118,6 +119,45 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo } // 取消上传凭证 -func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { +func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { + return nil +} + +func (d *Driver) Recycle(ctx context.Context, path string) error { + req := serializer.SlaveRecycleReq{ + Path: path, + } + + body, err := json.Marshal(req) + if err != nil { + return err + } + + // 订阅回收结果 + resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0) + defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan) + + res, err := d.client.Request("PUT", "task/recycle", bytes.NewReader(body)). + CheckHTTPResponse(200). + DecodeResponse() + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + // 等待回收结果或者超时 + waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800) + select { + case <-time.After(time.Duration(waitTimeout) * time.Second): + return ErrWaitResultTimeout + case msg := <-resChan: + if msg.Event != serializer.SlaveRecycleSuccess { + return errors.New(msg.Content.(serializer.SlaveRecycleResult).Error) + } + } + return nil } diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index 245767a9e4..4179d4555d 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/gob" "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" ) @@ -53,15 +54,35 @@ func (s *SlaveTransferReq) Hash(id string) string { return fmt.Sprintf("%x", bs) } +// SlaveRecycleReq 从机回收任务创建请求 +type SlaveRecycleReq struct { + Path string `json:"path"` +} + +// Hash 返回创建请求的唯一标识,保持创建请求幂等 +func (s *SlaveRecycleReq) Hash(id string) string { + h := sha1.New() + h.Write([]byte(fmt.Sprintf("transfer-%s-%s", id, s.Path))) + bs := h.Sum(nil) + return fmt.Sprintf("%x", bs) +} + const ( SlaveTransferSuccess = "success" SlaveTransferFailed = "failed" + SlaveRecycleSuccess = "success" + SlaveRecycleFailed = "failed" ) type SlaveTransferResult struct { Error string } +type SlaveRecycleResult struct { + Error string +} + func init() { gob.Register(SlaveTransferResult{}) + gob.Register(SlaveRecycleResult{}) } diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go index 6471542124..add3a63492 100644 --- a/pkg/serializer/slave_test.go +++ b/pkg/serializer/slave_test.go @@ -1,9 +1,10 @@ package serializer import ( + "testing" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" - "testing" ) func TestSlaveTransferReq_Hash(t *testing.T) { @@ -18,3 +19,14 @@ func TestSlaveTransferReq_Hash(t *testing.T) { } a.NotEqual(s1.Hash("1"), s2.Hash("1")) } + +func TestSlaveRecycleReq_Hash(t *testing.T) { + a := assert.New(t) + s1 := &SlaveRecycleReq{ + Path: "1", + } + s2 := &SlaveRecycleReq{ + Path: "2", + } + a.NotEqual(s1.Hash("1"), s2.Hash("1")) +} diff --git a/pkg/task/job.go b/pkg/task/job.go index 781c4608fe..9bf52d74ad 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -13,6 +13,8 @@ const ( DecompressTaskType // TransferTaskType 中转任务 TransferTaskType + // RecycleTaskType 回收任务 + RecycleTaskType // ImportTaskType 导入任务 ImportTaskType ) @@ -113,6 +115,8 @@ func GetJobFromModel(task *model.Task) (Job, error) { return NewTransferTaskFromModel(task) case ImportTaskType: return NewImportTaskFromModel(task) + case RecycleTaskType: + return NewRecycleTaskFromModel(task) default: return nil, ErrUnknownTaskType } diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go index 81793ee6bc..737f5b7684 100644 --- a/pkg/task/job_test.go +++ b/pkg/task/job_test.go @@ -2,12 +2,12 @@ package task import ( "errors" - testMock "github.com/stretchr/testify/mock" "testing" "github.com/DATA-DOG/go-sqlmock" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/stretchr/testify/assert" + testMock "github.com/stretchr/testify/mock" ) func TestRecord(t *testing.T) { @@ -103,4 +103,16 @@ func TestGetJobFromModel(t *testing.T) { asserts.Nil(job) asserts.Error(err) } + // RecycleTaskType + { + task := &model.Task{ + Status: 0, + Type: RecycleTaskType, + } + mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) + job, err := GetJobFromModel(task) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } } diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go new file mode 100644 index 0000000000..23abd96b49 --- /dev/null +++ b/pkg/task/recycle.go @@ -0,0 +1,155 @@ +package task + +import ( + "context" + "encoding/json" + "os" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RecycleTask 文件回收任务 +type RecycleTask struct { + User *model.User + TaskModel *model.Task + TaskProps RecycleProps + Err *JobError +} + +// RecycleProps 回收任务属性 +type RecycleProps struct { + // 回收目录 + Path string `json:"path"` + // 负责处理回收任务的节点ID + NodeID uint `json:"node_id"` +} + +// Props 获取任务属性 +func (job *RecycleTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *RecycleTask) Type() int { + return RecycleTaskType +} + +// Creator 获取创建者ID +func (job *RecycleTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *RecycleTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *RecycleTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *RecycleTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) + +} + +// SetErrorMsg 设定任务失败信息 +func (job *RecycleTask) SetErrorMsg(msg string, err error) { + jobErr := &JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + job.SetError(jobErr) +} + +// GetError 返回任务失败信息 +func (job *RecycleTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RecycleTask) Do() { + if job.TaskProps.NodeID == 1 { + err := os.RemoveAll(job.TaskProps.Path) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) + job.SetErrorMsg("文件回收失败", err) + } + } else { + // 指定为从机回收 + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(job.User) + if err != nil { + job.SetErrorMsg(err.Error(), nil) + return + } + + // 获取从机节点 + node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) + if node == nil { + job.SetErrorMsg("从机节点不可用", nil) + } + + // 切换为从机节点处理回收 + fs.SwitchToSlaveHandler(node) + handler := fs.Handler.(*slaveinmaster.Driver) + err = handler.Recycle(context.Background(), job.TaskProps.Path) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) + job.SetErrorMsg("文件回收失败", err) + } + } +} + +// NewRecycleTask 新建回收任务 +func NewRecycleTask(user uint, path string, node uint) (Job, error) { + creator, err := model.GetActiveUserByID(user) + if err != nil { + return nil, err + } + + newTask := &RecycleTask{ + User: &creator, + TaskProps: RecycleProps{ + Path: path, + NodeID: node, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewRecycleTaskFromModel 从数据库记录中恢复回收任务 +func NewRecycleTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetActiveUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &RecycleTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go new file mode 100644 index 0000000000..3fad4778a9 --- /dev/null +++ b/pkg/task/recycle_test.go @@ -0,0 +1,131 @@ +package task + +import ( + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" +) + +func TestRecycleTask_Props(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + } + asserts.NotEmpty(task.Props()) + asserts.Equal(RecycleTaskType, task.Type()) + asserts.EqualValues(0, task.Creator()) + asserts.Nil(task.Model()) +} + +func TestRecycleTask_SetStatus(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + task.SetStatus(3) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestRecycleTask_SetError(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + User: &model.User{}, + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + task.SetErrorMsg("error", nil) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal("error", task.GetError().Msg) +} + +func TestRecycleTask_Do(t *testing.T) { + asserts := assert.New(t) + task := &RecycleTask{ + TaskModel: &model.Task{ + Model: gorm.Model{ID: 1}, + }, + } + + // 目录不存在 + { + task.TaskProps.Path = "test/not_exist" + task.User = &model.User{ + Policy: model.Policy{ + Type: "unknown", + }, + } + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, + 1)) + mock.ExpectCommit() + task.Do() + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotEmpty(task.GetError().Msg) + } +} + +func TestNewRecycleTask(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + job, err := NewRecycleTask(1, "/", 0) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NotNil(job) + asserts.NoError(err) + } + + // 失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + job, err := NewRecycleTask(1, "test/not_exist", 0) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Nil(job) + asserts.Error(err) + } +} + +func TestNewRecycleTaskFromModel(t *testing.T) { + asserts := assert.New(t) + + // 成功 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.NotNil(job) + } + + // JSON解析失败 + { + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"}) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Error(err) + asserts.Nil(job) + } +} diff --git a/pkg/task/slavetask/recycle.go b/pkg/task/slavetask/recycle.go new file mode 100644 index 0000000000..d8c7bc8db5 --- /dev/null +++ b/pkg/task/slavetask/recycle.go @@ -0,0 +1,95 @@ +package slavetask + +import ( + "os" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/task" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RecycleTask 文件回收任务 +type RecycleTask struct { + Err *task.JobError + Req *serializer.SlaveRecycleReq + MasterID string +} + +// Props 获取任务属性 +func (job *RecycleTask) Props() string { + return "" +} + +// Type 获取任务类型 +func (job *RecycleTask) Type() int { + return 0 +} + +// Creator 获取创建者ID +func (job *RecycleTask) Creator() uint { + return 0 +} + +// Model 获取任务的数据库模型 +func (job *RecycleTask) Model() *model.Task { + return nil +} + +// SetStatus 设定状态 +func (job *RecycleTask) SetStatus(status int) { +} + +// SetError 设定任务失败信息 +func (job *RecycleTask) SetError(err *task.JobError) { + job.Err = err +} + +// SetErrorMsg 设定任务失败信息 +func (job *RecycleTask) SetErrorMsg(msg string, err error) { + jobErr := &task.JobError{Msg: msg} + if err != nil { + jobErr.Error = err.Error() + } + + job.SetError(jobErr) + + notifyMsg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveRecycleFailed, + Content: serializer.SlaveRecycleResult{ + Error: err.Error(), + }, + } + + if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { + util.Log().Warning("无法发送回收失败通知到从机, %s", err) + } +} + +// GetError 返回任务失败信息 +func (job *RecycleTask) GetError() *task.JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RecycleTask) Do() { + err := os.RemoveAll(job.Req.Path) + if err != nil { + util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Path, err) + job.SetErrorMsg("文件回收失败", err) + return + } + + msg := mq.Message{ + TriggeredBy: job.MasterID, + Event: serializer.SlaveRecycleSuccess, + Content: serializer.SlaveRecycleResult{}, + } + + if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { + util.Log().Warning("无法发送回收成功通知到从机, %s", err) + } +} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 20c5fcc969..818028eb4d 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -2,6 +2,8 @@ package slavetask import ( "context" + "os" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" @@ -10,7 +12,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" - "os" ) // TransferTask 文件中转任务 @@ -79,8 +80,6 @@ func (job *TransferTask) GetError() *task.JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - fs, err := filesystem.NewAnonymousFileSystem() if err != nil { job.SetErrorMsg("无法初始化匿名文件系统", err) @@ -137,11 +136,3 @@ func (job *TransferTask) Do() { util.Log().Warning("无法发送转存成功通知到从机, %s", err) } } - -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - err := os.Remove(job.Req.Src) - if err != nil { - util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Src, err) - } -} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 5f9aa58e07..f115e80329 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -3,7 +3,6 @@ package task import ( "context" "encoding/json" - "os" "path" "path/filepath" "strings" @@ -87,8 +86,6 @@ func (job *TransferTask) GetError() *JobError { // Do 开始执行任务 func (job *TransferTask) Do() { - defer job.Recycle() - // 创建文件系统 fs, err := filesystem.NewFileSystem(job.User) if err != nil { @@ -139,16 +136,6 @@ func (job *TransferTask) Do() { } -// Recycle 回收临时文件 -func (job *TransferTask) Recycle() { - if job.TaskProps.NodeID == 1 { - err := os.RemoveAll(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) - } - } -} - // NewTransferTask 新建中转任务 func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) { creator, err := model.GetActiveUserByID(user) diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index e1e7de22da..2b5b15cebc 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -212,6 +212,17 @@ func SlaveCreateTransferTask(c *gin.Context) { } } +// SlaveCreateRecycleTask 从机创建回收任务 +func SlaveCreateRecycleTask(c *gin.Context) { + var service serializer.SlaveRecycleReq + if err := c.ShouldBindJSON(&service); err == nil { + res := explorer.CreateRecycleTask(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // SlaveNotificationPush 处理从机发送的消息推送 func SlaveNotificationPush(c *gin.Context) { var service node.SlaveNotificationService diff --git a/routers/router.go b/routers/router.go index 0727fe6f73..f7586b3a43 100644 --- a/routers/router.go +++ b/routers/router.go @@ -88,6 +88,7 @@ func InitSlaveRouter() *gin.Engine { task := v3.Group("task") { task.PUT("transfer", controllers.SlaveCreateTransferTask) + task.PUT("recycle", controllers.SlaveCreateRecycleTask) } } return r diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 6344ddd62f..6f55a8f972 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -33,7 +33,7 @@ func (service *DownloadListService) Finished(c *gin.Context, user *model.User) s // Downloading 获取正在下载中的任务 func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 - downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready) + downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) intervals := make(map[uint]int) for _, download := range downloads { if _, ok := intervals[download.ID]; !ok { diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 1435640d7e..253fcb876a 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -5,6 +5,10 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" + "net/url" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" @@ -16,9 +20,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" - "net/http" - "net/url" - "time" ) // SlaveDownloadService 从机文件下載服务 @@ -165,6 +166,26 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial return serializer.ParamErr("未知的主机节点ID", nil) } +// CreateRecycleTask 创建从机文件回收任务 +func CreateRecycleTask(c *gin.Context, req *serializer.SlaveRecycleReq) serializer.Response { + if id, ok := c.Get("MasterSiteID"); ok { + job := &slavetask.RecycleTask{ + Req: req, + MasterID: id.(string), + } + + if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + task.TaskPoll.Submit(job.(task.Job)) + }); err != nil { + return serializer.Err(serializer.CodeCreateTaskError, "", err) + } + + return serializer.Response{} + } + + return serializer.ParamErr("未知的主机节点ID", nil) +} + // SlaveListService 从机上传会话服务 type SlaveCreateUploadSessionService struct { Session serializer.UploadSession `json:"session" binding:"required"` diff --git a/service/user/register.go b/service/user/register.go index d3c81b5eae..35e8253d96 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -1,14 +1,15 @@ package user import ( + "net/url" + "strings" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" - "net/url" - "strings" ) // UserRegisterService 管理用户注册的服务 From 915fefc5d463d93a9c8c7073f00ef64c4239fa98 Mon Sep 17 00:00:00 2001 From: XYenon <20698483+XYenon@users.noreply.github.com> Date: Sun, 21 Aug 2022 15:08:39 +0800 Subject: [PATCH 2/3] fix: move RecycleTaskType to the bottom --- pkg/task/job.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/task/job.go b/pkg/task/job.go index 9bf52d74ad..e9d54d8f85 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -13,10 +13,10 @@ const ( DecompressTaskType // TransferTaskType 中转任务 TransferTaskType - // RecycleTaskType 回收任务 - RecycleTaskType // ImportTaskType 导入任务 ImportTaskType + // RecycleTaskType 回收任务 + RecycleTaskType ) // 任务状态 From b7c6155d3d70022af4033667f52c2db6ecb231cc Mon Sep 17 00:00:00 2001 From: XYenon <20698483+XYenon@users.noreply.github.com> Date: Sat, 27 Aug 2022 22:30:42 +0800 Subject: [PATCH 3/3] refactor: refactor recycle aria2 temp file --- pkg/aria2/common/common.go | 4 +- pkg/aria2/monitor/monitor.go | 2 +- .../driver/shadow/slaveinmaster/handler.go | 39 -------- pkg/serializer/slave.go | 20 ---- pkg/task/recycle.go | 71 +++++--------- pkg/task/recycle_test.go | 42 +++----- pkg/task/slavetask/recycle.go | 95 ------------------- routers/controllers/slave.go | 11 --- routers/router.go | 1 - service/aria2/manage.go | 2 +- service/explorer/slave.go | 20 ---- 11 files changed, 41 insertions(+), 266 deletions(-) delete mode 100644 pkg/task/slavetask/recycle.go diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index d4a8313d89..455c89f01a 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -38,8 +38,6 @@ const ( Downloading // Paused 暂停中 Paused - // Seeding 做种中 - Seeding // Error 出错 Error // Complete 完成 @@ -48,6 +46,8 @@ const ( Canceled // Unknown 未知状态 Unknown + // Seeding 做种中 + Seeding ) var ( diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 6f6de7e976..531d6edd16 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -251,7 +251,7 @@ func (monitor *Monitor) Complete(pool task.Pool) bool { // 转存完成,回收下载目录 if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error { - job, err := task.NewRecycleTask(monitor.Task.UserID, monitor.Task.Parent, monitor.node.ID()) + job, err := task.NewRecycleTask(monitor.Task) if err != nil { monitor.setErrorStatus(err) monitor.RemoveTempFolder() diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index 84116394d2..4dd9da876f 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -122,42 +122,3 @@ func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]respo func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { return nil } - -func (d *Driver) Recycle(ctx context.Context, path string) error { - req := serializer.SlaveRecycleReq{ - Path: path, - } - - body, err := json.Marshal(req) - if err != nil { - return err - } - - // 订阅回收结果 - resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0) - defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan) - - res, err := d.client.Request("PUT", "task/recycle", bytes.NewReader(body)). - CheckHTTPResponse(200). - DecodeResponse() - if err != nil { - return err - } - - if res.Code != 0 { - return serializer.NewErrorFromResponse(res) - } - - // 等待回收结果或者超时 - waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800) - select { - case <-time.After(time.Duration(waitTimeout) * time.Second): - return ErrWaitResultTimeout - case msg := <-resChan: - if msg.Event != serializer.SlaveRecycleSuccess { - return errors.New(msg.Content.(serializer.SlaveRecycleResult).Error) - } - } - - return nil -} diff --git a/pkg/serializer/slave.go b/pkg/serializer/slave.go index 4179d4555d..04d56d3d0b 100644 --- a/pkg/serializer/slave.go +++ b/pkg/serializer/slave.go @@ -54,35 +54,15 @@ func (s *SlaveTransferReq) Hash(id string) string { return fmt.Sprintf("%x", bs) } -// SlaveRecycleReq 从机回收任务创建请求 -type SlaveRecycleReq struct { - Path string `json:"path"` -} - -// Hash 返回创建请求的唯一标识,保持创建请求幂等 -func (s *SlaveRecycleReq) Hash(id string) string { - h := sha1.New() - h.Write([]byte(fmt.Sprintf("transfer-%s-%s", id, s.Path))) - bs := h.Sum(nil) - return fmt.Sprintf("%x", bs) -} - const ( SlaveTransferSuccess = "success" SlaveTransferFailed = "failed" - SlaveRecycleSuccess = "success" - SlaveRecycleFailed = "failed" ) type SlaveTransferResult struct { Error string } -type SlaveRecycleResult struct { - Error string -} - func init() { gob.Register(SlaveTransferResult{}) - gob.Register(SlaveRecycleResult{}) } diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go index 23abd96b49..17eaf3c21e 100644 --- a/pkg/task/recycle.go +++ b/pkg/task/recycle.go @@ -1,14 +1,10 @@ package task import ( - "context" "encoding/json" - "os" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -22,10 +18,8 @@ type RecycleTask struct { // RecycleProps 回收任务属性 type RecycleProps struct { - // 回收目录 - Path string `json:"path"` - // 负责处理回收任务的节点ID - NodeID uint `json:"node_id"` + // 下载任务 GID + DownloadGID string `json:"download_gid"` } // Props 获取任务属性 @@ -59,7 +53,6 @@ func (job *RecycleTask) SetError(err *JobError) { job.Err = err res, _ := json.Marshal(job.Err) job.TaskModel.SetError(string(res)) - } // SetErrorMsg 设定任务失败信息 @@ -78,51 +71,33 @@ func (job *RecycleTask) GetError() *JobError { // Do 开始执行任务 func (job *RecycleTask) Do() { - if job.TaskProps.NodeID == 1 { - err := os.RemoveAll(job.TaskProps.Path) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) - job.SetErrorMsg("文件回收失败", err) - } - } else { - // 指定为从机回收 - - // 创建文件系统 - fs, err := filesystem.NewFileSystem(job.User) - if err != nil { - job.SetErrorMsg(err.Error(), nil) - return - } - - // 获取从机节点 - node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) - if node == nil { - job.SetErrorMsg("从机节点不可用", nil) - } - - // 切换为从机节点处理回收 - fs.SwitchToSlaveHandler(node) - handler := fs.Handler.(*slaveinmaster.Driver) - err = handler.Recycle(context.Background(), job.TaskProps.Path) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Path, err) - job.SetErrorMsg("文件回收失败", err) - } + download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID) + if err != nil { + util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID) + job.SetErrorMsg("无法找到下载任务", err) + return } -} - -// NewRecycleTask 新建回收任务 -func NewRecycleTask(user uint, path string, node uint) (Job, error) { - creator, err := model.GetActiveUserByID(user) + nodeID := download.GetNodeID() + node := cluster.Default.GetNodeByID(nodeID) + if node == nil { + util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID) + job.SetErrorMsg("从机节点不可用", nil) + return + } + err = node.GetAria2Instance().DeleteTempFile(download) if err != nil { - return nil, err + util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err) + job.SetErrorMsg("文件回收失败", err) + return } +} +// NewRecycleTask 新建回收任务 +func NewRecycleTask(download *model.Download) (Job, error) { newTask := &RecycleTask{ - User: &creator, + User: download.GetOwner(), TaskProps: RecycleProps{ - Path: path, - NodeID: node, + DownloadGID: download.GID, }, } diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go index 3fad4778a9..0092a30c11 100644 --- a/pkg/task/recycle_test.go +++ b/pkg/task/recycle_test.go @@ -54,32 +54,6 @@ func TestRecycleTask_SetError(t *testing.T) { asserts.Equal("error", task.GetError().Msg) } -func TestRecycleTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 目录不存在 - { - task.TaskProps.Path = "test/not_exist" - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } -} - func TestNewRecycleTask(t *testing.T) { asserts := assert.New(t) @@ -89,7 +63,13 @@ func TestNewRecycleTask(t *testing.T) { mock.ExpectBegin() mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() - job, err := NewRecycleTask(1, "/", 0) + job, err := NewRecycleTask(&model.Download{ + Model: gorm.Model{ID: 1}, + GID: "test_g_id", + Parent: "/", + UserID: 1, + NodeID: 1, + }) asserts.NoError(mock.ExpectationsWereMet()) asserts.NotNil(job) asserts.NoError(err) @@ -101,7 +81,13 @@ func TestNewRecycleTask(t *testing.T) { mock.ExpectBegin() mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) mock.ExpectRollback() - job, err := NewRecycleTask(1, "test/not_exist", 0) + job, err := NewRecycleTask(&model.Download{ + Model: gorm.Model{ID: 1}, + GID: "test_g_id", + Parent: "test/not_exist", + UserID: 1, + NodeID: 1, + }) asserts.NoError(mock.ExpectationsWereMet()) asserts.Nil(job) asserts.Error(err) diff --git a/pkg/task/slavetask/recycle.go b/pkg/task/slavetask/recycle.go deleted file mode 100644 index d8c7bc8db5..0000000000 --- a/pkg/task/slavetask/recycle.go +++ /dev/null @@ -1,95 +0,0 @@ -package slavetask - -import ( - "os" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/task" - "github.com/cloudreve/Cloudreve/v3/pkg/util" -) - -// RecycleTask 文件回收任务 -type RecycleTask struct { - Err *task.JobError - Req *serializer.SlaveRecycleReq - MasterID string -} - -// Props 获取任务属性 -func (job *RecycleTask) Props() string { - return "" -} - -// Type 获取任务类型 -func (job *RecycleTask) Type() int { - return 0 -} - -// Creator 获取创建者ID -func (job *RecycleTask) Creator() uint { - return 0 -} - -// Model 获取任务的数据库模型 -func (job *RecycleTask) Model() *model.Task { - return nil -} - -// SetStatus 设定状态 -func (job *RecycleTask) SetStatus(status int) { -} - -// SetError 设定任务失败信息 -func (job *RecycleTask) SetError(err *task.JobError) { - job.Err = err -} - -// SetErrorMsg 设定任务失败信息 -func (job *RecycleTask) SetErrorMsg(msg string, err error) { - jobErr := &task.JobError{Msg: msg} - if err != nil { - jobErr.Error = err.Error() - } - - job.SetError(jobErr) - - notifyMsg := mq.Message{ - TriggeredBy: job.MasterID, - Event: serializer.SlaveRecycleFailed, - Content: serializer.SlaveRecycleResult{ - Error: err.Error(), - }, - } - - if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { - util.Log().Warning("无法发送回收失败通知到从机, %s", err) - } -} - -// GetError 返回任务失败信息 -func (job *RecycleTask) GetError() *task.JobError { - return job.Err -} - -// Do 开始执行任务 -func (job *RecycleTask) Do() { - err := os.RemoveAll(job.Req.Path) - if err != nil { - util.Log().Warning("无法删除中转临时文件[%s], %s", job.Req.Path, err) - job.SetErrorMsg("文件回收失败", err) - return - } - - msg := mq.Message{ - TriggeredBy: job.MasterID, - Event: serializer.SlaveRecycleSuccess, - Content: serializer.SlaveRecycleResult{}, - } - - if err = cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { - util.Log().Warning("无法发送回收成功通知到从机, %s", err) - } -} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 2b5b15cebc..e1e7de22da 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -212,17 +212,6 @@ func SlaveCreateTransferTask(c *gin.Context) { } } -// SlaveCreateRecycleTask 从机创建回收任务 -func SlaveCreateRecycleTask(c *gin.Context) { - var service serializer.SlaveRecycleReq - if err := c.ShouldBindJSON(&service); err == nil { - res := explorer.CreateRecycleTask(c, &service) - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} - // SlaveNotificationPush 处理从机发送的消息推送 func SlaveNotificationPush(c *gin.Context) { var service node.SlaveNotificationService diff --git a/routers/router.go b/routers/router.go index f7586b3a43..0727fe6f73 100644 --- a/routers/router.go +++ b/routers/router.go @@ -88,7 +88,6 @@ func InitSlaveRouter() *gin.Engine { task := v3.Group("task") { task.PUT("transfer", controllers.SlaveCreateTransferTask) - task.PUT("recycle", controllers.SlaveCreateRecycleTask) } } return r diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 6f55a8f972..115a440532 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -57,7 +57,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { return serializer.Err(serializer.CodeNotFound, "Download record not found", err) } - if download.Status >= common.Error { + if download.Status >= common.Error && download.Status <= common.Unknown { // 如果任务已完成,则删除任务记录 if err := download.Delete(); err != nil { return serializer.DBErr("Failed to delete task record", err) diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 253fcb876a..afb61af60d 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -166,26 +166,6 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial return serializer.ParamErr("未知的主机节点ID", nil) } -// CreateRecycleTask 创建从机文件回收任务 -func CreateRecycleTask(c *gin.Context, req *serializer.SlaveRecycleReq) serializer.Response { - if id, ok := c.Get("MasterSiteID"); ok { - job := &slavetask.RecycleTask{ - Req: req, - MasterID: id.(string), - } - - if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { - task.TaskPoll.Submit(job.(task.Job)) - }); err != nil { - return serializer.Err(serializer.CodeCreateTaskError, "", err) - } - - return serializer.Response{} - } - - return serializer.ParamErr("未知的主机节点ID", nil) -} - // SlaveListService 从机上传会话服务 type SlaveCreateUploadSessionService struct { Session serializer.UploadSession `json:"session" binding:"required"`