-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
257 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from ...orm.collection import AbstractCollection, DataTypeNotMatchException, ExceptionsMessage | ||
from .connections import connections | ||
|
||
|
||
class Collection(AbstractCollection): | ||
connections = connections | ||
|
||
async def search(self, data, anns_field, param, limit, expr=None, partition_names=None, | ||
output_fields=None, timeout=None, round_decimal=-1, **kwargs): | ||
""" Conducts a vector similarity search with an optional boolean expression as filter. | ||
Args: | ||
data (``List[List[float]]``): The vectors of search data. | ||
the length of data is number of query (nq), and the dim of every vector in data must be equal to | ||
the vector field's of collection. | ||
anns_field (``str``): The name of the vector field used to search of collection. | ||
param (``dict[str, Any]``): | ||
The parameters of search. The followings are valid keys of param. | ||
* *nprobe*, *ef*, *search_k*, etc | ||
Corresponding search params for a certain index. | ||
* *metric_type* (``str``) | ||
similar metricy types, the value must be of type str. | ||
* *offset* (``int``, optional) | ||
offset for pagination. | ||
* *limit* (``int``, optional) | ||
limit for the search results and pagination. | ||
example for param:: | ||
{ | ||
"nprobe": 128, | ||
"metric_type": "L2", | ||
"offset": 10, | ||
"limit": 10, | ||
} | ||
limit (``int``): The max number of returned record, also known as `topk`. | ||
expr (``str``): The boolean expression used to filter attribute. Default to None. | ||
example for expr:: | ||
"id_field >= 0", "id_field in [1, 2, 3, 4]" | ||
partition_names (``List[str]``, optional): The names of partitions to search on. Default to None. | ||
output_fields (``List[str]``, optional): | ||
The name of fields to return in the search result. Can only get scalar fields. | ||
round_decimal (``int``, optional): The specified number of decimal places of returned distance. | ||
Defaults to -1 means no round to returned distance. | ||
timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. | ||
If timeout is set to None, the client keeps waiting until the server responds or an error occurs. | ||
**kwargs (``dict``): Optional search params | ||
* *_async* (``bool``, optional) | ||
Indicate if invoke asynchronously. | ||
Returns a SearchFuture if True, else returns results from server directly. | ||
* *_callback* (``function``, optional) | ||
The callback function which is invoked after server response successfully. | ||
It functions only if _async is set to True. | ||
* *consistency_level* (``str/int``, optional) | ||
Which consistency level to use when searching in the collection. | ||
Options of consistency level: Strong, Bounded, Eventually, Session, Customized. | ||
Note: this parameter will overwrite the same parameter specified when user created the collection, | ||
if no consistency level was specified, search will use the consistency level when you create the | ||
collection. | ||
* *guarantee_timestamp* (``int``, optional) | ||
Instructs Milvus to see all operations performed before this timestamp. | ||
By default Milvus will search all operations performed to date. | ||
Note: only valid in Customized consistency level. | ||
* *graceful_time* (``int``, optional) | ||
Search will use the (current_timestamp - the graceful_time) as the | ||
`guarantee_timestamp`. By default with 5s. | ||
Note: only valid in Bounded consistency level | ||
* *travel_timestamp* (``int``, optional) | ||
A specific timestamp to get results based on a data view at. | ||
Returns: | ||
SearchResult: | ||
Returns ``SearchResult`` if `_async` is False , otherwise ``SearchFuture`` | ||
.. _Metric type documentations: | ||
https://milvus.io/docs/v2.2.x/metric.md | ||
.. _Index documentations: | ||
https://milvus.io/docs/v2.2.x/index.md | ||
.. _How guarantee ts works: | ||
https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md | ||
Raises: | ||
MilvusException: If anything goes wrong | ||
Examples: | ||
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType | ||
>>> import random | ||
>>> connections.connect() | ||
>>> schema = CollectionSchema([ | ||
... FieldSchema("film_id", DataType.INT64, is_primary=True), | ||
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) | ||
... ]) | ||
>>> collection = Collection("test_collection_search", schema) | ||
>>> # insert | ||
>>> data = [ | ||
... [i for i in range(10)], | ||
... [[random.random() for _ in range(2)] for _ in range(10)], | ||
... ] | ||
>>> collection.insert(data) | ||
>>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) | ||
>>> collection.load() | ||
>>> # search | ||
>>> search_param = { | ||
... "data": [[1.0, 1.0]], | ||
... "anns_field": "films", | ||
... "param": {"metric_type": "L2", "offset": 1}, | ||
... "limit": 2, | ||
... "expr": "film_id > 0", | ||
... } | ||
>>> res = collection.search(**search_param) | ||
>>> assert len(res) == 1 | ||
>>> hits = res[0] | ||
>>> assert len(hits) == 2 | ||
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") | ||
- Total hits: 2, hits ids: [8, 5] | ||
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") | ||
- Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 | ||
""" | ||
if expr is not None and not isinstance(expr, str): | ||
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) | ||
|
||
conn = self._get_connection() | ||
res = conn.search(self._name, data, anns_field, param, limit, expr, | ||
partition_names, output_fields, round_decimal, timeout=timeout, | ||
schema=self._schema_dict, **kwargs) | ||
if kwargs.get("_async", False): | ||
return SearchFuture(res) | ||
return SearchResult(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters