Skip to content

Commit

Permalink
make http handler take an optional requests.Session (#825)
Browse files Browse the repository at this point in the history
* make http handler take an optional requests.Session

for #792

* adds test_session_attribute

* fix linter

* apply suggestions from ddelange

---------

Co-authored-by: Aron Bartle <[email protected]>
Co-authored-by: Michael Penkov <[email protected]>
  • Loading branch information
3 people authored Oct 4, 2024
1 parent fed7b78 commit 672db98
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
27 changes: 21 additions & 6 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def open_uri(uri, mode, transport_params):


def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE):
headers=None, timeout=None, session=None, buffer_size=DEFAULT_BUFFER_SIZE):
"""Implement streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
Expand All @@ -73,6 +73,9 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
set this variable to an empty dict, ``{}``.
session: object, optional
The requests Session object to use with http get requests.
Can be used for OAuth2 clients.
buffer_size: int, optional
The buffer size to use when performing I/O.
Expand All @@ -86,7 +89,7 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
fobj = SeekableBufferedInputBase(
uri, mode, buffer_size=buffer_size, kerberos=kerberos,
user=user, password=password, cert=cert,
headers=headers, timeout=timeout,
headers=headers, session=session, timeout=timeout,
)
fobj.name = os.path.basename(urllib.parse.urlparse(uri).path)
return fobj
Expand All @@ -97,7 +100,10 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
headers=None, session=None, timeout=None):

self.session = session or requests

if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
Expand All @@ -116,7 +122,14 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,

self.timeout = timeout

self.response = requests.get(
self.response = session.get(
url,
auth=auth,
cert=cert,
stream=True,
headers=self.headers,
timeout=self.timeout,
) if session is not None else requests.get(
url,
auth=auth,
cert=cert,
Expand Down Expand Up @@ -217,7 +230,7 @@ class SeekableBufferedInputBase(BufferedInputBase):

def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
headers=None, session=None, timeout=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
If cert is set, will try to use a client certificate
Expand All @@ -227,6 +240,8 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
"""
self.url = url

self.session = session or requests

if kerberos:
import requests_kerberos
self.auth = requests_kerberos.HTTPKerberosAuth()
Expand Down Expand Up @@ -332,7 +347,7 @@ def _partial_request(self, start_pos=None):
if start_pos is not None:
self.headers.update({"range": smart_open.utils.make_range_string(start_pos)})

response = requests.get(
response = self.session.get(
self.url,
auth=self.auth,
stream=True,
Expand Down
11 changes: 10 additions & 1 deletion smart_open/tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import smart_open.http
import smart_open.s3
import smart_open.constants

import requests

BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter'
URL = 'http://localhost'
Expand Down Expand Up @@ -159,6 +159,15 @@ def test_timeout_attribute(self):
assert hasattr(reader, 'timeout')
assert reader.timeout == timeout

@responses.activate
def test_session_attribute(self):
session = requests.Session()
responses.add_callback(responses.GET, URL, callback=request_callback)
reader = smart_open.open(URL, "rb", transport_params={'session': session})
assert hasattr(reader, 'session')
assert reader.session == session
assert reader.read() == BYTES


@responses.activate
def test_seek_implicitly_enabled(numbytes=10):
Expand Down

0 comments on commit 672db98

Please sign in to comment.