Skip to content

Commit

Permalink
fix: enable limiter for restful server(#35350) (#35354)
Browse files Browse the repository at this point in the history
related: #35350

Signed-off-by: MrPresent-Han <[email protected]>
Co-authored-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han and MrPresent-Han authored Aug 13, 2024
1 parent 0320961 commit 20e2658
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 35 deletions.
43 changes: 43 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,13 @@ func (h *HandlersV1) query(c *gin.Context) {
if !h.checkDatabase(ctx, c, req.DbName) {
return
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
response, err := h.proxy.Query(ctx, &req)
if err == nil {
err = merr.Error(response.GetStatus())
Expand Down Expand Up @@ -523,6 +530,13 @@ func (h *HandlersV1) get(c *gin.Context) {
})
return
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
req.Expr = filter
response, err := h.proxy.Query(ctx, &req)
if err == nil {
Expand Down Expand Up @@ -594,6 +608,13 @@ func (h *HandlersV1) delete(c *gin.Context) {
}
req.Expr = filter
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
response, err := h.proxy.Delete(ctx, &req)
if err == nil {
err = merr.Error(response.GetStatus())
Expand Down Expand Up @@ -669,6 +690,14 @@ func (h *HandlersV1) insert(c *gin.Context) {
})
return
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
log.Warn("high level restful api, fail to insert for limiting", zap.Error(err))
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
response, err := h.proxy.Insert(ctx, &req)
if err == nil {
err = merr.Error(response.GetStatus())
Expand Down Expand Up @@ -765,6 +794,13 @@ func (h *HandlersV1) upsert(c *gin.Context) {
})
return
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
response, err := h.proxy.Upsert(ctx, &req)
if err == nil {
err = merr.Error(response.GetStatus())
Expand Down Expand Up @@ -859,6 +895,13 @@ func (h *HandlersV1) search(c *gin.Context) {
if !h.checkDatabase(ctx, c, req.DbName) {
return
}
if _, err := CheckLimiter(&req, h.proxy); err != nil {
c.AbortWithStatusJSON(http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(merr.ErrHTTPRateLimit),
HTTPReturnMessage: merr.ErrHTTPRateLimit.Error() + ", error: " + err.Error(),
})
return
}
response, err := h.proxy.Search(ctx, &req)
if err == nil {
err = merr.Error(response.GetStatus())
Expand Down
26 changes: 25 additions & 1 deletion internal/distributed/proxy/httpserver/handler_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ func TestQuery(t *testing.T) {
exceptCode: 200,
expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"word_count\":3000}]}",
})
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")

for _, tt := range testCases {
reqs := []*http.Request{genQueryRequest(), genGetRequest()}
Expand Down Expand Up @@ -590,7 +593,9 @@ func TestDelete(t *testing.T) {
exceptCode: 200,
expectedBody: "{\"code\":200,\"data\":{}}",
})

// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
testEngine := initHTTPServer(tt.mp, true)
Expand All @@ -614,11 +619,15 @@ func TestDelete(t *testing.T) {
}

func TestDeleteForFilter(t *testing.T) {
paramtable.Init()
jsonBodyList := [][]byte{
[]byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`),
[]byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "id in [1,2,3]"}`),
[]byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3], "filter": "id in [1,2,3]"}`),
}
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, jsonBody := range jsonBodyList {
t.Run("delete success", func(t *testing.T) {
mp := mocks.NewMockProxy(t)
Expand Down Expand Up @@ -716,6 +725,9 @@ func TestInsert(t *testing.T) {
HTTPCollectionName: DefaultCollectionName,
HTTPReturnData: rows[0],
})
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
testEngine := initHTTPServer(tt.mp, true)
Expand Down Expand Up @@ -761,6 +773,9 @@ func TestInsertForDataType(t *testing.T) {
"[success]with dynamic field": withDynamicField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64))),
"[success]with array fields": withArrayField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64))),
}
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for name, schema := range schemas {
t.Run(name, func(t *testing.T) {
mp := mocks.NewMockProxy(t)
Expand Down Expand Up @@ -828,6 +843,9 @@ func TestReturnInt64(t *testing.T) {
schemapb.DataType_Int64: "1,2,3",
schemapb.DataType_VarChar: "\"1\",\"2\",\"3\"",
}
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, dataType := range schemas {
t.Run("[insert]httpCfg.allow: false", func(t *testing.T) {
schema := newCollectionSchema(generateCollectionSchema(dataType))
Expand Down Expand Up @@ -1157,6 +1175,9 @@ func TestUpsert(t *testing.T) {
HTTPCollectionName: DefaultCollectionName,
HTTPReturnData: rows[0],
})
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
testEngine := initHTTPServer(tt.mp, true)
Expand Down Expand Up @@ -1255,6 +1276,9 @@ func TestSearch(t *testing.T) {
exceptCode: 200,
expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"distance\":0.01,\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"distance\":0.04,\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"distance\":0.09,\"word_count\":3000}]}",
})
// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
Expand Down
24 changes: 24 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa
req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)})
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
Expand Down Expand Up @@ -609,6 +613,10 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName
Expr: filter,
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest))
})
if err == nil {
Expand Down Expand Up @@ -653,6 +661,10 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN
req.Expr = filter
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest))
})
if err == nil {
Expand Down Expand Up @@ -694,6 +706,10 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Insert(reqCtx, req.(*milvuspb.InsertRequest))
})
if err == nil {
Expand Down Expand Up @@ -756,6 +772,10 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
return nil, err
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Upsert(reqCtx, req.(*milvuspb.UpsertRequest))
})
if err == nil {
Expand Down Expand Up @@ -882,6 +902,10 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
GuaranteeTimestamp: BoundedTimestamp,
}
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := CheckLimiter(req, h.proxy)
if err != nil {
return resp, err
}
return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest))
})
if err == nil {
Expand Down
9 changes: 9 additions & 0 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func (DefaultReq) GetBase() *commonpb.MsgBase {
func (req *DefaultReq) GetDbName() string { return req.DbName }

func TestHTTPWrapper(t *testing.T) {
paramtable.Init()
postTestCases := []requestBodyTestCase{}
postTestCasesTrace := []requestBodyTestCase{}
ginHandler := gin.Default()
Expand Down Expand Up @@ -1181,6 +1182,9 @@ func TestDML(t *testing.T) {
errCode: 65535,
})

// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")
for _, testcase := range queryTestCases {
t.Run("query", func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
Expand Down Expand Up @@ -1228,6 +1232,11 @@ func TestSearchV2(t *testing.T) {
Status: &StatusSuccess,
}, nil).Times(4)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()

// disable rate limit
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "true")

testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down
33 changes: 33 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,47 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)

func CheckLimiter(req interface{}, pxy types.ProxyComponent) (any, error) {
if !paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() {
return nil, nil
}
// apply limiter for http/http2 server
limiter, err := pxy.GetRateLimiter()
if err != nil {
log.Error("Get proxy rate limiter for httpV1/V2 server failed", zap.Error(err))
return nil, err
}

collectionIDs, rt, n, err := proxy.GetRequestInfo(req)
if err != nil {
return nil, err
}
err = limiter.Check(collectionIDs, rt, n)
nodeID := strconv.FormatInt(paramtable.GetNodeID(), 10)
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
if err != nil {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
rsp := proxy.GetFailedResponse(req, err)
if rsp != nil {
return rsp, err
}
}
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.SuccessLabel).Inc()
return nil, nil
}

func ParseUsernamePassword(c *gin.Context) (string, string, bool) {
username, password, ok := c.Request.BasicAuth()
if !ok {
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
zap.String("consistency_level", request.ConsistencyLevel.String()),
)

log.Debug(rpcReceived(method))
log.Info(rpcReceived(method))

if err := node.sched.ddQueue.Enqueue(cct); err != nil {
log.Warn(
Expand All @@ -369,7 +369,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
return merr.Status(err), nil
}

log.Debug(
log.Info(
rpcEnqueued(method),
zap.Uint64("BeginTs", cct.BeginTs()),
zap.Uint64("EndTs", cct.EndTs()),
Expand All @@ -387,7 +387,7 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat
return merr.Status(err), nil
}

log.Debug(
log.Info(
rpcDone(method),
zap.Uint64("BeginTs", cct.BeginTs()),
zap.Uint64("EndTs", cct.EndTs()),
Expand Down
12 changes: 6 additions & 6 deletions internal/proxy/rate_limit_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
// RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting.
func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
collectionIDs, rt, n, err := getRequestInfo(req)
collectionIDs, rt, n, err := GetRequestInfo(req)
if err != nil {
return handler(ctx, req)
}
Expand All @@ -46,7 +46,7 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.TotalLabel).Inc()
if err != nil {
metrics.ProxyRateLimitReqCount.WithLabelValues(nodeID, rt.String(), metrics.FailLabel).Inc()
rsp := getFailedResponse(req, err)
rsp := GetFailedResponse(req, err)
if rsp != nil {
return rsp, nil
}
Expand All @@ -56,8 +56,8 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor {
}
}

// getRequestInfo returns collection name and rateType of request and return tokens needed.
func getRequestInfo(req interface{}) ([]int64, internalpb.RateType, int, error) {
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
func GetRequestInfo(req interface{}) ([]int64, internalpb.RateType, int, error) {
switch r := req.(type) {
case *milvuspb.InsertRequest:
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
Expand Down Expand Up @@ -132,8 +132,8 @@ func failedMutationResult(err error) *milvuspb.MutationResult {
}
}

// getFailedResponse returns failed response.
func getFailedResponse(req any, err error) any {
// GetFailedResponse returns failed response.
func GetFailedResponse(req any, err error) any {
switch req.(type) {
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
return failedMutationResult(err)
Expand Down
Loading

0 comments on commit 20e2658

Please sign in to comment.