Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asyncio support #1273

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added pymilvus/asyncio/__init__.py
Empty file.
Empty file.
23 changes: 23 additions & 0 deletions pymilvus/asyncio/client/grpc_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import grpc.aio

from ...client.grpc_handler import (
AbstractGrpcHandler,
Status,
MilvusException,
)


class GrpcHandler(AbstractGrpcHandler[grpc.aio.Channel]):
_insecure_channel = grpc.aio.insecure_channel
_secure_channel = grpc.aio.secure_channel

async def _channel_ready(self):
if self._channel is None:
raise MilvusException(
Status.CONNECT_FAILED,
'No channel in handler, please setup grpc channel first',
)
await self._channel.channel_ready()

def _header_adder_interceptor(self, header, value):
raise NotImplementedError # TODO
Empty file.
25 changes: 25 additions & 0 deletions pymilvus/asyncio/orm/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import copy
import typing

from ...orm.connections import AbstractConnections
from ..client.grpc_handler import GrpcHandler as AsyncGrpcHandler


# pylint: disable=W0236
class Connections(AbstractConnections[AsyncGrpcHandler, typing.Awaitable[None]]):
async def _disconnect(self, alias: str):
if alias in self._connected_alias:
await self._connected_alias.pop(alias).close()

async def _connect(self, alias, **kwargs):
gh = AsyncGrpcHandler(**kwargs)

await gh._channel_ready()
kwargs.pop('password')
kwargs.pop('secure', None)

self._connected_alias[alias] = gh
self._alias[alias] = copy.deepcopy(kwargs)


connections = Connections()
61 changes: 43 additions & 18 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import copy
import base64
import typing
from urllib import parse

import grpc
Expand Down Expand Up @@ -66,7 +67,14 @@
from ..decorators import retry_on_rpc_failure


class GrpcHandler:
GrpcChannelT = typing.TypeVar('GrpcChannelT', grpc.Channel, grpc.aio.Channel)


class AbstractGrpcHandler(typing.Generic[GrpcChannelT]):
_insecure_channel: typing.Callable[..., GrpcChannelT]
_secure_channel: typing.Callable[..., GrpcChannelT]
_channel: typing.Optional[GrpcChannelT]

def __init__(self, uri=config.GRPC_URI, host="", port="", channel=None, **kwargs):
self._stub = None
self._channel = channel
Expand Down Expand Up @@ -108,25 +116,17 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
pass

def _wait_for_channel_ready(self, timeout=10):
if self._channel is not None:
try:
grpc.channel_ready_future(self._channel).result(timeout=timeout)
return
except grpc.FutureTimeoutError as e:
raise MilvusException(Status.CONNECT_FAILED,
f'Fail connecting to server on {self._address}. Timeout') from e

raise MilvusException(Status.CONNECT_FAILED, 'No channel in handler, please setup grpc channel first')

def close(self):
self._channel.close()
return self._channel.close()

def _header_adder_interceptor(self, header, value):
raise NotImplementedError("this is abstract method")

def _setup_authorization_interceptor(self, user, password):
if user and password:
authorization = base64.b64encode(f"{user}:{password}".encode('utf-8'))
key = "authorization"
self._authorization_interceptor = interceptor.header_adder_interceptor(key, authorization)
self._authorization_interceptor = self._header_adder_interceptor(key, authorization)

