package repo import ( "context" "errors" "reflect" "taotie-api/utils/sctx" "time" "github.com/duke-git/lancet/slice" "github.com/google/wire" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // 依赖注入节点 var RepoProd = wire.NewSet( NewUserRepo, NewTenantRepo, ) type BaseRepo[T IEntity] struct { mdb *MongoDb tableName string } // ProcessFilter 处理查询条件 func (rp *BaseRepo[T]) RawProcessFilter(ctx context.Context, qp *QueryParams) bson.M { filter := bson.M{} // 查询条件 if len(qp.Where) != 0 { wh := make(map[string]any, 0) for k, v := range qp.Where { if (k[len(k)-2:]) == "Id" { oid, _ := primitive.ObjectIDFromHex(v.(string)) wh[k] = oid continue } wh[k] = v } filter = wh } // 查询 ids if len(qp.WhereInIds) != 0 { oids := make([]primitive.ObjectID, 0) for _, id := range qp.WhereInIds { oid, _ := primitive.ObjectIDFromHex(id) oids = append(oids, oid) } filter["_id"] = bson.M{"$in": oids} } // 过滤删除 filter["deletedAt"] = 0 // 过滤租户 cuser := sctx.GetCurrentUser(ctx) if cuser != nil { filter["tenantId"], _ = primitive.ObjectIDFromHex(cuser.TenantId) } return filter } // ProcessSetData 处理需要修改的数据,限制只能修改已拥有的字段 func (rp *BaseRepo[T]) RawProcessSetData(ctx context.Context, setData map[string]any) (map[string]any, error) { // 提取结构体 t := reflect.TypeFor[T]().Elem() if t.Kind() == reflect.Ptr { t = t.Elem() } // 提取结构体字段 fields := make([]string, 0, t.NumField()) for i := 0; i < t.NumField(); i++ { var bs = t.Field(i).Tag.Get("bson") if bs == ",inline" { continue } fields = append(fields, bs) } // 筛选 sd := make(map[string]any, 0) for k, v := range setData { if slice.Contain(fields, k) { // 处理 ObjectID 类型 if (k[len(k)-2:]) == "Id" { oid, _ := primitive.ObjectIDFromHex(v.(string)) sd[k] = oid continue } if (k[len(k)-3:]) == "Ids" { oids := make([]primitive.ObjectID, 0) for _, id := range v.([]any) { oid, _ := primitive.ObjectIDFromHex(id.(string)) oids = append(oids, oid) } sd[k] = oids continue } sd[k] = v } } return sd, nil } // Create 创建 func (rp *BaseRepo[T]) Create(ctx context.Context, tdata T) (string, error) { collection := rp.mdb.Db().Collection(rp.tableName) // 设置默认值 cuser := sctx.GetCurrentUser(ctx) tdata.SetCreatedBy(cuser.UserId) // 默认创建人 tdata.SetTenantId(cuser.TenantId) // 默认租户 // 设置基础时间 now := time.Now().UnixMilli() tdata.SetNow(now) result, err := collection.InsertOne(ctx, tdata) if err != nil { return "", err } // 插入成功后,返回插入的ID return result.InsertedID.(primitive.ObjectID).Hex(), nil } // CreateMany 创建多个 func (rp *BaseRepo[T]) CreateMany(ctx context.Context, tdatas []T) (ids []string, err error) { collection := rp.mdb.Db().Collection(rp.tableName) // 设置默认值 cuser := sctx.GetCurrentUser(ctx) now := time.Now().UnixMilli() tdataanys := make([]any, 0, len(tdatas)) for _, v := range tdatas { if cuser != nil { v.SetCreatedBy(cuser.UserId) // 默认创建人 v.SetTenantId(cuser.TenantId) // 默认租户 } // 设置基础时间 v.SetNow(now) // 转换 tdataanys = append(tdataanys, v) } // 插入数据 result, err := collection.InsertMany(ctx, tdataanys) if err != nil { return nil, err } ids = make([]string, 0, len(result.InsertedIDs)) for _, id := range result.InsertedIDs { ids = append(ids, id.(primitive.ObjectID).Hex()) } return ids, nil } // Delete 删除(软删除) func (rp *BaseRepo[T]) Delete(ctx context.Context, qp *QueryParams) error { collection := rp.mdb.Db().Collection(rp.tableName) filter := rp.RawProcessFilter(ctx, qp) // 查询参数不能为空 if len(filter) == 0 { return errors.New("查询参数不能为空") } // 软删除 _, err := collection.UpdateMany(ctx, filter, bson.M{"$set": bson.M{"deletedAt": time.Now().UnixMilli()}}) return err } func (rp *BaseRepo[T]) Update(ctx context.Context, qp *QueryParams, setData map[string]any) error { collection := rp.mdb.Db().Collection(rp.tableName) sd, err := rp.RawProcessSetData(ctx, setData) if err != nil { return err } // 检查是否有可更新的字段 if len(sd) == 0 { return errors.New("更新参数不能为空") } // 更新时间 sd["updatedAt"] = time.Now().UnixMilli() filter := rp.RawProcessFilter(ctx, qp) if len(filter) == 0 { return errors.New("查询参数不能为空") } // 更新数据 _, err = collection.UpdateMany(ctx, filter, bson.M{"$set": sd}) if err != nil { return err } return nil } // Count 查询数量 func (rp *BaseRepo[T]) RawCount(ctx context.Context, qp *QueryParams) (int64, error) { collection := rp.mdb.Db().Collection(rp.tableName) filter := rp.RawProcessFilter(ctx, qp) // 查询总量 total, err := collection.CountDocuments(ctx, filter) if err != nil { return 0, err } return total, nil } // Find 查询 func (rp *BaseRepo[T]) RawFind(ctx context.Context, qp *QueryParams) (results []T, total int64, err error) { collection := rp.mdb.Db().Collection(rp.tableName) filter := rp.RawProcessFilter(ctx, qp) opts := &options.FindOptions{} // 设置分页 if qp.Page > 0 && qp.PageSize > 0 { opts.SetSkip(int64((qp.Page - 1) * qp.PageSize)) opts.SetLimit(int64(qp.PageSize)) } // 设置排序 if len(qp.OrderBy) != 0 { opts.SetSort(qp.OrderBy) } // 执行查询 cursor, err := collection.Find(ctx, filter, opts) if err != nil { return nil, 0, err } defer cursor.Close(ctx) // 解析结果 if err = cursor.All(ctx, &results); err != nil { // 处理空文档错误 if err == mongo.ErrNilDocument { return nil, 0, nil } return nil, 0, err } // 查询总量 total, err = collection.CountDocuments(ctx, filter) if err != nil { return nil, 0, err } return results, total, nil } // FindOne 查询单条记录 func (rp *BaseRepo[T]) RawFindOne(ctx context.Context, qp *QueryParams) (result T, err error) { collection := rp.mdb.Db().Collection(rp.tableName) filter := rp.RawProcessFilter(ctx, qp) // 执行查询 err = collection.FindOne(ctx, filter).Decode(&result) if err != nil { // 处理空文档错误 if err == mongo.ErrNilDocument { return result, nil } return result, err } return result, nil }