295 lines
6.4 KiB
Go
295 lines
6.4 KiB
Go
package repo
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"reflect"
|
||
"taotie-api/utils/sctx"
|
||
"time"
|
||
|
||
"github.com/duke-git/lancet/slice"
|
||
"github.com/google/wire"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
)
|
||
|
||
// 依赖注入节点
|
||
var RepoProd = wire.NewSet(
|
||
NewUserRepo,
|
||
NewTenantRepo,
|
||
)
|
||
|
||
type BaseRepo[T IEntity] struct {
|
||
mdb *MongoDb
|
||
tableName string
|
||
}
|
||
|
||
// ProcessFilter 处理查询条件
|
||
func (rp *BaseRepo[T]) RawProcessFilter(ctx context.Context, qp *QueryParams) bson.M {
|
||
filter := bson.M{}
|
||
|
||
// 查询条件
|
||
if len(qp.Where) != 0 {
|
||
wh := make(map[string]any, 0)
|
||
for k, v := range qp.Where {
|
||
if (k[len(k)-2:]) == "Id" {
|
||
oid, _ := primitive.ObjectIDFromHex(v.(string))
|
||
wh[k] = oid
|
||
continue
|
||
}
|
||
wh[k] = v
|
||
}
|
||
filter = wh
|
||
}
|
||
|
||
// 查询 ids
|
||
if len(qp.WhereInIds) != 0 {
|
||
oids := make([]primitive.ObjectID, 0)
|
||
for _, id := range qp.WhereInIds {
|
||
oid, _ := primitive.ObjectIDFromHex(id)
|
||
oids = append(oids, oid)
|
||
}
|
||
filter["_id"] = bson.M{"$in": oids}
|
||
}
|
||
|
||
// 过滤删除
|
||
filter["deletedAt"] = 0
|
||
|
||
// 过滤租户
|
||
cuser := sctx.GetCurrentUser(ctx)
|
||
if cuser != nil {
|
||
filter["tenantId"], _ = primitive.ObjectIDFromHex(cuser.TenantId)
|
||
}
|
||
|
||
return filter
|
||
}
|
||
|
||
// ProcessSetData 处理需要修改的数据,限制只能修改已拥有的字段
|
||
func (rp *BaseRepo[T]) RawProcessSetData(ctx context.Context, setData map[string]any) (map[string]any, error) {
|
||
// 提取结构体
|
||
t := reflect.TypeFor[T]().Elem()
|
||
if t.Kind() == reflect.Ptr {
|
||
t = t.Elem()
|
||
}
|
||
|
||
// 提取结构体字段
|
||
fields := make([]string, 0, t.NumField())
|
||
for i := 0; i < t.NumField(); i++ {
|
||
var bs = t.Field(i).Tag.Get("bson")
|
||
if bs == ",inline" {
|
||
continue
|
||
}
|
||
fields = append(fields, bs)
|
||
}
|
||
|
||
// 筛选
|
||
sd := make(map[string]any, 0)
|
||
for k, v := range setData {
|
||
if slice.Contain(fields, k) {
|
||
// 处理 ObjectID 类型
|
||
if (k[len(k)-2:]) == "Id" {
|
||
oid, _ := primitive.ObjectIDFromHex(v.(string))
|
||
sd[k] = oid
|
||
continue
|
||
}
|
||
if (k[len(k)-3:]) == "Ids" {
|
||
oids := make([]primitive.ObjectID, 0)
|
||
for _, id := range v.([]any) {
|
||
oid, _ := primitive.ObjectIDFromHex(id.(string))
|
||
oids = append(oids, oid)
|
||
}
|
||
sd[k] = oids
|
||
continue
|
||
}
|
||
|
||
sd[k] = v
|
||
}
|
||
}
|
||
|
||
return sd, nil
|
||
}
|
||
|
||
// Create 创建
|
||
func (rp *BaseRepo[T]) Create(ctx context.Context, tdata T) (string, error) {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
// 设置默认值
|
||
cuser := sctx.GetCurrentUser(ctx)
|
||
tdata.SetCreatedBy(cuser.UserId) // 默认创建人
|
||
tdata.SetTenantId(cuser.TenantId) // 默认租户
|
||
|
||
// 设置基础时间
|
||
now := time.Now().UnixMilli()
|
||
tdata.SetNow(now)
|
||
|
||
result, err := collection.InsertOne(ctx, tdata)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 插入成功后,返回插入的ID
|
||
return result.InsertedID.(primitive.ObjectID).Hex(), nil
|
||
}
|
||
|
||
// CreateMany 创建多个
|
||
func (rp *BaseRepo[T]) CreateMany(ctx context.Context, tdatas []T) (ids []string, err error) {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
// 设置默认值
|
||
cuser := sctx.GetCurrentUser(ctx)
|
||
now := time.Now().UnixMilli()
|
||
tdataanys := make([]any, 0, len(tdatas))
|
||
for _, v := range tdatas {
|
||
if cuser != nil {
|
||
v.SetCreatedBy(cuser.UserId) // 默认创建人
|
||
v.SetTenantId(cuser.TenantId) // 默认租户
|
||
}
|
||
|
||
// 设置基础时间
|
||
v.SetNow(now)
|
||
|
||
// 转换
|
||
tdataanys = append(tdataanys, v)
|
||
}
|
||
// 插入数据
|
||
result, err := collection.InsertMany(ctx, tdataanys)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ids = make([]string, 0, len(result.InsertedIDs))
|
||
for _, id := range result.InsertedIDs {
|
||
ids = append(ids, id.(primitive.ObjectID).Hex())
|
||
}
|
||
|
||
return ids, nil
|
||
}
|
||
|
||
// Delete 删除(软删除)
|
||
func (rp *BaseRepo[T]) Delete(ctx context.Context, qp *QueryParams) error {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
filter := rp.RawProcessFilter(ctx, qp)
|
||
|
||
// 查询参数不能为空
|
||
if len(filter) == 0 {
|
||
return errors.New("查询参数不能为空")
|
||
}
|
||
|
||
// 软删除
|
||
_, err := collection.UpdateMany(ctx, filter, bson.M{"$set": bson.M{"deletedAt": time.Now().UnixMilli()}})
|
||
|
||
return err
|
||
}
|
||
|
||
func (rp *BaseRepo[T]) Update(ctx context.Context, qp *QueryParams, setData map[string]any) error {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
sd, err := rp.RawProcessSetData(ctx, setData)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 检查是否有可更新的字段
|
||
if len(sd) == 0 {
|
||
return errors.New("更新参数不能为空")
|
||
}
|
||
|
||
// 更新时间
|
||
sd["updatedAt"] = time.Now().UnixMilli()
|
||
|
||
filter := rp.RawProcessFilter(ctx, qp)
|
||
|
||
if len(filter) == 0 {
|
||
return errors.New("查询参数不能为空")
|
||
}
|
||
|
||
// 更新数据
|
||
_, err = collection.UpdateMany(ctx, filter, bson.M{"$set": sd})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Count 查询数量
|
||
func (rp *BaseRepo[T]) RawCount(ctx context.Context, qp *QueryParams) (int64, error) {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
filter := rp.RawProcessFilter(ctx, qp)
|
||
|
||
// 查询总量
|
||
total, err := collection.CountDocuments(ctx, filter)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return total, nil
|
||
}
|
||
|
||
// Find 查询
|
||
func (rp *BaseRepo[T]) RawFind(ctx context.Context, qp *QueryParams) (results []T, total int64, err error) {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
filter := rp.RawProcessFilter(ctx, qp)
|
||
|
||
opts := &options.FindOptions{}
|
||
|
||
// 设置分页
|
||
if qp.Page > 0 && qp.PageSize > 0 {
|
||
opts.SetSkip(int64((qp.Page - 1) * qp.PageSize))
|
||
opts.SetLimit(int64(qp.PageSize))
|
||
}
|
||
|
||
// 设置排序
|
||
if len(qp.OrderBy) != 0 {
|
||
opts.SetSort(qp.OrderBy)
|
||
}
|
||
|
||
// 执行查询
|
||
cursor, err := collection.Find(ctx, filter, opts)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
defer cursor.Close(ctx)
|
||
|
||
// 解析结果
|
||
if err = cursor.All(ctx, &results); err != nil {
|
||
// 处理空文档错误
|
||
if err == mongo.ErrNilDocument {
|
||
return nil, 0, nil
|
||
}
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 查询总量
|
||
total, err = collection.CountDocuments(ctx, filter)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return results, total, nil
|
||
}
|
||
|
||
// FindOne 查询单条记录
|
||
func (rp *BaseRepo[T]) RawFindOne(ctx context.Context, qp *QueryParams) (result T, err error) {
|
||
collection := rp.mdb.Db().Collection(rp.tableName)
|
||
|
||
filter := rp.RawProcessFilter(ctx, qp)
|
||
|
||
// 执行查询
|
||
err = collection.FindOne(ctx, filter).Decode(&result)
|
||
if err != nil {
|
||
// 处理空文档错误
|
||
if err == mongo.ErrNilDocument {
|
||
return result, nil
|
||
}
|
||
return result, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|