Skip to content

Commit

Permalink
feat: include tokens usage for streamed output (#4282)
Browse files Browse the repository at this point in the history
Use pb.Reply instead of []byte with Reply.GetMessage() in llama grpc to get the proper usage data in reply streaming mode at the last [DONE] frame

Co-authored-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mintyleaf and mudler authored Nov 28, 2024
1 parent e001fad commit 0d6c3a7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
12 changes: 10 additions & 2 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
ss := ""

var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
partialRune = append(partialRune, chars...)
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
msg := reply.GetMessage()
partialRune = append(partialRune, msg...)

tokenUsage.Prompt = int(reply.PromptTokens)
tokenUsage.Completion = int(reply.Tokens)

for len(partialRune) > 0 {
r, size := utf8.DecodeRune(partialRune)
Expand All @@ -132,6 +136,10 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im

partialRune = partialRune[size:]
}

if len(msg) == 0 {
tokenCallback("", tokenUsage)
}
})
return LLMResponse{
Response: ss,
Expand Down
9 changes: 8 additions & 1 deletion core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
responses <- initialMessage

ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
choices := []schema.Choice{}
if s != "" {
choices = append(choices, schema.Choice{Delta: &schema.Message{Content: &s}, Index: 0})
}
resp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Choices: choices,
Object: "chat.completion.chunk",
Usage: schema.OpenAIUsage{
PromptTokens: usage.Prompt,
Expand Down Expand Up @@ -465,6 +469,9 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
toolsCalled := false
for ev := range responses {
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices) == 0 {
break
}
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Backend interface {
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
Expand Down
6 changes: 3 additions & 3 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
return client.LoadModel(ctx, in, opts...)
}

func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
Expand All @@ -158,7 +158,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
}

for {
feature, err := stream.Recv()
reply, err := stream.Recv()
if err == io.EOF {
break
}
Expand All @@ -167,7 +167,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun

return err
}
f(feature.GetMessage())
f(reply)
}

return nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/grpc/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (e *embedBackend) LoadModel(ctx context.Context, in *pb.ModelOptions, opts
return e.s.LoadModel(ctx, in)
}

func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
func (e *embedBackend) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
bs := &embedBackendServerStream{
ctx: ctx,
fn: f,
Expand Down Expand Up @@ -97,11 +97,11 @@ func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsReques

type embedBackendServerStream struct {
ctx context.Context
fn func(s []byte)
fn func(reply *pb.Reply)
}

func (e *embedBackendServerStream) Send(reply *pb.Reply) error {
e.fn(reply.GetMessage())
e.fn(reply)
return nil
}

Expand Down

0 comments on commit 0d6c3a7

Please sign in to comment.