Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
belkka committed Feb 6, 2023
1 parent 0d22892 commit f04f676
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 21 deletions.
81 changes: 81 additions & 0 deletions pymilvus/asyncio/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
AbstractGrpcHandler,
Status,
MilvusException,
retry_on_rpc_failure,
check_pass_param,
get_consistency_level,
ts_utils,
Prepare,
)


Expand All @@ -21,3 +26,79 @@ async def _channel_ready(self):

def _header_adder_interceptor(self, header, value):
raise NotImplementedError # TODO

# @retry_on_rpc_failure()
async def describe_collection(self, collection_name, timeout=None, **kwargs):
raise NotImplementedError("TODO")
check_pass_param(collection_name=collection_name)
request = Prepare.describe_collection_request(collection_name)
rf = self._stub.DescribeCollection.future(request, timeout=timeout)
response = rf.result()
status = response.status

if status.error_code == 0:
return CollectionSchema(raw=response).dict()

raise DescribeCollectionException(status.error_code, status.reason)

async def _execute_search_requests(self, requests, timeout=None, **kwargs):
raise NotImplementedError("TODO")
auto_id = kwargs.get("auto_id", True)

try:
if kwargs.get("_async", False):
futures = []
for request in requests:
ft = self._stub.Search.future(request, timeout=timeout)
futures.append(ft)
func = kwargs.get("_callback", None)
return ChunkedSearchFuture(futures, func, auto_id)

raws = []
for request in requests:
response = self._stub.Search(request, timeout=timeout)

if response.status.error_code != 0:
raise MilvusException(response.status.error_code, response.status.reason)

raws.append(response)
round_decimal = kwargs.get("round_decimal", -1)
return ChunkedQueryResult(raws, auto_id, round_decimal)

except Exception as pre_err:
if kwargs.get("_async", False):
return SearchFuture(None, None, True, pre_err)
raise pre_err


# @retry_on_rpc_failure(retry_on_deadline=False)
async def search(self, collection_name, data, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
round_decimal=-1, timeout=None, schema=None, **kwargs):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
anns_field=anns_field,
search_data=data,
partition_name_array=partition_names,
output_fields=output_fields,
travel_timestamp=kwargs.get("travel_timestamp", 0),
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0)
)

if schema is None:
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)

consistency_level = schema["consistency_level"]
# overwrite the consistency level defined when user created the collection
consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level))

ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs)

requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, schema,
expression, partition_names, output_fields, round_decimal,
**kwargs)

auto_id = schema["auto_id"]
return await self._execute_search_requests(requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs) # TODO

147 changes: 147 additions & 0 deletions pymilvus/asyncio/orm/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from ...orm.collection import AbstractCollection, DataTypeNotMatchException, ExceptionsMessage
from .connections import connections


class Collection(AbstractCollection):
connections = connections

