Skip to content

Commit

Permalink
增加plan和function_call两种返回type (#618)
Browse files Browse the repository at this point in the history
* 增加plan和function_call两种返回type

* 为数据类型添加exta属性

* 为数据类型添加exta属性

* 为数据类型添加exta属性

---------

Co-authored-by: yepeiwen01 <[email protected]>
  • Loading branch information
peiwenYe and yepeiwen01 authored Nov 27, 2024
1 parent 72107d8 commit ffeaea6
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 22 deletions.
44 changes: 33 additions & 11 deletions python/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,28 @@ def extract_values_to_dict(self):
return inputs


class Text(BaseModel):
class Text(BaseModel, extra='allow'):
info: str = Field(default="", description="具体文本内容")


class Code(BaseModel):
class Code(BaseModel, extra='allow'):
code: str = Field(default="", description="代码片段")


class Files(BaseModel):
class Files(BaseModel, extra='allow'):
filename: str = Field(default="", description="文件名")
url: str = Field(default="", description="文件url")


class Urls(BaseModel):
class Urls(BaseModel, extra='allow'):
url: str = Field(default="", description="链接地址")


class OralText(BaseModel):
class OralText(BaseModel, extra='allow'):
info: str = Field(default="", description="口语化文本内容")


class References(BaseModel):
class References(BaseModel, extra='allow'):
type: str = Field(default="", description="类型")
resource_type: str = Field(default="", description="资源类型")
icon: str = Field(default="", description="站点图标")
Expand All @@ -96,21 +96,35 @@ class References(BaseModel):
video_url: str = Field(default="", description="视频url")


class Image(BaseModel):
class Image(BaseModel, extra='allow'):
filename: str = Field(default="", description="图片名称")
url: str = Field(default="", description="图片url")
byte: Optional[bytes] = Field(default=b'', description="图片二进制数据")


class Chart(BaseModel):
class Chart(BaseModel, extra='allow'):
filename: str = Field(default="", description="图表名称")
url: str = Field(default="", description="图表url")


class Audio(BaseModel):
class Audio(BaseModel, extra='allow'):
filename: str = Field(default="", description="音频名称")
url: str = Field(default="", description="音频url")
byte: Optional[bytes] = Field(default=b'', description="音频二进制数据")


class PlanStep(BaseModel, extra='allow'):
name: str = Field(default="", description="step名")
arguments: dict = Field(default={}, description="step参数")

class Plan(BaseModel, extra='allow'):
detail: str = Field(default="", description="计划详情")
steps: list[PlanStep] = Field(default=[], description="步骤列表")

class FunctionCall(BaseModel, extra='allow'):
thought: str = Field(default="", description="思考结果")
name: str = Field(default="", description="工具名")
arguments: dict = Field(default={}, description="参数列表")


class Content(BaseModel):
Expand All @@ -126,7 +140,7 @@ class Content(BaseModel):
description="耗时、性能、内存等trace及debug所需信息")
type: str = Field(default="text",
description="代表event 类型,包括 text、code、files、urls、oral_text、references、image、chart、audio该字段的取值决定了下面text字段的内容结构")
text: Union[Text, Code, Files, Urls, OralText, References, Image, Chart, Audio] = Field(default=Text,
text: Union[Text, Code, Files, Urls, OralText, References, Image, Chart, Audio, Plan, FunctionCall] = Field(default=Text,
description="代表当前 event 元素的内容,每一种 event 对应的 text 结构固定")

@field_validator('text', mode='before')
Expand All @@ -149,6 +163,10 @@ def set_text(cls, v, values, **kwargs):
return Chart(**v)
elif values.data['type'] == 'audio':
return Audio(**v)
elif values.data['type'] == 'plan':
return Plan(**v)
elif values.data['type'] == 'function_call':
return FunctionCall(**v)
else:
raise ValueError(f"Invalid value for 'type': {values['type']}")

Expand Down Expand Up @@ -514,9 +532,13 @@ def create_output(self, type, text, role="tool", name="", visible_scope="all", r
key_list = ["filename", "url"]
elif type == "audio":
key_list = ["filename", "url"]
elif type == "plan":
key_list = ["detail", "steps"]
elif type == "function_call":
key_list = ["thought", "name", "arguments"]
else:
raise ValueError("Unknown type: {}".format(type))
assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list)
# assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list)
else:
raise ValueError("text must be str or dict")

Expand Down
60 changes: 59 additions & 1 deletion python/tests/component_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,62 @@
"required": ["filename", "url"]
}


