-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(bark-cpp): add new bark.cpp backend
Signed-off-by: Ettore Di Giacinto <[email protected]>
- Loading branch information
Showing
7 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
/sources/ | ||
__pycache__/ | ||
*.a | ||
*.o | ||
get-sources | ||
prepare-sources | ||
/backend/cpp/llama/grpc-server | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
INCLUDE_PATH := $(abspath ./) | ||
LIBRARY_PATH := $(abspath ./) | ||
|
||
AR?=ar | ||
|
||
BUILD_TYPE?= | ||
# keep standard at C11 and C++11 | ||
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC | ||
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm | ||
|
||
# warnings | ||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function | ||
|
||
gobark.o: | ||
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS) | ||
|
||
libbark.a: gobark.o | ||
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./ | ||
$(AR) rcs libbark.a gobark.o | ||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o | ||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o | ||
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o | ||
|
||
clean: | ||
rm -f gobark.o libbark.a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#include <iostream> | ||
#include <tuple> | ||
|
||
#include "bark.h" | ||
#include "gobark.h" | ||
#include "common.h" | ||
#include "ggml.h" | ||
|
||
struct bark_context *c; | ||
|
||
void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) { | ||
if (step == bark_encoding_step::SEMANTIC) { | ||
printf("\rGenerating semantic tokens... %d%%", progress); | ||
} else if (step == bark_encoding_step::COARSE) { | ||
printf("\rGenerating coarse tokens... %d%%", progress); | ||
} else if (step == bark_encoding_step::FINE) { | ||
printf("\rGenerating fine tokens... %d%%", progress); | ||
} | ||
fflush(stdout); | ||
} | ||
|
||
int load_model(char *model) { | ||
// initialize bark context | ||
struct bark_context_params ctx_params = bark_context_default_params(); | ||
bark_params params; | ||
|
||
params.model_path = model; | ||
|
||
// ctx_params.verbosity = verbosity; | ||
ctx_params.progress_callback = bark_print_progress_callback; | ||
ctx_params.progress_callback_user_data = nullptr; | ||
|
||
struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed); | ||
if (!bctx) { | ||
fprintf(stderr, "%s: Could not load model\n", __func__); | ||
return 1; | ||
} | ||
|
||
c = bctx; | ||
|
||
return 0; | ||
} | ||
|
||
int tts(char *text,int threads, char *dst ) { | ||
|
||
ggml_time_init(); | ||
const int64_t t_main_start_us = ggml_time_us(); | ||
|
||
// generate audio | ||
if (!bark_generate_audio(c, text, threads)) { | ||
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__); | ||
return 1; | ||
} | ||
|
||
const float *audio_data = bark_get_audio_data(c); | ||
if (audio_data == NULL) { | ||
fprintf(stderr, "%s: Could not get audio data\n", __func__); | ||
return 1; | ||
} | ||
|
||
const int audio_arr_size = bark_get_audio_data_size(c); | ||
|
||
std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size); | ||
|
||
write_wav_on_disk(audio_arr, dst); | ||
|
||
// report timing | ||
{ | ||
const int64_t t_main_end_us = ggml_time_us(); | ||
const int64_t t_load_us = bark_get_load_time(c); | ||
const int64_t t_eval_us = bark_get_eval_time(c); | ||
|
||
printf("\n\n"); | ||
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); | ||
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f); | ||
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
int unload() { | ||
bark_free(c); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package main | ||
|
||
// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers | ||
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon | ||
// #include <gobark.h> | ||
// #include <stdlib.h> | ||
import "C" | ||
|
||
import ( | ||
"fmt" | ||
"unsafe" | ||
|
||
"github.com/mudler/LocalAI/pkg/grpc/base" | ||
pb "github.com/mudler/LocalAI/pkg/grpc/proto" | ||
) | ||
|
||
type Bark struct { | ||
base.SingleThread | ||
threads int | ||
} | ||
|
||
func (sd *Bark) Load(opts *pb.ModelOptions) error { | ||
|
||
sd.threads = int(opts.Threads) | ||
|
||
modelFile := C.CString(opts.ModelFile) | ||
defer C.free(unsafe.Pointer(modelFile)) | ||
Check warning Code scanning / gosec Use of unsafe calls should be audited Warning
Use of unsafe calls should be audited
|
||
|
||
ret := C.load_model(modelFile) | ||
if ret != 0 { | ||
return fmt.Errorf("inference failed") | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (sd *Bark) TTS(opts *pb.TTSRequest) error { | ||
t := C.CString(opts.Text) | ||
defer C.free(unsafe.Pointer(t)) | ||
Check warning Code scanning / gosec Use of unsafe calls should be audited Warning
Use of unsafe calls should be audited
|
||
|
||
dst := C.CString(opts.Dst) | ||
defer C.free(unsafe.Pointer(dst)) | ||
Check warning Code scanning / gosec Use of unsafe calls should be audited Warning
Use of unsafe calls should be audited
|
||
|
||
threads := C.int(sd.threads) | ||
|
||
ret := C.tts(t, threads, dst) | ||
if ret != 0 { | ||
return fmt.Errorf("inference failed") | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
int load_model(char *model); | ||
int tts(char *text,int threads, char *dst ); | ||
#ifdef __cplusplus | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package main | ||
|
||
// Note: this is started internally by LocalAI and a server is allocated for each model | ||
import ( | ||
"flag" | ||
|
||
grpc "github.com/mudler/LocalAI/pkg/grpc" | ||
) | ||
|
||
var ( | ||
addr = flag.String("addr", "localhost:50051", "the address to connect to") | ||
) | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
if err := grpc.StartServer(*addr, &Bark{}); err != nil { | ||
panic(err) | ||
} | ||
} |