-
Notifications
You must be signed in to change notification settings - Fork 0
/
gorm.go
124 lines (109 loc) · 2.64 KB
/
gorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright 2023 The golang.design Initiative Authors.
// All rights reserved. Use of this source code is governed
// by a MIT license that can be found in the LICENSE file.
//
// Written by Changkun Ou <changkun.de>
package iter
import (
"errors"
"sync/atomic"
"gorm.io/gorm"
)
// errStop is the error returned by the database iterator when
// the iteration is stopped.
var errStop = errors.New("iter: stop the iteration")
// GormIter is an gorm.DB compatible database iterator.
// To use this iterator, for example:
//
// it := NewGormIter[T](tx, batchSize)
// for batch, ok := it.Next(); ok; batch, ok = it.Next() {
// // Process the batch.
// ...
// // Stop the iteration if necessary.
// if ... {
// it.Stop()
// break
// }
// }
// if err := it.Err(); err != nil {
// // Handle the error.
// }
//
// This iterator is not safe to use after Err() returns a non-nil error.
// This iterator is not safe to use after Next() returns false.
// This iterator is not safe to use after Stop() is called.
// This iterator is not safe to use after the underlying database
// connection is closed.
type GormIter[T any] struct {
tx *gorm.DB
batchSize int
next chan chan []T
stop chan struct{}
finished atomic.Bool
err chan error
}
// NewBatchFromGorm creates a new database iterator with the given batch size.
func NewBatchFromGorm[T any](tx *gorm.DB, batchSize int) *GormIter[T] {
it := &GormIter[T]{
tx: tx,
batchSize: batchSize,
next: make(chan chan []T),
stop: make(chan struct{}),
err: make(chan error, 1),
}
go it.batchFinder()
return it
}
func (it *GormIter[T]) batchFinder() {
var current []T
err := it.tx.FindInBatches(¤t, it.batchSize, func(tx *gorm.DB, _ int) error {
rows := make([]T, len(current))
copy(rows, current)
select {
case <-it.stop:
return errStop
case ch := <-it.next:
ch <- rows
close(ch)
}
return nil
}).Error
it.err <- err
close(it.err)
it.Stop()
}
// Next implements StopErrIter[T].
func (it *GormIter[T]) Next() ([]T, bool) {
// Proceed to the next batch.
done := make(chan []T)
select {
case <-it.stop:
return nil, false
case it.next <- done:
}
// Wait for the next batch.
rows := <-done
return rows, true
}
// Stop implements StopErrIter[T].
func (it *GormIter[T]) Stop() {
for {
finished := it.finished.Load()
if finished {
return
}
// If the swap success, then we can close the channel and return.
if it.finished.CompareAndSwap(finished, true) {
close(it.stop)
return
}
}
}
// Err implements StopErrIter[T].
func (it *GormIter[T]) Err() error {
err := <-it.err
if err != nil && err != errStop {
return err
}
return nil
}