plan_schema = copy.deepcopy(base_item_schema)
plan_schema["$schema"] = "plan_schema"
plan_schema["properties"]["type"] = {
"type": "string",
"enum": ["plan"]
}
plan_schema["properties"]["text"] = {
"type": "object",
"properties": {
"detail": {
"type": "string"
},
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": True
}
},
"required": ["name", "arguments"]
}
}
},
"required": ["detail", "steps"]
}


function_call_schema = copy.deepcopy(base_item_schema)
function_call_schema["$schema"] = "function_call_schema"
function_call_schema["properties"]["type"] = {
"type": "string",
"enum": ["function_call"]
}
function_call_schema["properties"]["text"] = {
"type": "object",
"properties": {
"thought": {
"type": "string"
},
"name": {
"type": "string",
},
"arguments": {
"type": "object",
"additionalProperties": True
}
},
"required": ["thought", "name", "arguments"],
}


type_to_json_schemas = {
"text": text_schema,
'code': code_schema,
Expand All @@ -256,5 +312,7 @@
"references": references_schema,
"image": image_schema,
"chart": chart_schema,
"audio": audio_schema
"audio": audio_schema,
"plan": plan_schema,
"function_call": function_call_schema,
}
2 changes: 1 addition & 1 deletion python/tests/component_tool_eval_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from appbuilder.tests.component_schemas import text_schema, url_schema, image_schema, code_schema, file_schema, oral_text_schema, references_schema, chart_schema, audio_schema
from appbuilder.tests.component_schemas import text_schema, url_schema, image_schema, code_schema, file_schema, oral_text_schema, references_schema, chart_schema, audio_schema, plan_schema, function_call_schema

components_tool_eval_output_json_maps = {
"AnimalRecognition": [text_schema],
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_all_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def write_error_data(txt_file_path, error_df,error_stats):
file.write(f"错误信息: {error}, 出现次数: {count}\n")
print(f"\n错误信息已写入: {txt_file_path}")

@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "")
class TestComponentManifestsAndToolEval(unittest.TestCase):
def setUp(self) -> None:
self.all_components = get_all_components()
Expand Down
20 changes: 12 additions & 8 deletions python/tests/test_base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test_valid_output_with_dict(self):
output6 = self.component.create_output(type="image", text={"filename": "file.png", "url": "http://www.baidu.com"})
output7 = self.component.create_output(type="chart", text={"filename": "file.jpg", "url": "http://www.baidu.com"})
output8 = self.component.create_output(type="audio", text={"filename": "file.mp3", "url": "http://www.baidu.com"})
output9 = self.component.create_output(type="plan", text={"detail": "hello", "steps":[{"name": "1", "arguments": {"query": "a", "chat_history": "world"}}]})
output10 = self.component.create_output(type="function_call", text={"thought": "hello", "name": "AppBuilder", "arguments": {"query": "a", "chat_history": "world"}})
self.assertIsInstance(output1, ComponentOutput)
self.assertIsInstance(output2, ComponentOutput)
self.assertIsInstance(output3, ComponentOutput)
Expand All @@ -36,6 +38,8 @@ def test_valid_output_with_dict(self):
self.assertIsInstance(output6, ComponentOutput)
self.assertIsInstance(output7, ComponentOutput)
self.assertIsInstance(output8, ComponentOutput)
self.assertIsInstance(output9, ComponentOutput)
self.assertIsInstance(output10, ComponentOutput)

def test_valid_output_type_with_same_key(self):
output1 = self.component.create_output(type="urls", text={"url": "http://www.baidu.com"})
Expand All @@ -52,14 +56,14 @@ def test_valid_output_type_with_same_key(self):
def test_invalid_output_type_json(self):
with self.assertRaises(ValueError):
output = self.component.create_output(type="json", text="")
with self.assertRaises(AssertionError):
output = self.component.create_output(type="files", text={})
with self.assertRaises(AssertionError):
output = self.component.create_output(type="references", text={"info": "text"})
with self.assertRaises(AssertionError):
output = self.component.create_output(type="image", text={"url": "https://example.com/img"})
with self.assertRaises(AssertionError):
output = self.component.create_output(type="chart", text={"url": "https://example.com/chart"})
# with self.assertRaises(AssertionError):
# output = self.component.create_output(type="files", text={})
# with self.assertRaises(AssertionError):
# output = self.component.create_output(type="references", text={"info": "text"})
# with self.assertRaises(AssertionError):
# output = self.component.create_output(type="image", text={"url": "https://example.com/img"})
# with self.assertRaises(AssertionError):
# output = self.component.create_output(type="chart", text={"url": "https://example.com/chart"})


if __name__ == '__main__':
Expand Down

0 comments on commit ffeaea6

Please sign in to comment.