diff --git a/core/backend/llm.go b/core/backend/llm.go index 4491a191eeb4..9e121f7999f2 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -117,8 +117,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im ss := "" var partialRune []byte - err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { - partialRune = append(partialRune, chars...) + err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) { + msg := reply.GetMessage() + partialRune = append(partialRune, msg...) + + tokenUsage.Prompt = int(reply.PromptTokens) + tokenUsage.Completion = int(reply.Tokens) for len(partialRune) > 0 { r, size := utf8.DecodeRune(partialRune) @@ -132,6 +136,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im partialRune = partialRune[size:] } + + if len(msg) == 0 { + tokenCallback("", tokenUsage) + } }) return LLMResponse{ Response: ss, diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 1ac1387eed3e..b03b18bd1d18 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -39,11 +39,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup responses <- initialMessage ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + choices := []schema.Choice{} + if s != "" { + choices = append(choices, schema.Choice{Delta: &schema.Message{Content: &s}, Index: 0}) + } resp := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, + Choices: choices, Object: "chat.completion.chunk", Usage: schema.OpenAIUsage{ PromptTokens: usage.Prompt, @@ -465,6 +469,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup toolsCalled := false for ev := range responses { usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices) == 0 { + break + } if len(ev.Choices[0].Delta.ToolCalls) > 0 { toolsCalled = true } diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 21435891fa79..fabc026853b0 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -37,7 +37,7 @@ type Backend interface { Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) - PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error + PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 9c8b302e13a1..ca207c3fd8e9 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -136,7 +136,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp return client.LoadModel(ctx, in, opts...) } -func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { +func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -158,7 +158,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun } for { - feature, err := stream.Recv() + reply, err := stream.Recv() if err == io.EOF { break } @@ -167,7 +167,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun return err } - f(feature.GetMessage()) + f(reply) } return nil diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index a5828a5fd21c..79648c5aed99 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -35,7 +35,7 @@ func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts return e.s.LoadModel(ctx, in) } -func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { +func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error { bs := &embedBackendServerStream{ ctx: ctx, fn: f, @@ -97,11 +97,11 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques type embedBackendServerStream struct { ctx context.Context - fn func(s []byte) + fn func(reply *pb.Reply) } func (e *embedBackendServerStream) Send(reply *pb.Reply) error { - e.fn(reply.GetMessage()) + e.fn(reply) return nil }