Files
taotie-api/repo/repo.go

295 lines
6.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repo
import (
"context"
"errors"
"reflect"
"taotie-api/utils/sctx"
"time"
"github.com/duke-git/lancet/slice"
"github.com/google/wire"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// 依赖注入节点
var RepoProd = wire.NewSet(
NewUserRepo,
NewTenantRepo,
)
type BaseRepo[T IEntity] struct {
mdb *MongoDb
tableName string
}
// ProcessFilter 处理查询条件
func (rp *BaseRepo[T]) RawProcessFilter(ctx context.Context, qp *QueryParams) bson.M {
filter := bson.M{}
// 查询条件
if len(qp.Where) != 0 {
wh := make(map[string]any, 0)
for k, v := range qp.Where {
if (k[len(k)-2:]) == "Id" {
oid, _ := primitive.ObjectIDFromHex(v.(string))
wh[k] = oid
continue
}
wh[k] = v
}
filter = wh
}
// 查询 ids
if len(qp.WhereInIds) != 0 {
oids := make([]primitive.ObjectID, 0)
for _, id := range qp.WhereInIds {
oid, _ := primitive.ObjectIDFromHex(id)
oids = append(oids, oid)
}
filter["_id"] = bson.M{"$in": oids}
}
// 过滤删除
filter["deletedAt"] = 0
// 过滤租户
cuser := sctx.GetCurrentUser(ctx)
if cuser != nil {
filter["tenantId"], _ = primitive.ObjectIDFromHex(cuser.TenantId)
}
return filter
}
// ProcessSetData 处理需要修改的数据,限制只能修改已拥有的字段
func (rp *BaseRepo[T]) RawProcessSetData(ctx context.Context, setData map[string]any) (map[string]any, error) {
// 提取结构体
t := reflect.TypeFor[T]().Elem()
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
// 提取结构体字段
fields := make([]string, 0, t.NumField())
for i := 0; i < t.NumField(); i++ {
var bs = t.Field(i).Tag.Get("bson")
if bs == ",inline" {
continue
}
fields = append(fields, bs)
}
// 筛选
sd := make(map[string]any, 0)
for k, v := range setData {
if slice.Contain(fields, k) {
// 处理 ObjectID 类型
if (k[len(k)-2:]) == "Id" {
oid, _ := primitive.ObjectIDFromHex(v.(string))
sd[k] = oid
continue
}
if (k[len(k)-3:]) == "Ids" {
oids := make([]primitive.ObjectID, 0)
for _, id := range v.([]any) {
oid, _ := primitive.ObjectIDFromHex(id.(string))
oids = append(oids, oid)
}
sd[k] = oids
continue
}
sd[k] = v
}
}
return sd, nil
}
// Create 创建
func (rp *BaseRepo[T]) Create(ctx context.Context, tdata T) (string, error) {
collection := rp.mdb.Db().Collection(rp.tableName)
// 设置默认值
cuser := sctx.GetCurrentUser(ctx)
tdata.SetCreatedBy(cuser.UserId) // 默认创建人
tdata.SetTenantId(cuser.TenantId) // 默认租户
// 设置基础时间
now := time.Now().UnixMilli()
tdata.SetNow(now)
result, err := collection.InsertOne(ctx, tdata)
if err != nil {
return "", err
}
// 插入成功后返回插入的ID
return result.InsertedID.(primitive.ObjectID).Hex(), nil
}
// CreateMany 创建多个
func (rp *BaseRepo[T]) CreateMany(ctx context.Context, tdatas []T) (ids []string, err error) {
collection := rp.mdb.Db().Collection(rp.tableName)
// 设置默认值
cuser := sctx.GetCurrentUser(ctx)
now := time.Now().UnixMilli()
tdataanys := make([]any, 0, len(tdatas))
for _, v := range tdatas {
if cuser != nil {
v.SetCreatedBy(cuser.UserId) // 默认创建人
v.SetTenantId(cuser.TenantId) // 默认租户
}
// 设置基础时间
v.SetNow(now)
// 转换
tdataanys = append(tdataanys, v)
}
// 插入数据
result, err := collection.InsertMany(ctx, tdataanys)
if err != nil {
return nil, err
}
ids = make([]string, 0, len(result.InsertedIDs))
for _, id := range result.InsertedIDs {
ids = append(ids, id.(primitive.ObjectID).Hex())
}
return ids, nil
}
// Delete 删除(软删除)
func (rp *BaseRepo[T]) Delete(ctx context.Context, qp *QueryParams) error {
collection := rp.mdb.Db().Collection(rp.tableName)
filter := rp.RawProcessFilter(ctx, qp)
// 查询参数不能为空
if len(filter) == 0 {
return errors.New("查询参数不能为空")
}
// 软删除
_, err := collection.UpdateMany(ctx, filter, bson.M{"$set": bson.M{"deletedAt": time.Now().UnixMilli()}})
return err
}
func (rp *BaseRepo[T]) Update(ctx context.Context, qp *QueryParams, setData map[string]any) error {
collection := rp.mdb.Db().Collection(rp.tableName)
sd, err := rp.RawProcessSetData(ctx, setData)
if err != nil {
return err
}
// 检查是否有可更新的字段
if len(sd) == 0 {
return errors.New("更新参数不能为空")
}
// 更新时间
sd["updatedAt"] = time.Now().UnixMilli()
filter := rp.RawProcessFilter(ctx, qp)
if len(filter) == 0 {
return errors.New("查询参数不能为空")
}
// 更新数据
_, err = collection.UpdateMany(ctx, filter, bson.M{"$set": sd})
if err != nil {
return err
}
return nil
}
// Count 查询数量
func (rp *BaseRepo[T]) RawCount(ctx context.Context, qp *QueryParams) (int64, error) {
collection := rp.mdb.Db().Collection(rp.tableName)
filter := rp.RawProcessFilter(ctx, qp)
// 查询总量
total, err := collection.CountDocuments(ctx, filter)
if err != nil {
return 0, err
}
return total, nil
}
// Find 查询
func (rp *BaseRepo[T]) RawFind(ctx context.Context, qp *QueryParams) (results []T, total int64, err error) {
collection := rp.mdb.Db().Collection(rp.tableName)
filter := rp.RawProcessFilter(ctx, qp)
opts := &options.FindOptions{}
// 设置分页
if qp.Page > 0 && qp.PageSize > 0 {
opts.SetSkip(int64((qp.Page - 1) * qp.PageSize))
opts.SetLimit(int64(qp.PageSize))
}
// 设置排序
if len(qp.OrderBy) != 0 {
opts.SetSort(qp.OrderBy)
}
// 执行查询
cursor, err := collection.Find(ctx, filter, opts)
if err != nil {
return nil, 0, err
}
defer cursor.Close(ctx)
// 解析结果
if err = cursor.All(ctx, &results); err != nil {
// 处理空文档错误
if err == mongo.ErrNilDocument {
return nil, 0, nil
}
return nil, 0, err
}
// 查询总量
total, err = collection.CountDocuments(ctx, filter)
if err != nil {
return nil, 0, err
}
return results, total, nil
}
// FindOne 查询单条记录
func (rp *BaseRepo[T]) RawFindOne(ctx context.Context, qp *QueryParams) (result T, err error) {
collection := rp.mdb.Db().Collection(rp.tableName)
filter := rp.RawProcessFilter(ctx, qp)
// 执行查询
err = collection.FindOne(ctx, filter).Decode(&result)
if err != nil {
// 处理空文档错误
if err == mongo.ErrNilDocument {
return result, nil
}
return result, err
}
return result, nil
}