Skip to content

Commit

Permalink
mrege
Browse files Browse the repository at this point in the history
  • Loading branch information
iamjoel committed Dec 3, 2024
2 parents 9bea4e5 + 0c61608 commit ef4e51c
Show file tree
Hide file tree
Showing 66 changed files with 1,657 additions and 600 deletions.
1 change: 0 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,3 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5

CREATE_TIDB_SERVICE_JOB_ENABLED=false

RETRIEVAL_TOP_N=0
3 changes: 3 additions & 0 deletions api/.ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ select = [
"PLC0208", # iteration-over-set
"PLC2801", # unnecessary-dunder-call
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
"PLE0605", # invalid-all-format
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
Expand All @@ -28,6 +30,7 @@ select = [
"RUF100", # unused-noqa
"RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception
Expand Down
2 changes: 0 additions & 2 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,6 @@ class DataSetConfig(BaseSettings):
default=30,
)

RETRIEVAL_TOP_N: int = Field(description="number of retrieval top_n", default=0)


class WorkspaceConfig(BaseSettings):
"""
Expand Down
2 changes: 1 addition & 1 deletion api/configs/packaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):

CURRENT_VERSION: str = Field(
description="Dify version",
default="0.12.1",
default="0.13.0",
)

COMMIT_SHA: str = Field(
Expand Down
12 changes: 6 additions & 6 deletions api/core/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
)

__all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"FileTransferMethod",
"FileBelongsTo",
"File",
"ImageConfig",
"FileAttribute",
"ArrayFileAttribute",
"FILE_MODEL_IDENTITY",
]
24 changes: 12 additions & 12 deletions api/core/model_runtime/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@
from .model_entities import ModelPropertyKey

__all__ = [
"AssistantPromptMessage",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"LLMUsage",
"ModelPropertyKey",
"AssistantPromptMessage",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"UserPromptMessage",
"PromptMessageTool",
"ToolPromptMessage",
"PromptMessageContentType",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"UserPromptMessage",
"VideoPromptMessageContent",
]
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/moonshot/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def get_tool_call(tool_name: str):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def get_tool_call(tool_call_id: str):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue

Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/stepfun/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def get_tool_call(tool_name: str):
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
decoded_chunk = chunk.strip().removeprefix("data: ")
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .common import ChatRole
from .maas import MaasError, MaasService

__all__ = ["MaasService", "ChatRole", "MaasError"]
__all__ = ["ChatRole", "MaasError", "MaasService"]
17 changes: 15 additions & 2 deletions api/core/model_runtime/model_providers/wenxin/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@ class WenxinRerank(_CommonWenxin):
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}"

# For issue #11252
# for wenxin Rerank model top_n length should be equal or less than docs length
if top_n is not None and top_n > len(docs):
top_n = len(docs)
# for wenxin Rerank model, query should not be an empty string
if query == "":
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience.
try:
response = httpx.post(
url,
json={"model": model, "query": query, "documents": docs, "top_n": top_n},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return response.json()
data = response.json()
# wenxin error handling
if "error_code" in data:
raise InternalServerError(data["error_msg"])
return data
except httpx.HTTPStatusError as e:
raise InternalServerError(str(e))

Expand Down Expand Up @@ -69,6 +79,9 @@ def _invoke(
results = wenxin_rerank.rerank(model, query, docs, top_n)

rerank_documents = []
if "results" not in results:
raise ValueError("results key not found in response")

for result in results["results"]:
index = result["index"]
if "document" in result:
Expand Down
7 changes: 3 additions & 4 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from flask import Flask, current_app

from configs import DifyConfig
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
Expand Down Expand Up @@ -114,7 +113,7 @@ def retrieve(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or top_k,
top_n=top_k,
)

return all_documents
Expand Down Expand Up @@ -186,7 +185,7 @@ def embedding_search(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
top_n=len(documents),
)
)
else:
Expand Down Expand Up @@ -231,7 +230,7 @@ def full_text_index_search(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=DifyConfig.RETRIEVAL_TOP_N or len(documents),
top_n=len(documents),
)
)
else:
Expand Down
11 changes: 5 additions & 6 deletions api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def _create_collection(self) -> None:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
Expand Down Expand Up @@ -200,10 +199,10 @@ def init_vector(
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
host=dify_config.OCEANBASE_VECTOR_HOST or "",
port=dify_config.OCEANBASE_VECTOR_PORT or 0,
user=dify_config.OCEANBASE_VECTOR_USER or "",
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
),
)
42 changes: 21 additions & 21 deletions api/core/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,32 @@
)

__all__ = [
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
"ArrayAnyVariable",
"ArrayFileSegment",
"ArrayFileVariable",
"ArrayNumberSegment",
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",
"FileSegment",
"FileVariable",
"ArrayFileVariable",
"FloatSegment",
"FloatVariable",
"IntegerSegment",
"IntegerVariable",
"NoneSegment",
"NoneVariable",
"ObjectSegment",
"ObjectVariable",
"SecretVariable",
"Segment",
"SegmentGroup",
"SegmentType",
"StringSegment",
"StringVariable",
"Variable",
]
2 changes: 1 addition & 1 deletion api/core/workflow/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .workflow_logging_callback import WorkflowLoggingCallback

__all__ = [
"WorkflowLoggingCallback",
"WorkflowCallback",
"WorkflowLoggingCallback",
]
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/answer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute

__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
from .node import BaseNode

__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/end/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .end_node import EndNode
from .entities import EndStreamParam

__all__ = ["EndStreamParam", "EndNode"]
__all__ = ["EndNode", "EndStreamParam"]
4 changes: 2 additions & 2 deletions api/core/workflow/nodes/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from .types import NodeEvent

__all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunStreamChunkEvent",
"NodeEvent",
"ModelInvokeCompletedEvent",
]
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/http_request/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode

__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]
5 changes: 3 additions & 2 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
variable_pool.add([self.node_id, "item"], iterator_list_value[0])

# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
from core.workflow.graph_engine.graph_engine import GraphEngine

graph_engine = GraphEngine(
tenant_id=self.tenant_id,
Expand Down Expand Up @@ -162,7 +162,8 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
thread_pool = graph_engine.workflow_thread_pool_mapping[graph_engine.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/question_classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode

__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]
Empty file.
13 changes: 9 additions & 4 deletions api/extensions/ext_redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Union

import redis
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
Expand Down Expand Up @@ -46,11 +48,11 @@ def __getattr__(self, item):

def init_app(app: DifyApp):
global redis_client
connection_class = Connection
connection_class: type[Union[Connection, SSLConnection]] = Connection
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection

redis_params = {
redis_params: dict[str, Any] = {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD,
"db": dify_config.REDIS_DB,
Expand All @@ -60,6 +62,7 @@ def init_app(app: DifyApp):
}

if dify_config.REDIS_USE_SENTINEL:
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]
Expand All @@ -74,11 +77,13 @@ def init_app(app: DifyApp):
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
redis_client.initialize(master)
elif dify_config.REDIS_USE_CLUSTERS:
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
nodes = [
ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1]))
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD))
# FIXME: mypy error here, try to figure out how to fix it
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
else:
redis_params.update(
{
Expand Down
Loading

0 comments on commit ef4e51c

Please sign in to comment.