async def search(self, data, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
""" Conducts a vector similarity search with an optional boolean expression as filter.
Args:
data (``List[List[float]]``): The vectors of search data.
the length of data is number of query (nq), and the dim of every vector in data must be equal to
the vector field's of collection.
anns_field (``str``): The name of the vector field used to search of collection.
param (``dict[str, Any]``):
The parameters of search. The followings are valid keys of param.
* *nprobe*, *ef*, *search_k*, etc
Corresponding search params for a certain index.
* *metric_type* (``str``)
similar metricy types, the value must be of type str.
* *offset* (``int``, optional)
offset for pagination.
* *limit* (``int``, optional)
limit for the search results and pagination.
example for param::
{
"nprobe": 128,
"metric_type": "L2",
"offset": 10,
"limit": 10,
}
limit (``int``): The max number of returned record, also known as `topk`.
expr (``str``): The boolean expression used to filter attribute. Default to None.
example for expr::
"id_field >= 0", "id_field in [1, 2, 3, 4]"
partition_names (``List[str]``, optional): The names of partitions to search on. Default to None.
output_fields (``List[str]``, optional):
The name of fields to return in the search result. Can only get scalar fields.
round_decimal (``int``, optional): The specified number of decimal places of returned distance.
Defaults to -1 means no round to returned distance.
timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None.
If timeout is set to None, the client keeps waiting until the server responds or an error occurs.
**kwargs (``dict``): Optional search params
* *_async* (``bool``, optional)
Indicate if invoke asynchronously.
Returns a SearchFuture if True, else returns results from server directly.
* *_callback* (``function``, optional)
The callback function which is invoked after server response successfully.
It functions only if _async is set to True.
* *consistency_level* (``str/int``, optional)
Which consistency level to use when searching in the collection.
Options of consistency level: Strong, Bounded, Eventually, Session, Customized.
Note: this parameter will overwrite the same parameter specified when user created the collection,
if no consistency level was specified, search will use the consistency level when you create the
collection.
* *guarantee_timestamp* (``int``, optional)
Instructs Milvus to see all operations performed before this timestamp.
By default Milvus will search all operations performed to date.
Note: only valid in Customized consistency level.
* *graceful_time* (``int``, optional)
Search will use the (current_timestamp - the graceful_time) as the
`guarantee_timestamp`. By default with 5s.
Note: only valid in Bounded consistency level
* *travel_timestamp* (``int``, optional)
A specific timestamp to get results based on a data view at.
Returns:
SearchResult:
Returns ``SearchResult`` if `_async` is False , otherwise ``SearchFuture``
.. _Metric type documentations:
https://milvus.io/docs/v2.2.x/metric.md
.. _Index documentations:
https://milvus.io/docs/v2.2.x/index.md
.. _How guarantee ts works:
https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md
Raises:
MilvusException: If anything goes wrong
Examples:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_search", schema)
>>> # insert
>>> data = [
... [i for i in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> collection.insert(data)
>>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}})
>>> collection.load()
>>> # search
>>> search_param = {
... "data": [[1.0, 1.0]],
... "anns_field": "films",
... "param": {"metric_type": "L2", "offset": 1},
... "limit": 2,
... "expr": "film_id > 0",
... }
>>> res = collection.search(**search_param)
>>> assert len(res) == 1
>>> hits = res[0]
>>> assert len(hits) == 2
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ")
- Total hits: 2, hits ids: [8, 5]
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ")
- Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385
"""
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
res = conn.search(self._name, data, anns_field, param, limit, expr,
partition_names, output_fields, round_decimal, timeout=timeout,
schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
33 changes: 22 additions & 11 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import List
import pandas

from .connections import connections
from .connections import connections, AbstractConnections
from .schema import (
CollectionSchema,
FieldSchema,
Expand Down Expand Up @@ -46,8 +46,9 @@
from ..client.configs import DefaultConfigs


class AbstractCollection:
connections: AbstractConnections

class Collection:
def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", shards_num: int=2, **kwargs):
""" Constructs a collection by name, schema and other parameters.
Expand Down Expand Up @@ -91,6 +92,24 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default
self._using = using
self._shards_num = shards_num
self._kwargs = kwargs
self._prepare(schema, **kwargs)

def _prepare(self, schema: CollectionSchema=None, **kwargs):
raise NotImplementedError

def _get_connection(self):
return self.connections._fetch_handler(self._using)

@property
def name(self) -> str:
"""str: the name of the collection. """
return self._name


class Collection(AbstractCollection):
connections = connections

def _prepare(self, schema: CollectionSchema=None, **kwargs):
conn = self._get_connection()

has = conn.has_collection(self._name, **kwargs)
Expand All @@ -113,7 +132,7 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default

else:
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name)
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name)
if isinstance(schema, CollectionSchema):
check_schema(schema)
consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL))
Expand All @@ -139,9 +158,6 @@ def __repr__(self):
r.append(s.format(k, v))
return "".join(r)

def _get_connection(self):
return connections._fetch_handler(self._using)

@classmethod
def construct_from_dataframe(cls, name, dataframe, **kwargs):
if dataframe is None:
Expand Down Expand Up @@ -212,11 +228,6 @@ def description(self) -> str:
"""str: a text description of the collection. """
return self._schema.description

@property
def name(self) -> str:
"""str: the name of the collection. """
return self._name

@property
def is_empty(self) -> bool:
"""bool: whether the collection is empty or not."""
Expand Down
17 changes: 7 additions & 10 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@


class TestCollections:

# @pytest.fixture(scope="function",)
# def collection(self):
# name = gen_collection_name()
# schema = gen_schema()
# yield Collection(name, schema=schema)
# if connections.get_connection().has_collection(name):
# connections.get_connection().drop_collection(name)
@pytest.fixture(scope="function")
def collection(self):
name = gen_collection_name()
schema = gen_schema()
yield Collection(name, schema=schema)
if connections.get_connection().has_collection(name):
connections.get_connection().drop_collection(name)

def test_collection_by_DataFrame(self):
from pymilvus import Collection
Expand Down Expand Up @@ -54,11 +53,9 @@ def test_collection_by_DataFrame(self):
with mock.patch(f"{prefix}.close", return_value=None):
connections.disconnect("default")

@pytest.mark.xfail
def test_constructor(self, collection):
assert type(collection) is Collection

@pytest.mark.xfail
def test_construct_from_dataframe(self):
assert type(Collection.construct_from_dataframe(gen_collection_name(), gen_pd_data(default_nb), primary_field="int64")[0]) is Collection

Expand Down

0 comments on commit f04f676

Please sign in to comment.