def _setup_grpc_channel(self):
""" Create a ddl grpc channel """
Expand All @@ -137,7 +137,7 @@ def _setup_grpc_channel(self):
('grpc.keepalive_time_ms', 55000),
]
if not self._secure:
self._channel = grpc.insecure_channel(
self._channel = self._insecure_channel(
self._address,
options=opts,
)
Expand All @@ -160,7 +160,7 @@ def _setup_grpc_channel(self):
else:
creds = grpc.ssl_channel_credentials(root_certificates=None, private_key=None,
certificate_chain=None)
self._channel = grpc.secure_channel(
self._channel = self._secure_channel(
self._address,
creds,
options=opts
Expand All @@ -170,11 +170,11 @@ def _setup_grpc_channel(self):
if self._authorization_interceptor:
self._final_channel = grpc.intercept_channel(self._final_channel, self._authorization_interceptor)
if self._log_level:
log_level_interceptor = interceptor.header_adder_interceptor("log_level", self._log_level)
log_level_interceptor = self._header_adder_interceptor("log_level", self._log_level)
self._final_channel = grpc.intercept_channel(self._final_channel, log_level_interceptor)
self._log_level = None
if self._request_id:
request_id_interceptor = interceptor.header_adder_interceptor("client_request_id", self._request_id)
request_id_interceptor = self._header_adder_interceptor("client_request_id", self._request_id)
self._final_channel = grpc.intercept_channel(self._final_channel, request_id_interceptor)
self._request_id = None
self._stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)
Expand All @@ -192,6 +192,31 @@ def server_address(self):
""" Server network address """
return self._address


class GrpcHandler(AbstractGrpcHandler[grpc.Channel]):
_insecure_channel = grpc.insecure_channel
_secure_channel = grpc.secure_channel

def _wait_for_channel_ready(self, timeout=10):
if self._channel is None:
raise MilvusException(
Status.CONNECT_FAILED,
'No channel in handler, please setup grpc channel first',
)

try:
grpc.channel_ready_future(self._channel).result(timeout=timeout)
except grpc.FutureTimeoutError as exc:
raise MilvusException(
Status.CONNECT_FAILED,
f'Fail connecting to server on {self._address}. Timeout'
) from exc

def _header_adder_interceptor(self, header, value):
return interceptor.header_adder_interceptor(header, value)

#### TODO: implement methods below in asyncio.client.grpc_handler

def reset_password(self, user, old_password, new_password, timeout=None):
"""
reset password and then setup the grpc channel.
Expand Down
68 changes: 44 additions & 24 deletions pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import copy
import re
import threading
import typing
from urllib import parse
from typing import Tuple

from ..client.check import is_legal_host, is_legal_port, is_legal_address
from ..client.grpc_handler import GrpcHandler
from ..client.grpc_handler import GrpcHandler, AbstractGrpcHandler

from .default_config import DefaultConfig, ENV_CONNECTION_CONF
from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException
Expand Down Expand Up @@ -55,7 +56,11 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)


class Connections(metaclass=SingleInstanceMetaClass):
NoneT = typing.TypeVar('NoneT', None, typing.Awaitable[None])
GrpcHandlerT = typing.TypeVar('GrpcHandlerT', bound=AbstractGrpcHandler)


class AbstractConnections(typing.Generic[GrpcHandlerT, NoneT], metaclass=SingleInstanceMetaClass):
""" Class for managing all connections of milvus. Used as a singleton in this module. """

def __init__(self):
Expand All @@ -66,7 +71,7 @@ def __init__(self):

"""
self._alias = {}
self._connected_alias = {}
self._connected_alias: typing.Dict[str, GrpcHandlerT] = {}

self.add_connection(default=self._read_default_config_from_os_env())

Expand Down Expand Up @@ -190,6 +195,9 @@ def __generate_address(self, uri: str, host: str, port: str) -> str:

return f"{host}:{port}"

def _disconnect(self, alias: str) -> NoneT:
raise NotImplementedError

def disconnect(self, alias: str):
""" Disconnects connection from the registry.

Expand All @@ -199,8 +207,7 @@ def disconnect(self, alias: str):
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

if alias in self._connected_alias:
self._connected_alias.pop(alias).close()
return self._disconnect(alias)

def remove_connection(self, alias: str):
""" Removes connection from the registry.
Expand All @@ -211,8 +218,16 @@ def remove_connection(self, alias: str):
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

self.disconnect(alias)
# TODO: does order matter?
# original sync implementation was
# self.disconnect(alias)
# self._alias.pop(alias, None)
belkka marked this conversation as resolved.
Show resolved Hide resolved

self._alias.pop(alias, None)
return self.disconnect(alias)

def _connect(self, alias, **kwargs) -> NoneT:
raise NotImplementedError

def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwargs):
"""
Expand Down Expand Up @@ -265,19 +280,6 @@ def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwa
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

def connect_milvus(**kwargs):
gh = GrpcHandler(**kwargs)

t = kwargs.get("timeout")
timeout = t if isinstance(t, int) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
kwargs.pop('password')
kwargs.pop('secure', None)

self._connected_alias[alias] = gh
self._alias[alias] = copy.deepcopy(kwargs)

def with_config(config: Tuple) -> bool:
for c in config:
if c != "":
Expand All @@ -300,15 +302,14 @@ def with_config(config: Tuple) -> bool:
if self._alias[alias].get("address") != in_addr:
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)

connect_milvus(**kwargs, user=user, password=password)

else:
if alias not in self._alias:
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)

connect_alias = dict(self._alias[alias].items())
connect_alias["user"] = user
connect_milvus(**connect_alias, password=password, **kwargs)
kwargs = dict(**self._alias[alias], **kwargs)

kwargs["user"] = user
return self._connect(alias, **kwargs, password=password)

def list_connections(self) -> list:
""" List names of all connections.
Expand Down Expand Up @@ -381,5 +382,24 @@ def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandler:
return conn


class Connections(AbstractConnections[GrpcHandler, None]):
def _disconnect(self, alias: str):
if alias in self._connected_alias:
self._connected_alias.pop(alias).close()

def _connect(self, alias, **kwargs):
gh = GrpcHandler(**kwargs)

t = kwargs.get("timeout")
timeout = t if isinstance(t, int) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
kwargs.pop('password')
kwargs.pop('secure', None)

self._connected_alias[alias] = gh
self._alias[alias] = copy.deepcopy(kwargs)


# Singleton Mode in Python
connections = Connections()