New Upstream Release - smart-open
Ready changes
Summary
Merged new upstream version: 6.3.0 (was: 5.2.1).
Diff
diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..dd042f5
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,9 @@
+[flake8]
+# E121: Continuation line under-indented for hanging indent
+# E123: Continuation line missing indentation or outdented
+# E125: Continuation line with same indent as next logical line
+# E128: Continuation line under-indented for visual indent
+# E226: Missing whitespace around arithmetic operator
+# W503: Line break occurred before a binary operator
+ignore=E121,E123,E125,E128,E226,W503
+max-line-length=110
\ No newline at end of file
diff --git a/.github/workflows/python-package-win.yml b/.github/workflows/python-package-win.yml
deleted file mode 100644
index 2692788..0000000
--- a/.github/workflows/python-package-win.yml
+++ /dev/null
@@ -1,47 +0,0 @@
-name: Test under Windows
-on: [push, pull_request]
-jobs:
- build:
- runs-on: windows-2019
- strategy:
- matrix:
- include:
- - python-version: '3.6'
- toxenv: "py36-doctest"
-
- - python-version: '3.6'
- toxenv: "py36-test"
-
- - python-version: '3.7'
- toxenv: "py37-doctest"
-
- - python-version: '3.7'
- toxenv: "py37-test"
-
- - python-version: '3.8'
- toxenv: "py38-doctest"
-
- - python-version: '3.8'
- toxenv: "py38-test"
-
- - python-version: '3.9'
- toxenv: "py39-doctest"
-
- - python-version: '3.9'
- toxenv: "py39-test"
-
- steps:
- - uses: actions/checkout@v2
-
- - uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Update pip
- run: python -m pip install -U pip
-
- - name: Install tox
- run: python -m pip install tox
-
- - name: Test using Tox
- run: tox
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index bbe1c45..c2cb5b3 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -1,33 +1,81 @@
name: Test
on: [push, pull_request]
jobs:
- build:
- env:
- BOTO_CONFIG: "/dev/null"
- SO_BUCKET: smart-open
-
+ linters:
runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup up Python 3.10
+ uses: actions/setup-python@v2
+ with:
+ python-version: "3.10"
+
+ - name: Update pip
+ run: python -m pip install -U pip
+
+ - name: Install dependencies
+ run: python -m pip install flake8
+
+ - name: Run flake8 linter (source)
+ run: flake8 --show-source smart_open
+
+ unit_tests:
+ needs: [linters]
+ runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- - python-version: '3.6'
- toxenv: "check_keys,py36-doctest,py36-test,py36-benchmark,py36-integration"
- result_key: benchmark-results-py36
+ - {python: '3.7', os: ubuntu-20.04}
+ - {python: '3.8', os: ubuntu-20.04}
+ - {python: '3.9', os: ubuntu-20.04}
+ - {python: '3.10', os: ubuntu-20.04}
+
+ - {python: '3.7', os: windows-2019}
+ - {python: '3.8', os: windows-2019}
+ - {python: '3.9', os: windows-2019}
+ - {python: '3.10', os: windows-2019}
+ steps:
+ - uses: actions/checkout@v2
+
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
- - python-version: '3.7'
- toxenv: "check_keys,py37-doctest,enable_moto_server,py37-test,py37-benchmark,py37-integration,disable_moto_server"
- enable_moto_server: "1"
+ - name: Update pip
+ run: python -m pip install -U pip
+
+ #
+ # https://askubuntu.com/questions/1428181/module-lib-has-no-attribute-x509-v-flag-cb-issuer-check
+ #
+ - name: Upgrade PyOpenSSL
+ run: python -m pip install pyOpenSSL --upgrade
- - python-version: '3.8'
- toxenv: "check_keys,py38-doctest,test_coverage,py38-integration"
- coveralls: true
+ - name: Install smart_open and its dependencies
+ run: pip install -e .[test]
- - python-version: '3.9'
- toxenv: "check_keys,py39-doctest,test_coverage,py39-integration"
- coveralls: true
+ - name: Run unit tests
+ run: pytest smart_open -v -rfxECs --durations=20
+
+ doctest:
+ needs: [linters,unit_tests]
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ include:
+ - {python: '3.7', os: ubuntu-20.04}
+ - {python: '3.8', os: ubuntu-20.04}
+ - {python: '3.9', os: ubuntu-20.04}
+ - {python: '3.10', os: ubuntu-20.04}
- - python-version: '3.8'
- toxenv: "flake8"
+ #
+ # Some of the doctests don't pass on Windows because of Windows-specific
+ # character encoding issues.
+ #
+ # - {python: '3.7', os: windows-2019}
+ # - {python: '3.8', os: windows-2019}
+ # - {python: '3.9', os: windows-2019}
+ # - {python: '3.10', os: windows-2019}
steps:
- uses: actions/checkout@v2
@@ -39,15 +87,112 @@ jobs:
- name: Update pip
run: python -m pip install -U pip
- - name: Install tox
- run: python -m pip install tox
+ - name: Upgrade PyOpenSSL
+ run: python -m pip install pyOpenSSL --upgrade
- - name: Test using Tox
+ - name: Install smart_open and its dependencies
+ run: pip install -e .[test]
+
+ - name: Run doctests
+ run: python ci_helpers/doctest.py
env:
- SO_RESULT_KEY: ${{ matrix.result_key }}
- SO_ENABLE_MOTO_SERVER: ${{ matrix.enable_moto_server }}
- TOXENV: ${{ matrix.toxenv }}
- run: tox
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+
+ integration:
+ needs: [linters,unit_tests]
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ include:
+ - {python: '3.7', os: ubuntu-20.04, moto_server: true}
+ - {python: '3.8', os: ubuntu-20.04}
+ - {python: '3.9', os: ubuntu-20.04}
+ - {python: '3.10', os: ubuntu-20.04}
+
+ # Not sure why we exclude these, perhaps for historical reasons?
+ #
+ # - {python: '3.7', os: windows-2019}
+ # - {python: '3.8', os: windows-2019}
+ # - {python: '3.9', os: windows-2019}
+ # - {python: '3.10', os: windows-2019}
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Update pip
+ run: python -m pip install -U pip
+
+ - name: Upgrade PyOpenSSL
+ run: python -m pip install pyOpenSSL --upgrade
+
+ - run: python -m pip install numpy
+
+ - name: Install smart_open and its dependencies
+ run: pip install -e .[test]
+
+ - run: bash ci_helpers/helpers.sh enable_moto_server
+ if: ${{ matrix.moto_server }}
+
+ - run: |
+ sudo apt-get install vsftpd
+ sudo bash ci_helpers/helpers.sh create_ftp_ftps_servers
+
+ - name: Run integration tests
+ run: python ci_helpers/run_integration_tests.py
+ env:
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+
+ - run: bash ci_helpers/helpers.sh disable_moto_server
+ if: ${{ matrix.moto_server }}
+
+ - run: sudo bash ci_helpers/helpers.sh delete_ftp_ftps_servers
+
+ benchmarks:
+ needs: [linters,unit_tests]
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ include:
+ - {python: '3.7', os: ubuntu-20.04}
+ - {python: '3.8', os: ubuntu-20.04}
+ - {python: '3.9', os: ubuntu-20.04}
+ - {python: '3.10', os: ubuntu-20.04}
+
+ # - {python: '3.7', os: windows-2019}
+ # - {python: '3.8', os: windows-2019}
+ # - {python: '3.9', os: windows-2019}
+ # - {python: '3.10', os: windows-2019}
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Update pip
+ run: python -m pip install -U pip
+
+ - name: Upgrade PyOpenSSL
+ run: python -m pip install pyOpenSSL --upgrade
+
+ - name: Install smart_open and its dependencies
+ run: pip install -e .[test]
+
+ - run: pip install awscli pytest_benchmark
+
+ - name: Run benchmarks
+ run: python ci_helpers/run_benchmarks.py
+ env:
+ SO_BUCKET: smart-open
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
#
# The test_coverage environment in tox.ini generates coverage data and
@@ -60,11 +205,11 @@ jobs:
# (https://github.com/coverallsapp/github-action/issues/30) but it does
# not work with pytest output.
#
- - name: Upload code coverage to coveralls.io
- if: ${{ matrix.coveralls }}
- continue-on-error: true
- env:
- GITHUB_TOKEN: ${{ github.token }}
- run: |
- pip install coveralls
- coveralls
+ # - name: Upload code coverage to coveralls.io
+ # if: ${{ matrix.coveralls }}
+ # continue-on-error: true
+ # env:
+ # GITHUB_TOKEN: ${{ github.token }}
+ # run: |
+ # pip install coveralls
+ # coveralls
diff --git a/.gitignore b/.gitignore
index 485c692..3aa10d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -65,3 +65,4 @@ target/
# env files
.env
+.venv
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4fc77ff..eb914a2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,54 @@
# Unreleased
+## 6.3.0, 2022-12-12
+
+* Refactor Google Cloud Storage to use blob.open (__[ddelange](https://github.com/ddelange)__, [#744](https://github.com/RaRe-Technologies/smart_open/pull/744))
+* Add FTP/FTPS support (#33) (__[RachitSharma2001](https://github.com/RachitSharma2001)__, [#739](https://github.com/RaRe-Technologies/smart_open/pull/739))
+* Bring back compression_wrapper(filename) + use case-insensitive extension matching (__[piskvorky](https://github.com/piskvorky)__, [#737](https://github.com/RaRe-Technologies/smart_open/pull/737))
+* Fix avoidable S3 race condition (#693) (__[RachitSharma2001](https://github.com/RachitSharma2001)__, [#735](https://github.com/RaRe-Technologies/smart_open/pull/735))
+* setup.py: Remove pathlib2 (__[jayvdb](https://github.com/jayvdb)__, [#733](https://github.com/RaRe-Technologies/smart_open/pull/733))
+* Add flake8 config globally (__[cadnce](https://github.com/cadnce)__, [#732](https://github.com/RaRe-Technologies/smart_open/pull/732))
+* Added buffer_size parameter to http module (__[mullenkamp](https://github.com/mullenkamp)__, [#730](https://github.com/RaRe-Technologies/smart_open/pull/730))
+* Added documentation to support GCS anonymously (__[cadnce](https://github.com/cadnce)__, [#728](https://github.com/RaRe-Technologies/smart_open/pull/728))
+* Reconnect inactive sftp clients automatically (__[Kache](https://github.com/Kache)__, [#719](https://github.com/RaRe-Technologies/smart_open/pull/719))
+
+# 6.2.0, 14 September 2022
+
+- Fix quadratic time ByteBuffer operations (PR [#711](https://github.com/RaRe-Technologies/smart_open/pull/711), [@Joshua-Landau-Anthropic](https://github.com/Joshua-Landau-Anthropic))
+
+# 6.1.0, 21 August 2022
+
+- Add cert parameter to http transport params (PR [#703](https://github.com/RaRe-Technologies/smart_open/pull/703), [@stev-0](https://github.com/stev-0))
+- Allow passing additional kwargs for Azure writes (PR [#702](https://github.com/RaRe-Technologies/smart_open/pull/702), [@ddelange](https://github.com/ddelange))
+
+# 6.0.0, 24 April 2022
+
+This release deprecates the old `ignore_ext` parameter.
+Use the `compression` parameter instead.
+
+```python
+fin = smart_open.open("/path/file.gz", ignore_ext=True) # No
+fin = smart_open.open("/path/file.gz", compression="disable") # Yes
+
+fin = smart_open.open("/path/file.gz", ignore_ext=False) # No
+fin = smart_open.open("/path/file.gz") # Yes
+fin = smart_open.open("/path/file.gz", compression="infer_from_extension") # Yes, if you want to be explicit
+
+fin = smart_open.open("/path/file", compression=".gz") # Yes
+```
+
+- Make Python 3.7 the required minimum (PR [#688](https://github.com/RaRe-Technologies/smart_open/pull/688), [@mpenkov](https://github.com/mpenkov))
+- Drop deprecated ignore_ext parameter (PR [#661](https://github.com/RaRe-Technologies/smart_open/pull/661), [@mpenkov](https://github.com/mpenkov))
+- Drop support for passing buffers to smart_open.open (PR [#660](https://github.com/RaRe-Technologies/smart_open/pull/660), [@mpenkov](https://github.com/mpenkov))
+- Support working directly with file descriptors (PR [#659](https://github.com/RaRe-Technologies/smart_open/pull/659), [@mpenkov](https://github.com/mpenkov))
+- Added support for viewfs:// URLs (PR [#665](https://github.com/RaRe-Technologies/smart_open/pull/665), [@ChandanChainani](https://github.com/ChandanChainani))
+- Fix AttributeError when reading passthrough zstandard (PR [#658](https://github.com/RaRe-Technologies/smart_open/pull/658), [@mpenkov](https://github.com/mpenkov))
+- Make UploadFailedError picklable (PR [#689](https://github.com/RaRe-Technologies/smart_open/pull/689), [@birgerbr](https://github.com/birgerbr))
+- Support container client and blob client for azure blob storage (PR [#652](https://github.com/RaRe-Technologies/smart_open/pull/652), [@cbare](https://github.com/cbare))
+- Pin google-cloud-storage to >=1.31.1 in extras (PR [#687](https://github.com/RaRe-Technologies/smart_open/pull/687), [@PLPeeters](https://github.com/PLPeeters))
+- Expose certain transport-specific methods e.g. to_boto3 in top layer (PR [#664](https://github.com/RaRe-Technologies/smart_open/pull/664), [@mpenkov](https://github.com/mpenkov))
+- Use pytest instead of parameterizedtestcase (PR [#657](https://github.com/RaRe-Technologies/smart_open/pull/657), [@mpenkov](https://github.com/mpenkov))
+
# 5.2.1, 28 August 2021
- make HTTP/S seeking less strict (PR [#646](https://github.com/RaRe-Technologies/smart_open/pull/646), [@mpenkov](https://github.com/mpenkov))
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..c55d71e
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,28 @@
+# Quickstart
+
+Clone the repo and use a python installation to create a venv:
+
+```sh
+git clone git@github.com:RaRe-Technologies/smart_open.git
+cd smart_open
+python -m venv .venv
+```
+
+Activate the venv to start working and install test deps:
+
+```sh
+.venv/bin/activate
+pip install -e ".[test]"
+```
+
+Tests should pass:
+
+```sh
+pytest
+```
+
+Thats it! When you're done, deactivate the venv:
+
+```sh
+deactivate
+```
diff --git a/MIGRATING_FROM_OLDER_VERSIONS.rst b/MIGRATING_FROM_OLDER_VERSIONS.rst
index e1b6e08..96115a5 100644
--- a/MIGRATING_FROM_OLDER_VERSIONS.rst
+++ b/MIGRATING_FROM_OLDER_VERSIONS.rst
@@ -1,3 +1,21 @@
+Migrating to the new compression parameter
+==========================================
+
+smart_open versions 6.0.0 and above no longer support the ``ignore_ext`` parameter.
+Use the ``compression`` parameter instead:
+
+```python
+fin = smart_open.open("/path/file.gz", ignore_ext=True) # No
+fin = smart_open.open("/path/file.gz", compression="disable") # Yes
+
+fin = smart_open.open("/path/file.gz", ignore_ext=False) # No
+fin = smart_open.open("/path/file.gz") # Yes
+fin = smart_open.open("/path/file.gz", compression="infer_from_extension") # Yes, if you want to be explicit
+
+fin = smart_open.open("/path/file", compression=".gz") # Yes
+
+```
+
Migrating to the new client-based S3 API
========================================
diff --git a/README.rst b/README.rst
index 02108bf..6443523 100644
--- a/README.rst
+++ b/README.rst
@@ -348,15 +348,14 @@ Since going over all (or select) keys in an S3 bucket is a very common operation
.. code-block:: python
>>> from smart_open import s3
- >>> # get data corresponding to 2010 and later under "silo-open-data/annual/monthly_rain"
>>> # we use workers=1 for reproducibility; you should use as many workers as you have cores
>>> bucket = 'silo-open-data'
- >>> prefix = 'annual/monthly_rain/'
+ >>> prefix = 'Official/annual/monthly_rain/'
>>> for key, content in s3.iter_bucket(bucket, prefix=prefix, accept_key=lambda key: '/201' in key, workers=1, key_limit=3):
... print(key, round(len(content) / 2**20))
- annual/monthly_rain/2010.monthly_rain.nc 13
- annual/monthly_rain/2011.monthly_rain.nc 13
- annual/monthly_rain/2012.monthly_rain.nc 13
+ Official/annual/monthly_rain/2010.monthly_rain.nc 13
+ Official/annual/monthly_rain/2011.monthly_rain.nc 13
+ Official/annual/monthly_rain/2012.monthly_rain.nc 13
GCS Credentials
---------------
@@ -418,36 +417,6 @@ to setting up authentication.
If you need more credential options, refer to the
`Azure Storage authentication guide <https://docs.microsoft.com/en-us/azure/storage/common/storage-samples-python#authentication>`__.
-File-like Binary Streams
-------------------------
-
-The ``open`` function also accepts file-like objects.
-This is useful when you already have a `binary file <https://docs.python.org/3/glossary.html#term-binary-file>`_ open, and would like to wrap it with transparent decompression:
-
-
-.. code-block:: python
-
- >>> import io, gzip
- >>>
- >>> # Prepare some gzipped binary data in memory, as an example.
- >>> # Any binary file will do; we're using BytesIO here for simplicity.
- >>> buf = io.BytesIO()
- >>> with gzip.GzipFile(fileobj=buf, mode='w') as fout:
- ... _ = fout.write(b'this is a bytestring')
- >>> _ = buf.seek(0)
- >>>
- >>> # Use case starts here.
- >>> buf.name = 'file.gz' # add a .name attribute so smart_open knows what compressor to use
- >>> import smart_open
- >>> smart_open.open(buf, 'rb').read() # will gzip-decompress transparently!
- b'this is a bytestring'
-
-
-In this case, ``smart_open`` relied on the ``.name`` attribute of our `binary I/O stream <https://docs.python.org/3/library/io.html#binary-i-o>`_ ``buf`` object to determine which decompressor to use.
-If your file object doesn't have one, set the ``.name`` attribute to an appropriate value.
-Furthermore, that value has to end with a **known** file extension (see the ``register_compressor`` function).
-Otherwise, the transparent decompression will not occur.
-
Drop-in replacement of ``pathlib.Path.open``
--------------------------------------------
@@ -502,3 +471,4 @@ issues or pull requests there. Suggestions, pull requests and improvements welco
``smart_open`` is open source software released under the `MIT license <https://github.com/piskvorky/smart_open/blob/master/LICENSE>`_.
Copyright (c) 2015-now `Radim Řehůřek <https://radimrehurek.com>`_.
+
diff --git a/benchmark/bytebuffer_bench.py b/benchmark/bytebuffer_bench.py
new file mode 100644
index 0000000..257e0e2
--- /dev/null
+++ b/benchmark/bytebuffer_bench.py
@@ -0,0 +1,34 @@
+import time
+import sys
+
+import smart_open
+from smart_open.bytebuffer import ByteBuffer
+
+
+def raw_bytebuffer_benchmark():
+ buffer = ByteBuffer()
+
+ start = time.time()
+ for _ in range(10_000):
+ assert buffer.fill([b"X" * 1000]) == 1000
+ return time.time() - start
+
+
+def file_read_benchmark(filename):
+ file = smart_open.open(filename, mode="rb")
+
+ start = time.time()
+ read = file.read(100_000_000)
+ end = time.time()
+
+ if len(read) < 100_000_000:
+ print("File smaller than 100MB")
+
+ return end - start
+
+
+print("Raw ByteBuffer benchmark:", raw_bytebuffer_benchmark())
+
+if len(sys.argv) > 1:
+ bench_result = file_read_benchmark(sys.argv[1])
+ print("File read benchmark", bench_result)
diff --git a/ci_helpers/README.txt b/ci_helpers/README.txt
new file mode 100644
index 0000000..d645134
--- /dev/null
+++ b/ci_helpers/README.txt
@@ -0,0 +1,3 @@
+This subdirectory contains helper scripts for our continuous integration workflows file.
+
+They are designed to be platform-independent: they run on both Linux and Windows.
diff --git a/tox_helpers/doctest.py b/ci_helpers/doctest.py
similarity index 100%
rename from tox_helpers/doctest.py
rename to ci_helpers/doctest.py
diff --git a/ci_helpers/helpers.sh b/ci_helpers/helpers.sh
new file mode 100644
index 0000000..4533b33
--- /dev/null
+++ b/ci_helpers/helpers.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+set -e
+set -x
+
+enable_moto_server(){
+ moto_server -p5000 2>/dev/null&
+}
+
+create_ftp_ftps_servers(){
+ #
+ # Must be run as root
+ #
+ home_dir=/home/user
+ user=user
+ pass=123
+ ftp_port=21
+ ftps_port=90
+
+ mkdir $home_dir
+ useradd -p $(echo $pass | openssl passwd -1 -stdin) -d $home_dir $user
+ chown $user:$user $home_dir
+
+ server_setup='''
+listen=YES
+listen_ipv6=NO
+write_enable=YES
+pasv_enable=YES
+pasv_min_port=40000
+pasv_max_port=40009
+chroot_local_user=YES
+allow_writeable_chroot=YES'''
+
+ additional_ssl_setup='''
+ssl_enable=YES
+allow_anon_ssl=NO
+force_local_data_ssl=NO
+force_local_logins_ssl=NO
+require_ssl_reuse=NO
+'''
+
+ cp /etc/vsftpd.conf /etc/vsftpd-ssl.conf
+ echo -e "$server_setup\nlisten_port=${ftp_port}" >> /etc/vsftpd.conf
+ echo -e "$server_setup\nlisten_port=${ftps_port}\n$additional_ssl_setup" >> /etc/vsftpd-ssl.conf
+
+ service vsftpd restart
+ vsftpd /etc/vsftpd-ssl.conf &
+}
+
+disable_moto_server(){
+ lsof -i tcp:5000 | tail -n1 | cut -f2 -d" " | xargs kill -9
+}
+
+delete_ftp_ftps_servers(){
+ service vsftpd stop
+}
+
+"$@"
diff --git a/tox_helpers/run_benchmarks.py b/ci_helpers/run_benchmarks.py
similarity index 100%
rename from tox_helpers/run_benchmarks.py
rename to ci_helpers/run_benchmarks.py
diff --git a/tox_helpers/run_integration_tests.py b/ci_helpers/run_integration_tests.py
similarity index 91%
rename from tox_helpers/run_integration_tests.py
rename to ci_helpers/run_integration_tests.py
index 51726cb..5417ca8 100644
--- a/tox_helpers/run_integration_tests.py
+++ b/ci_helpers/run_integration_tests.py
@@ -9,6 +9,7 @@ subprocess.check_call(
'pytest',
'integration-tests/test_207.py',
'integration-tests/test_http.py',
+ 'integration-tests/test_ftp.py'
]
)
diff --git a/tox_helpers/test_missing_dependencies.py b/ci_helpers/test_missing_dependencies.py
similarity index 100%
rename from tox_helpers/test_missing_dependencies.py
rename to ci_helpers/test_missing_dependencies.py
diff --git a/debian/changelog b/debian/changelog
index a6e7d6f..eb81f4e 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+smart-open (6.3.0-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Mon, 06 Mar 2023 15:27:26 -0000
+
smart-open (5.2.1-5) unstable; urgency=medium
[ Debian Janitor ]
diff --git a/help.txt b/help.txt
index 6d642bb..b5e250e 100644
--- a/help.txt
+++ b/help.txt
@@ -124,6 +124,8 @@ FUNCTIONS
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
+ cert: str/tuple, optional
+ If String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’)
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
diff --git a/howto.md b/howto.md
index 7ef716e..9a56a39 100644
--- a/howto.md
+++ b/howto.md
@@ -25,7 +25,11 @@ Finally, ensure all the guides still work by running:
python -m doctest howto.md
-The above command shouldn't print anything to standard output/error and return zero.
+The above command shouldn't print anything to standard output/error and return zero, provided your local environment is set up correctly:
+
+- you have a working Internet connection
+- localstack is running, and the `mybucket` S3 bucket has been created
+- the GITHUB_TOKEN environment variable is set to a valid [access token](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token)
## How to Read/Write Zip Files
@@ -156,6 +160,7 @@ You can then interact with the object using the `boto3` API:
```python
>>> import boto3
+>>> from smart_open import open
>>> resource = boto3.resource('s3') # Pass additional resource parameters here
>>> with open('s3://commoncrawl/robots.txt') as fin:
... boto3_object = fin.to_boto3(resource)
@@ -171,9 +176,11 @@ This works only when reading and writing via S3.
For versioned objects, the returned object will be slightly different:
```python
+>>> from smart_open import open
+>>> resource = boto3.resource('s3')
>>> params = {'version_id': 'KiQpZPsKI5Dm2oJZy_RzskTOtl2snjBg'}
>>> with open('s3://smart-open-versioned/demo.txt', transport_params=params) as fin:
-... print(fin.to_boto3())
+... print(fin.to_boto3(resource))
s3.ObjectVersion(bucket_name='smart-open-versioned', object_key='demo.txt', id='KiQpZPsKI5Dm2oJZy_RzskTOtl2snjBg')
```
@@ -242,7 +249,7 @@ To access such buckets, you need to pass some special transport parameters:
```python
>>> from smart_open import open
->>> params = {'client_kwargs': {'S3.Client.get_object': {RequestPayer': 'requester'}}}
+>>> params = {'client_kwargs': {'S3.Client.get_object': {'RequestPayer': 'requester'}}}
>>> with open('s3://arxiv/pdf/arXiv_pdf_manifest.xml', transport_params=params) as fin:
... print(fin.readline())
<?xml version='1.0' standalone='yes'?>
@@ -266,9 +273,11 @@ You can fine-tune it using several ways:
>>> config = botocore.config.Config(retries={'mode': 'standard'})
>>> client = boto3.client('s3', config=config)
>>> tp = {'client': client}
->>> with smart_open.open('s3://commoncrawl/robots.txt', transport_params=tp) as fin:
+>>> with open('s3://commoncrawl/robots.txt', transport_params=tp) as fin:
... print(fin.readline())
User-Agent: *
+<BLANKLINE>
+
```
To verify your settings have effect:
@@ -288,9 +297,13 @@ Instead, `smart_open` offers the caller of the function to pass additional param
```python
>>> import boto3
->>> client_kwargs = {'S3.Client.get_object': {RequestPayer': 'requester'}}}
+>>> from smart_open import open
+>>> transport_params = {'client_kwargs': {'S3.Client.get_object': {'RequestPayer': 'requester'}}}
>>> with open('s3://arxiv/pdf/arXiv_pdf_manifest.xml', transport_params=params) as fin:
-... pass
+... print(fin.readline())
+<?xml version='1.0' standalone='yes'?>
+<BLANKLINE>
+
```
The above example influences how the [S3.Client.get_object function](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object) gets called by `smart_open` when reading the specified URL.
@@ -318,7 +331,7 @@ More specifically, here's the direct method:
```python
import boto3
import smart_open
-with smart_open.open('s3://bucket/key', 'wb') as fout:
+with open('s3://bucket/key', 'wb') as fout:
fout.write(b'hello world!')
client = boto3.client('s3')
client.put_object_acl(ACL=acl_as_string)
@@ -329,7 +342,7 @@ Here's the same code that passes the above parameter via `smart_open`:
```python
import smart_open
tp = {'client_kwargs': {'S3.Client.create_multipart_upload': {'ACL': acl_as_string}}}
-with smart_open.open('s3://bucket/key', 'wb', transport_params=tp) as fout:
+with open('s3://bucket/key', 'wb', transport_params=tp) as fout:
fout.write(b'hello world!')
```
@@ -342,22 +355,20 @@ access to. Below is an example for how users can read a file with smart_open. Fo
[Github API documentation](https://docs.github.com/en/rest/reference/repos#contents).
```python
->>> from smart_open import open
>>> import base64
+>>> import gzip
>>> import json
->>> owner = "RaRe-Technologies"
->>> repo = "smart_open"
->>> path = "howto.md"
->>> git_token = "..."
+>>> import os
+>>> from smart_open import open
+>>> owner, repo, path = "RaRe-Technologies", "smart_open", "howto.md"
+>>> github_token = os.environ['GITHUB_TOKEN']
>>> url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
->>> transport_params = {
-... "headers" : {
-... "Authorization" : "Bearer " + git_token
-... }
-... }
->>> with open(url, transport_params=transport_params) as obj:
-... response_contents = json.loads(obj.read())["contents"]
-... file_text = base64.b64decode(response_contents).decode()
+>>> params = {"headers" : {"Authorization" : "Bearer " + github_token}}
+>>> with open(url, 'rb', transport_params=params) as fin:
+... response = json.loads(gzip.decompress(fin.read()))
+>>> response["path"]
+'howto.md'
+
```
Note: If you are accessing a file in a Github Enterprise org, you will likely have a different base dns than
@@ -389,7 +400,7 @@ You can now read/write to the bucket the same way you would to a real S3 bucket:
>>> client = boto3.client('s3', endpoint_url='http://localhost:4566')
>>> tparams = {'client': client}
>>> with open('s3://mybucket/hello.txt', 'wt', transport_params=tparams) as fout:
-... fout.write('hello world!')
+... _ = fout.write('hello world!')
>>> with open('s3://mybucket/hello.txt', 'rt', transport_params=tparams) as fin:
... fin.read()
'hello world!'
@@ -419,4 +430,20 @@ for an explanation). To download all files in a directory you can do this:
... print(f.name)
... break # just show the first iteration for the test
LC08/01/044/034/LC08_L1GT_044034_20130330_20170310_01_T2/LC08_L1GT_044034_20130330_20170310_01_T2_ANG.txt
+
+```
+
+## How to Access Google Cloud Anonymously
+
+The `google-cloud-storage` library that `smart_open` uses expects credentials and authenticated access by default.
+If you would like to access GCS without using an account you need to explicitly use an anonymous client.
+
+```python
+>>> from google.cloud import storage
+>>> from smart_open import open
+>>> client = storage.Client.create_anonymous_client()
+>>> f = open("gs://gcp-public-data-landsat/index.csv.gz", transport_params=dict(client=client))
+>>> f.readline()
+'SCENE_ID,PRODUCT_ID,SPACECRAFT_ID,SENSOR_ID,DATE_ACQUIRED,COLLECTION_NUMBER,COLLECTION_CATEGORY,SENSING_TIME,DATA_TYPE,WRS_PATH,WRS_ROW,CLOUD_COVER,NORTH_LAT,SOUTH_LAT,WEST_LON,EAST_LON,TOTAL_SIZE,BASE_URL\n'
+
```
diff --git a/integration-tests/test_ftp.py b/integration-tests/test_ftp.py
new file mode 100644
index 0000000..000faae
--- /dev/null
+++ b/integration-tests/test_ftp.py
@@ -0,0 +1,84 @@
+from __future__ import unicode_literals
+import pytest
+from smart_open import open
+
+
+@pytest.fixture(params=[("ftp", 21), ("ftps", 90)])
+def server_info(request):
+ return request.param
+
+def test_nonbinary(server_info):
+ server_type = server_info[0]
+ port_num = server_info[1]
+ file_contents = "Test Test \n new test \n another tests"
+ appended_content1 = "Added \n to end"
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file", "w") as f:
+ f.write(file_contents)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file", "r") as f:
+ read_contents = f.read()
+ assert read_contents == file_contents
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file", "a") as f:
+ f.write(appended_content1)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file", "r") as f:
+ read_contents = f.read()
+ assert read_contents == file_contents + appended_content1
+
+def test_binary(server_info):
+ server_type = server_info[0]
+ port_num = server_info[1]
+ file_contents = b"Test Test \n new test \n another tests"
+ appended_content1 = b"Added \n to end"
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file2", "wb") as f:
+ f.write(file_contents)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file2", "rb") as f:
+ read_contents = f.read()
+ assert read_contents == file_contents
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file2", "ab") as f:
+ f.write(appended_content1)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file2", "rb") as f:
+ read_contents = f.read()
+ assert read_contents == file_contents + appended_content1
+
+def test_line_endings_non_binary(server_info):
+ server_type = server_info[0]
+ port_num = server_info[1]
+ B_CLRF = b'\r\n'
+ CLRF = '\r\n'
+ file_contents = f"Test Test {CLRF} new test {CLRF} another tests{CLRF}"
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file3", "w") as f:
+ f.write(file_contents)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file3", "r") as f:
+ for line in f:
+ assert not CLRF in line
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file3", "rb") as f:
+ for line in f:
+ assert B_CLRF in line
+
+def test_line_endings_binary(server_info):
+ server_type = server_info[0]
+ port_num = server_info[1]
+ B_CLRF = b'\r\n'
+ CLRF = '\r\n'
+ file_contents = f"Test Test {CLRF} new test {CLRF} another tests{CLRF}".encode('utf-8')
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file4", "wb") as f:
+ f.write(file_contents)
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file4", "r") as f:
+ for line in f:
+ assert not CLRF in line
+
+ with open(f"{server_type}://user:123@localhost:{port_num}/file4", "rb") as f:
+ for line in f:
+ assert B_CLRF in line
\ No newline at end of file
diff --git a/integration-tests/test_s3.py b/integration-tests/test_s3.py
index 8af3229..2a6cc84 100644
--- a/integration-tests/test_s3.py
+++ b/integration-tests/test_s3.py
@@ -11,7 +11,6 @@ import contextlib
import io
import os
import random
-import subprocess
import string
import boto3
diff --git a/integration-tests/test_s3_ported.py b/integration-tests/test_s3_ported.py
index 67dfcd9..755918b 100644
--- a/integration-tests/test_s3_ported.py
+++ b/integration-tests/test_s3_ported.py
@@ -20,7 +20,7 @@ import uuid
import warnings
import boto3
-from parameterizedtestcase import ParameterizedTestCase as PTestCase
+import pytest
import smart_open
import smart_open.concurrency
@@ -170,7 +170,7 @@ class ReaderTest(unittest.TestCase):
expected = CONTENTS[key_name]
with smart_open.s3.Reader(BUCKET_NAME, key_name) as fin:
- returned_obj = fin.to_boto3()
+ returned_obj = fin.to_boto3(boto3.resource('s3'))
boto3_body = returned_obj.get()['Body'].read()
self.assertEqual(expected, boto3_body)
@@ -277,7 +277,7 @@ def force(multiprocessing=False, concurrent_futures=False):
smart_open.concurrency._CONCURRENT_FUTURES = old_concurrent_futures
-class IterBucketTest(PTestCase):
+class IterBucketTest(unittest.TestCase):
def setUp(self):
self.expected = [
(key, value)
@@ -321,11 +321,17 @@ class IterBucketTest(PTestCase):
self.assertEqual(len(expected), len(actual))
self.assertEqual(expected, sorted(actual))
- @PTestCase.parameterize(('workers',), [(x,) for x in (1, 4, 8, 16, 64)])
- def test_workers(self, workers):
- actual = list(smart_open.s3.iter_bucket(BUCKET_NAME, prefix='iter_bucket', workers=workers))
- self.assertEqual(len(self.expected), len(actual))
- self.assertEqual(self.expected, sorted(actual))
+
+@pytest.mark.parametrize('workers', [1, 4, 8, 16, 64])
+def test_workers(workers):
+ expected = sorted([
+ (key, value)
+ for (key, value) in CONTENTS.items()
+ if key.startswith('iter_bucket/')
+ ])
+ actual = sorted(smart_open.s3.iter_bucket(BUCKET_NAME, prefix='iter_bucket', workers=workers))
+ assert len(expected) == len(actual)
+ assert expected == actual
class DownloadKeyTest(unittest.TestCase):
diff --git a/integration-tests/test_ssh.py b/integration-tests/test_ssh.py
new file mode 100644
index 0000000..252a6fb
--- /dev/null
+++ b/integration-tests/test_ssh.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2022 Radim Rehurek <me@radimrehurek.com>
+#
+# This code is distributed under the terms and conditions
+# from the MIT License (MIT).
+#
+
+import os
+import tempfile
+import pytest
+
+import smart_open
+import smart_open.ssh
+
+
+def explode(*args, **kwargs):
+ raise RuntimeError("this function should never have been called")
+
+
+@pytest.mark.skipif("SMART_OPEN_SSH" not in os.environ, reason="this test only works on the dev machine")
+def test():
+ with smart_open.open("ssh://misha@localhost/Users/misha/git/smart_open/README.rst") as fin:
+ readme = fin.read()
+
+ assert 'smart_open — utils for streaming large files in Python' in readme
+
+ #
+ # Ensure the cache is being used
+ #
+ assert ('localhost', 'misha') in smart_open.ssh._SSH
+
+ try:
+ connect_ssh = smart_open.ssh._connect_ssh
+ smart_open.ssh._connect_ssh = explode
+
+ with smart_open.open("ssh://misha@localhost/Users/misha/git/smart_open/howto.md") as fin:
+ howto = fin.read()
+
+ assert 'How-to Guides' in howto
+ finally:
+ smart_open.ssh._connect_ssh = connect_ssh
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..6cab0ef
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,2 @@
+[tool.pytest.ini_options]
+testpaths = ["smart_open"]
diff --git a/setup.py b/setup.py
index e44dc0e..57ad9e6 100644
--- a/setup.py
+++ b/setup.py
@@ -35,19 +35,18 @@ __version__ = _get_version()
def read(fname):
return io.open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8').read()
+
aws_deps = ['boto3']
-gcs_deps = ['google-cloud-storage']
+gcs_deps = ['google-cloud-storage>=2.6.0']
azure_deps = ['azure-storage-blob', 'azure-common', 'azure-core']
http_deps = ['requests']
+ssh_deps = ['paramiko']
-all_deps = aws_deps + gcs_deps + azure_deps + http_deps
+all_deps = aws_deps + gcs_deps + azure_deps + http_deps + ssh_deps
tests_require = all_deps + [
- 'moto[server]==1.3.14', # Older versions of moto appear broken
- 'pathlib2',
+ 'moto[server]',
'responses',
'boto3',
- 'paramiko',
- 'parameterizedtestcase',
'pytest',
'pytest-rerunfailures'
]
@@ -80,6 +79,7 @@ setup(
'all': all_deps,
'http': http_deps,
'webhdfs': http_deps,
+ 'ssh': ssh_deps,
},
python_requires=">=3.6,<4.0",
@@ -91,9 +91,10 @@ setup(
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
- 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
'Topic :: System :: Distributed Computing',
'Topic :: Database :: Front-Ends',
],
diff --git a/smart_open/azure.py b/smart_open/azure.py
index 7bac6a2..96f944a 100644
--- a/smart_open/azure.py
+++ b/smart_open/azure.py
@@ -68,7 +68,8 @@ def open(
container_id,
blob_id,
mode,
- client=None, # type: azure.storage.blob.BlobServiceClient
+ client=None, # type: Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient] # noqa
+ blob_kwargs=None,
buffer_size=DEFAULT_BUFFER_SIZE,
min_part_size=_DEFAULT_MIN_PART_SIZE,
max_concurrency=DEFAULT_MAX_CONCURRENCY,
@@ -83,12 +84,15 @@ def open(
The name of the blob within the bucket.
mode: str
The mode for opening the object. Must be either "rb" or "wb".
- client: azure.storage.blob.BlobServiceClient
+ client: azure.storage.blob.BlobServiceClient, ContainerClient, or BlobClient
The Azure Blob Storage client to use when working with azure-storage-blob.
+ blob_kwargs: dict, optional
+ Additional parameters to pass to `BlobClient.commit_block_list`.
+ For writing only.
buffer_size: int, optional
The buffer size to use when performing I/O. For reading only.
min_part_size: int, optional
- The minimum part size for multipart uploads. For writing only.
+ The minimum part size for multipart uploads. For writing only.
max_concurrency: int, optional
The number of parallel connections with which to download. For reading only.
@@ -110,12 +114,34 @@ def open(
container_id,
blob_id,
client,
+ blob_kwargs=blob_kwargs,
min_part_size=min_part_size
)
else:
raise NotImplementedError('Azure Blob Storage support for mode %r not implemented' % mode)
+def _get_blob_client(client, container, blob):
+ # type: (Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient], str, str) -> azure.storage.blob.BlobClient # noqa
+ """
+ Return an Azure BlobClient starting with any of BlobServiceClient,
+ ContainerClient, or BlobClient plus container name and blob name.
+ """
+ if hasattr(client, "get_container_client"):
+ client = client.get_container_client(container)
+
+ if hasattr(client, "container_name") and client.container_name != container:
+ raise ValueError(
+ "Client for %r doesn't match "
+ "container %r" % (client.container_name, container)
+ )
+
+ if hasattr(client, "get_blob_client"):
+ client = client.get_blob_client(blob)
+
+ return client
+
+
class _RawReader(object):
"""Read an Azure Blob Storage file."""
@@ -175,15 +201,16 @@ class Reader(io.BufferedIOBase):
self,
container,
blob,
- client, # type: azure.storage.blob.BlobServiceClient
+ client, # type: Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient] # noqa
buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=smart_open.constants.BINARY_NEWLINE,
max_concurrency=DEFAULT_MAX_CONCURRENCY,
):
- self._container_client = client.get_container_client(container)
- # type: azure.storage.blob.ContainerClient
+ self._container_name = container
+
+ self._blob = _get_blob_client(client, container, blob)
+ # type: azure.storage.blob.BlobClient
- self._blob = self._container_client.get_blob_client(blob)
if self._blob is None:
raise azure.core.exceptions.ResourceNotFoundError(
'blob %s not found in %s' % (blob, container)
@@ -345,12 +372,12 @@ class Reader(io.BufferedIOBase):
def __str__(self):
return "(%s, %r, %r)" % (self.__class__.__name__,
- self._container.container_name,
+ self._container_name,
self._blob.blob_name)
def __repr__(self):
return "%s(container=%r, blob=%r)" % (
- self.__class__.__name__, self._container_client.container_name, self._blob.blob_name,
+ self.__class__.__name__, self._container_name, self._blob.blob_name,
)
@@ -363,13 +390,17 @@ class Writer(io.BufferedIOBase):
self,
container,
blob,
- client, # type: azure.storage.blob.BlobServiceClient
+ client, # type: Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient] # noqa
+ blob_kwargs=None,
min_part_size=_DEFAULT_MIN_PART_SIZE,
):
- self._client = client
- self._container_client = self._client.get_container_client(container)
- # type: azure.storage.blob.ContainerClient
- self._blob = self._container_client.get_blob_client(blob) # type: azure.storage.blob.BlobClient
+ self._is_closed = False
+ self._container_name = container
+
+ self._blob = _get_blob_client(client, container, blob)
+ self._blob_kwargs = blob_kwargs or {}
+ # type: azure.storage.blob.BlobClient
+
self._min_part_size = min_part_size
self._total_size = 0
@@ -394,14 +425,14 @@ class Writer(io.BufferedIOBase):
if not self.closed:
if self._current_part.tell() > 0:
self._upload_part()
- self._blob.commit_block_list(self._block_list)
+ self._blob.commit_block_list(self._block_list, **self._blob_kwargs)
self._block_list = []
- self._client = None
+ self._is_closed = True
logger.debug("successfully closed")
@property
def closed(self):
- return self._client is None
+ return self._is_closed
def writable(self):
"""Return True if the stream supports writing."""
@@ -483,14 +514,14 @@ class Writer(io.BufferedIOBase):
def __str__(self):
return "(%s, %r, %r)" % (
self.__class__.__name__,
- self._container_client.container_name,
+ self._container_name,
self._blob.blob_name
)
def __repr__(self):
return "%s(container=%r, blob=%r, min_part_size=%r)" % (
self.__class__.__name__,
- self._container_client.container_name,
+ self._container_name,
self._blob.blob_name,
self._min_part_size
)
diff --git a/smart_open/bytebuffer.py b/smart_open/bytebuffer.py
index 65e9f27..6aaa251 100644
--- a/smart_open/bytebuffer.py
+++ b/smart_open/bytebuffer.py
@@ -105,12 +105,12 @@ class ByteBuffer(object):
if size < 0 or size > len(self):
size = len(self)
- part = self._bytes[self._pos:self._pos+size]
+ part = bytes(self._bytes[self._pos:self._pos+size])
return part
def empty(self):
"""Remove all bytes from the buffer"""
- self._bytes = b''
+ self._bytes = bytearray()
self._pos = 0
def fill(self, source, size=-1):
@@ -151,7 +151,7 @@ class ByteBuffer(object):
if hasattr(source, 'read'):
new_bytes = source.read(size)
else:
- new_bytes = b''
+ new_bytes = bytearray()
for more_bytes in source:
new_bytes += more_bytes
if len(new_bytes) >= size:
diff --git a/smart_open/compression.py b/smart_open/compression.py
index ac66d62..08469e7 100644
--- a/smart_open/compression.py
+++ b/smart_open/compression.py
@@ -62,6 +62,7 @@ def register_compressor(ext, callback):
"""
if not (ext and ext[0] == '.'):
raise ValueError('ext must be a string starting with ., not %r' % ext)
+ ext = ext.lower()
if ext in _COMPRESSOR_REGISTRY:
logger.warning('overriding existing compression handler for %r', ext)
_COMPRESSOR_REGISTRY[ext] = callback
@@ -103,23 +104,23 @@ def _handle_gzip(file_obj, mode):
return result
-def compression_wrapper(file_obj, mode, compression):
+def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
"""
- This function will wrap the file_obj with an appropriate
- [de]compression mechanism based on the specified extension.
+ Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
- file_obj must either be a filehandle object, or a class which behaves
- like one. It must have a .name attribute.
+ If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
- If the filename extension isn't recognized, will simply return the original
- file_obj.
+ `file_obj` must either be a filehandle object, or a class which behaves like one.
+
+ If `filename` is specified, it will be used to extract the extension.
+ If not, the `file_obj.name` attribute is used as the filename.
"""
if compression == NO_COMPRESSION:
return file_obj
elif compression == INFER_FROM_EXTENSION:
try:
- filename = file_obj.name
+ filename = (filename or file_obj.name).lower()
except (AttributeError, TypeError):
logger.warning(
'unable to transparently decompress %r because it '
diff --git a/smart_open/ftp.py b/smart_open/ftp.py
new file mode 100644
index 0000000..3dbe26f
--- /dev/null
+++ b/smart_open/ftp.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
+#
+# This code is distributed under the terms and conditions
+# from the MIT License (MIT).
+#
+
+"""Implements I/O streams over FTP.
+"""
+
+import logging
+import urllib.parse
+import smart_open.utils
+from ftplib import FTP, FTP_TLS, error_reply
+import types
+logger = logging.getLogger(__name__)
+
+SCHEMES = ("ftp", "ftps")
+
+"""Supported URL schemes."""
+
+DEFAULT_PORT = 21
+
+URI_EXAMPLES = (
+ "ftp://username@host/path/file",
+ "ftp://username:password@host/path/file",
+ "ftp://username:password@host:port/path/file",
+ "ftps://username@host/path/file",
+ "ftps://username:password@host/path/file",
+ "ftps://username:password@host:port/path/file",
+)
+
+
+def _unquote(text):
+ return text and urllib.parse.unquote(text)
+
+
+def parse_uri(uri_as_string):
+ split_uri = urllib.parse.urlsplit(uri_as_string)
+ assert split_uri.scheme in SCHEMES
+ return dict(
+ scheme=split_uri.scheme,
+ uri_path=_unquote(split_uri.path),
+ user=_unquote(split_uri.username),
+ host=split_uri.hostname,
+ port=int(split_uri.port or DEFAULT_PORT),
+ password=_unquote(split_uri.password),
+ )
+
+
+def open_uri(uri, mode, transport_params):
+ smart_open.utils.check_kwargs(open, transport_params)
+ parsed_uri = parse_uri(uri)
+ uri_path = parsed_uri.pop("uri_path")
+ scheme = parsed_uri.pop("scheme")
+ secure_conn = True if scheme == "ftps" else False
+ return open(uri_path, mode, secure_connection=secure_conn,
+ transport_params=transport_params, **parsed_uri)
+
+
+def convert_transport_params_to_args(transport_params):
+ supported_keywords = [
+ "timeout",
+ "source_address",
+ "encoding",
+ ]
+ unsupported_keywords = [k for k in transport_params if k not in supported_keywords]
+ kwargs = {k: v for (k, v) in transport_params.items() if k in supported_keywords}
+
+ if unsupported_keywords:
+ logger.warning(
+ "ignoring unsupported ftp keyword arguments: %r", unsupported_keywords
+ )
+
+ return kwargs
+
+
+def _connect(hostname, username, port, password, secure_connection, transport_params):
+ kwargs = convert_transport_params_to_args(transport_params)
+ if secure_connection:
+ ftp = FTP_TLS(**kwargs)
+ else:
+ ftp = FTP(**kwargs)
+ try:
+ ftp.connect(hostname, port)
+ except Exception as e:
+ logger.error("Unable to connect to FTP server: try checking the host and port!")
+ raise e
+ try:
+ ftp.login(username, password)
+ except error_reply as e:
+ logger.error("Unable to login to FTP server: try checking the username and password!")
+ raise e
+ if secure_connection:
+ ftp.prot_p()
+ return ftp
+
+
+def open(
+ path,
+ mode="r",
+ host=None,
+ user=None,
+ password=None,
+ port=DEFAULT_PORT,
+ secure_connection=False,
+ transport_params=None,
+):
+ """Open a file for reading or writing via FTP/FTPS.
+
+ Parameters
+ ----------
+ path: str
+ The path on the remote server
+ mode: str
+ Must be "rb" or "wb"
+ host: str
+ The host to connect to
+ user: str
+ The username to use for the connection
+ password: str
+ The password for the specified username
+ port: int
+ The port to connect to
+ secure_connection: bool
+ True for FTPS, False for FTP
+ transport_params: dict
+ Additional parameters for the FTP connection.
+ Currently supported parameters: timeout, source_address, encoding.
+ """
+ if not host:
+ raise ValueError("you must specify the host to connect to")
+ if not user:
+ raise ValueError("you must specify the user")
+ if not transport_params:
+ transport_params = {}
+ conn = _connect(host, user, port, password, secure_connection, transport_params)
+ mode_to_ftp_cmds = {
+ "rb": ("RETR", "rb"),
+ "wb": ("STOR", "wb"),
+ "ab": ("APPE", "wb"),
+ }
+ try:
+ ftp_mode, file_obj_mode = mode_to_ftp_cmds[mode]
+ except KeyError:
+ raise ValueError(f"unsupported mode: {mode!r}")
+ ftp_mode, file_obj_mode = mode_to_ftp_cmds[mode]
+ socket = conn.transfercmd(f"{ftp_mode} {path}")
+ fobj = socket.makefile(file_obj_mode)
+
+ def full_close(self):
+ self.orig_close()
+ self.socket.close()
+ self.conn.close()
+ fobj.orig_close = fobj.close
+ fobj.socket = socket
+ fobj.conn = conn
+ fobj.close = types.MethodType(full_close, fobj)
+ return fobj
diff --git a/smart_open/gcs.py b/smart_open/gcs.py
index 1baa4b4..1f20ea6 100644
--- a/smart_open/gcs.py
+++ b/smart_open/gcs.py
@@ -5,11 +5,10 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
-
"""Implements file-like objects for reading and writing to/from GCS."""
-import io
import logging
+import warnings
try:
import google.cloud.exceptions
@@ -25,70 +24,13 @@ from smart_open import constants
logger = logging.getLogger(__name__)
-_BINARY_TYPES = (bytes, bytearray, memoryview)
-"""Allowed binary buffer types for writing to the underlying GCS stream"""
-
-_UNKNOWN = '*'
-
SCHEME = "gs"
"""Supported scheme for GCS"""
-_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024
-"""Google requires you to upload in multiples of 256 KB, except for the last part."""
-
_DEFAULT_MIN_PART_SIZE = 50 * 1024**2
"""Default minimum part size for GCS multipart uploads"""
-DEFAULT_BUFFER_SIZE = 256 * 1024
-"""Default buffer size for working with GCS"""
-
-_UPLOAD_INCOMPLETE_STATUS_CODES = (308, )
-_UPLOAD_COMPLETE_STATUS_CODES = (200, 201)
-
-
-def _make_range_string(start, stop=None, end=None):
- #
- # GCS seems to violate RFC-2616 (see utils.make_range_string), so we
- # need a separate implementation.
- #
- # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks
- #
- if end is None:
- end = _UNKNOWN
- if stop is None:
- return 'bytes %d-/%s' % (start, end)
- return 'bytes %d-%d/%s' % (start, stop, end)
-
-
-class UploadFailedError(Exception):
- def __init__(self, message, status_code, text):
- """Raise when a multi-part upload to GCS returns a failed response status code.
-
- Parameters
- ----------
- message: str
- The error message to display.
- status_code: int
- The status code returned from the upload response.
- text: str
- The text returned from the upload response.
-
- """
- super(UploadFailedError, self).__init__(message)
- self.status_code = status_code
- self.text = text
-
-
-def _fail(response, part_num, content_length, total_size, headers):
- status_code = response.status_code
- response_text = response.text
- total_size_gb = total_size / 1024.0 ** 3
-
- msg = (
- "upload failed (status code: %(status_code)d, response text: %(response_text)s), "
- "part #%(part_num)d, %(total_size)d bytes (total %(total_size_gb).3fGB), headers: %(headers)r"
- ) % locals()
- raise UploadFailedError(msg, response.status_code, response.text)
+_DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True}
def parse_uri(uri_as_string):
@@ -105,15 +47,21 @@ def open_uri(uri, mode, transport_params):
return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs)
+def warn_deprecated(parameter_name):
+ message = f"Parameter {parameter_name} is deprecated, this parameter no-longer has any effect"
+ warnings.warn(message, UserWarning)
+
+
def open(
- bucket_id,
- blob_id,
- mode,
- buffer_size=DEFAULT_BUFFER_SIZE,
- min_part_size=_MIN_MIN_PART_SIZE,
- client=None, # type: google.cloud.storage.Client
- blob_properties=None
- ):
+ bucket_id,
+ blob_id,
+ mode,
+ buffer_size=None,
+ min_part_size=_DEFAULT_MIN_PART_SIZE,
+ client=None, # type: google.cloud.storage.Client
+ blob_properties=None,
+ blob_open_kwargs=None,
+):
"""Open an GCS blob for reading or writing.
Parameters
@@ -123,476 +71,102 @@ def open(
blob_id: str
The name of the blob within the bucket.
mode: str
- The mode for opening the object. Must be either "rb" or "wb".
- buffer_size: int, optional
- The buffer size to use when performing I/O. For reading only.
+ The mode for opening the object. Must be either "rb" or "wb".
+ buffer_size:
+ deprecated
min_part_size: int, optional
- The minimum part size for multipart uploads. For writing only.
+ The minimum part size for multipart uploads. For writing only.
client: google.cloud.storage.Client, optional
The GCS client to use when working with google-cloud-storage.
blob_properties: dict, optional
- Set properties on blob before writing. For writing only.
+ Set properties on blob before writing. For writing only.
+ blob_open_kwargs: dict, optional
+ Additional keyword arguments to propagate to the blob.open method
+ of the google-cloud-storage library.
"""
- if mode == constants.READ_BINARY:
- fileobj = Reader(
- bucket_id,
- blob_id,
- buffer_size=buffer_size,
- line_terminator=constants.BINARY_NEWLINE,
- client=client,
- )
- elif mode == constants.WRITE_BINARY:
- fileobj = Writer(
- bucket_id,
- blob_id,
- min_part_size=min_part_size,
- client=client,
- blob_properties=blob_properties,
- )
+ if blob_open_kwargs is None:
+ blob_open_kwargs = {}
+
+ if buffer_size is not None:
+ warn_deprecated('buffer_size')
+
+ if mode in (constants.READ_BINARY, 'r', 'rt'):
+ _blob = Reader(bucket=bucket_id,
+ key=blob_id,
+ client=client,
+ blob_open_kwargs=blob_open_kwargs)
+
+ elif mode in (constants.WRITE_BINARY, 'w', 'wt'):
+ _blob = Writer(bucket=bucket_id,
+ blob=blob_id,
+ min_part_size=min_part_size,
+ client=client,
+ blob_properties=blob_properties,
+ blob_open_kwargs=blob_open_kwargs)
+
else:
- raise NotImplementedError('GCS support for mode %r not implemented' % mode)
+ raise NotImplementedError(f'GCS support for mode {mode} not implemented')
- fileobj.name = blob_id
- return fileobj
+ return _blob
-class _RawReader(object):
- """Read an GCS object."""
+def Reader(bucket,
+ key,
+ buffer_size=None,
+ line_terminator=None,
+ client=None,
+ blob_open_kwargs=None):
- def __init__(self, gcs_blob, size):
- # type: (google.cloud.storage.Blob, int) -> None
- self._blob = gcs_blob
- self._size = size
- self._position = 0
+ if blob_open_kwargs is None:
+ blob_open_kwargs = {}
+ if client is None:
+ client = google.cloud.storage.Client()
+ if buffer_size is not None:
+ warn_deprecated('buffer_size')
+ if line_terminator is not None:
+ warn_deprecated('line_terminator')
- def seek(self, position):
- """Seek to the specified position (byte offset) in the GCS key.
+ bkt = client.bucket(bucket)
+ blob = bkt.get_blob(key)
- :param int position: The byte offset from the beginning of the key.
+ if blob is None:
+ raise google.cloud.exceptions.NotFound(f'blob {key} not found in {bucket}')
- Returns the position after seeking.
- """
- self._position = position
- return self._position
+ return blob.open('rb', **blob_open_kwargs)
- def read(self, size=-1):
- if self._position >= self._size:
- return b''
- binary = self._download_blob_chunk(size)
- self._position += len(binary)
- return binary
- def _download_blob_chunk(self, size):
- start = position = self._position
- if position == self._size:
- #
- # When reading, we can't seek to the first byte of an empty file.
- # Similarly, we can't seek past the last byte. Do nothing here.
- #
- binary = b''
- elif size == -1:
- binary = self._blob.download_as_bytes(start=start)
- else:
- end = position + size
- binary = self._blob.download_as_bytes(start=start, end=end)
- return binary
+def Writer(bucket,
+ blob,
+ min_part_size=None,
+ client=None,
+ blob_properties=None,
+ blob_open_kwargs=None):
+ if blob_open_kwargs is None:
+ blob_open_kwargs = {}
+ if blob_properties is None:
+ blob_properties = {}
+ if client is None:
+ client = google.cloud.storage.Client()
-class Reader(io.BufferedIOBase):
- """Reads bytes from GCS.
+ blob_open_kwargs = {**_DEFAULT_WRITE_OPEN_KWARGS, **blob_open_kwargs}
- Implements the io.BufferedIOBase interface of the standard library.
+ g_bucket = client.bucket(bucket)
+ if not g_bucket.exists():
+ raise google.cloud.exceptions.NotFound(f'bucket {bucket} not found')
- :raises google.cloud.exceptions.NotFound: Raised when the blob to read from does not exist.
+ g_blob = g_bucket.blob(
+ blob,
+ chunk_size=min_part_size,
+ )
- """
- def __init__(
- self,
- bucket,
- key,
- buffer_size=DEFAULT_BUFFER_SIZE,
- line_terminator=constants.BINARY_NEWLINE,
- client=None, # type: google.cloud.storage.Client
- ):
- if client is None:
- client = google.cloud.storage.Client()
-
- self._blob = client.bucket(bucket).get_blob(key) # type: google.cloud.storage.Blob
-
- if self._blob is None:
- raise google.cloud.exceptions.NotFound('blob %s not found in %s' % (key, bucket))
-
- self._size = self._blob.size if self._blob.size is not None else 0
-
- self._raw_reader = _RawReader(self._blob, self._size)
- self._current_pos = 0
- self._current_part_size = buffer_size
- self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size)
- self._eof = False
- self._line_terminator = line_terminator
-
- #
- # This member is part of the io.BufferedIOBase interface.
- #
- self.raw = None
-
- #
- # Override some methods from io.IOBase.
- #
- def close(self):
- """Flush and close this stream."""
- logger.debug("close: called")
- self._blob = None
- self._current_part = None
- self._raw_reader = None
-
- def readable(self):
- """Return True if the stream can be read from."""
- return True
-
- def seekable(self):
- """If False, seek(), tell() and truncate() will raise IOError.
-
- We offer only seek support, and no truncate support."""
- return True
-
- #
- # io.BufferedIOBase methods.
- #
- def detach(self):
- """Unsupported."""
- raise io.UnsupportedOperation
-
- def seek(self, offset, whence=constants.WHENCE_START):
- """Seek to the specified position.
-
- :param int offset: The offset in bytes.
- :param int whence: Where the offset is from.
-
- Returns the position after seeking."""
- logger.debug('seeking to offset: %r whence: %r', offset, whence)
- if whence not in constants.WHENCE_CHOICES:
- raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES)
-
- if whence == constants.WHENCE_START:
- new_position = offset
- elif whence == constants.WHENCE_CURRENT:
- new_position = self._current_pos + offset
- else:
- new_position = self._size + offset
- new_position = smart_open.utils.clamp(new_position, 0, self._size)
- self._current_pos = new_position
- self._raw_reader.seek(new_position)
- logger.debug('current_pos: %r', self._current_pos)
-
- self._current_part.empty()
- self._eof = self._current_pos == self._size
- return self._current_pos
-
- def tell(self):
- """Return the current position within the file."""
- return self._current_pos
-
- def truncate(self, size=None):
- """Unsupported."""
- raise io.UnsupportedOperation
-
- def read(self, size=-1):
- """Read up to size bytes from the object and return them."""
- if size == 0:
- return b''
- elif size < 0:
- self._current_pos = self._size
- return self._read_from_buffer() + self._raw_reader.read()
-
- #
- # Return unused data first
- #
- if len(self._current_part) >= size:
- return self._read_from_buffer(size)
-
- #
- # If the stream is finished, return what we have.
- #
- if self._eof:
- return self._read_from_buffer()
-
- #
- # Fill our buffer to the required size.
- #
- self._fill_buffer(size)
- return self._read_from_buffer(size)
-
- def read1(self, size=-1):
- """This is the same as read()."""
- return self.read(size=size)
-
- def readinto(self, b):
- """Read up to len(b) bytes into b, and return the number of bytes
- read."""
- data = self.read(len(b))
- if not data:
- return 0
- b[:len(data)] = data
- return len(data)
-
- def readline(self, limit=-1):
- """Read up to and including the next newline. Returns the bytes read."""
- if limit != -1:
- raise NotImplementedError('limits other than -1 not implemented yet')
- the_line = io.BytesIO()
- while not (self._eof and len(self._current_part) == 0):
- #
- # In the worst case, we're reading the unread part of self._current_part
- # twice here, once in the if condition and once when calling index.
- #
- # This is sub-optimal, but better than the alternative: wrapping
- # .index in a try..except, because that is slower.
- #
- remaining_buffer = self._current_part.peek()
- if self._line_terminator in remaining_buffer:
- next_newline = remaining_buffer.index(self._line_terminator)
- the_line.write(self._read_from_buffer(next_newline + 1))
- break
- else:
- the_line.write(self._read_from_buffer())
- self._fill_buffer()
- return the_line.getvalue()
-
- #
- # Internal methods.
- #
- def _read_from_buffer(self, size=-1):
- """Remove at most size bytes from our buffer and return them."""
- # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part))
- size = size if size >= 0 else len(self._current_part)
- part = self._current_part.read(size)
- self._current_pos += len(part)
- # logger.debug('part: %r', part)
- return part
-
- def _fill_buffer(self, size=-1):
- size = size if size >= 0 else self._current_part._chunk_size
- while len(self._current_part) < size and not self._eof:
- bytes_read = self._current_part.fill(self._raw_reader)
- if bytes_read == 0:
- logger.debug('reached EOF while filling buffer')
- self._eof = True
-
- def __str__(self):
- return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name)
-
- def __repr__(self):
- return "%s(bucket=%r, blob=%r, buffer_size=%r)" % (
- self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._current_part_size,
- )
-
-
-class Writer(io.BufferedIOBase):
- """Writes bytes to GCS.
-
- Implements the io.BufferedIOBase interface of the standard library."""
-
- def __init__(
- self,
- bucket,
- blob,
- min_part_size=_DEFAULT_MIN_PART_SIZE,
- client=None, # type: google.cloud.storage.Client
- blob_properties=None,
- ):
- if client is None:
- client = google.cloud.storage.Client()
- self._client = client
- self._blob = self._client.bucket(bucket).blob(blob) # type: google.cloud.storage.Blob
- assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB'
- assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB'
- self._min_part_size = min_part_size
-
- self._total_size = 0
- self._total_parts = 0
- self._bytes_uploaded = 0
- self._current_part = io.BytesIO()
-
- self._session = google.auth.transport.requests.AuthorizedSession(client._credentials)
-
- if blob_properties:
- for k, v in blob_properties.items():
- setattr(self._blob, k, v)
-
- #
- # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload#start-resumable
- #
- self._resumable_upload_url = self._blob.create_resumable_upload_session()
-
- #
- # This member is part of the io.BufferedIOBase interface.
- #
- self.raw = None
-
- def flush(self):
- pass
-
- #
- # Override some methods from io.IOBase.
- #
- def close(self):
- logger.debug("closing")
- if not self.closed:
- if self._total_size == 0: # empty files
- self._upload_empty_part()
- else:
- self._upload_part(is_last=True)
- self._client = None
- logger.debug("successfully closed")
-
- @property
- def closed(self):
- return self._client is None
-
- def writable(self):
- """Return True if the stream supports writing."""
- return True
-
- def seekable(self):
- """If False, seek(), tell() and truncate() will raise IOError.
-
- We offer only tell support, and no seek or truncate support."""
- return True
-
- def seek(self, offset, whence=constants.WHENCE_START):
- """Unsupported."""
- raise io.UnsupportedOperation
-
- def truncate(self, size=None):
- """Unsupported."""
- raise io.UnsupportedOperation
-
- def tell(self):
- """Return the current stream position."""
- return self._total_size
-
- #
- # io.BufferedIOBase methods.
- #
- def detach(self):
- raise io.UnsupportedOperation("detach() not supported")
-
- def write(self, b):
- """Write the given bytes (binary string) to the GCS file.
-
- There's buffering happening under the covers, so this may not actually
- do any HTTP transfer right away."""
-
- if not isinstance(b, _BINARY_TYPES):
- raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b)))
-
- self._current_part.write(b)
- self._total_size += len(b)
-
- #
- # If the size of this part is precisely equal to the minimum part size,
- # we don't perform the actual write now, and wait until we see more data.
- # We do this because the very last part of the upload must be handled slightly
- # differently (see comments in the _upload_part method).
- #
- if self._current_part.tell() > self._min_part_size:
- self._upload_part()
-
- return len(b)
-
- def terminate(self):
- """Cancel the underlying resumable upload."""
- #
- # https://cloud.google.com/storage/docs/xml-api/resumable-upload#example_cancelling_an_upload
- #
- self._session.delete(self._resumable_upload_url)
-
- #
- # Internal methods.
- #
- def _upload_part(self, is_last=False):
- part_num = self._total_parts + 1
-
- #
- # Here we upload the largest amount possible given GCS's restriction
- # of parts being multiples of 256kB, except for the last one.
- #
- # A final upload of 0 bytes does not work, so we need to guard against
- # this edge case. This results in occasionally keeping an additional
- # 256kB in the buffer after uploading a part, but until this is fixed
- # on Google's end there is no other option.
- #
- # https://stackoverflow.com/questions/60230631/upload-zero-size-final-part-to-google-cloud-storage-resumable-upload
- #
- content_length = self._current_part.tell()
- remainder = content_length % self._min_part_size
- if is_last:
- end = self._bytes_uploaded + content_length
- elif remainder == 0:
- content_length -= _REQUIRED_CHUNK_MULTIPLE
- end = None
- else:
- content_length -= remainder
- end = None
-
- range_stop = self._bytes_uploaded + content_length - 1
- content_range = _make_range_string(self._bytes_uploaded, range_stop, end=end)
- headers = {
- 'Content-Length': str(content_length),
- 'Content-Range': content_range,
- }
- logger.info(
- "uploading part #%i, %i bytes (total %.3fGB) headers %r",
- part_num, content_length, range_stop / 1024.0 ** 3, headers,
- )
- self._current_part.seek(0)
- response = self._session.put(
- self._resumable_upload_url,
- data=self._current_part.read(content_length),
- headers=headers,
- )
-
- if is_last:
- expected = _UPLOAD_COMPLETE_STATUS_CODES
- else:
- expected = _UPLOAD_INCOMPLETE_STATUS_CODES
- if response.status_code not in expected:
- _fail(response, part_num, content_length, self._total_size, headers)
- logger.debug("upload of part #%i finished" % part_num)
-
- self._total_parts += 1
- self._bytes_uploaded += content_length
-
- #
- # For the last part, the below _current_part handling is a NOOP.
- #
- self._current_part = io.BytesIO(self._current_part.read())
- self._current_part.seek(0, io.SEEK_END)
-
- def _upload_empty_part(self):
- logger.debug("creating empty file")
- headers = {'Content-Length': '0'}
- response = self._session.put(self._resumable_upload_url, headers=headers)
- if response.status_code not in _UPLOAD_COMPLETE_STATUS_CODES:
- _fail(response, self._total_parts + 1, 0, self._total_size, headers)
-
- self._total_parts += 1
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if exc_type is not None:
- self.terminate()
- else:
- self.close()
-
- def __str__(self):
- return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name)
-
- def __repr__(self):
- return "%s(bucket=%r, blob=%r, min_part_size=%r)" % (
- self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._min_part_size,
- )
+ for k, v in blob_properties.items():
+ setattr(g_blob, k, v)
+
+ _blob = g_blob.open('wb', **blob_open_kwargs)
+
+ # backwards-compatiblity, was deprecated upstream https://cloud.google.com/storage/docs/resumable-uploads
+ _blob.terminate = lambda: None
+
+ return _blob
diff --git a/smart_open/hdfs.py b/smart_open/hdfs.py
index a4d892c..a247d3e 100644
--- a/smart_open/hdfs.py
+++ b/smart_open/hdfs.py
@@ -23,24 +23,26 @@ from smart_open import utils
logger = logging.getLogger(__name__)
-SCHEME = 'hdfs'
+SCHEMES = ('hdfs', 'viewfs')
URI_EXAMPLES = (
'hdfs:///path/file',
'hdfs://path/file',
+ 'viewfs:///path/file',
+ 'viewfs://path/file',
)
def parse_uri(uri_as_string):
split_uri = urllib.parse.urlsplit(uri_as_string)
- assert split_uri.scheme == SCHEME
+ assert split_uri.scheme in SCHEMES
uri_path = split_uri.netloc + split_uri.path
uri_path = "/" + uri_path.lstrip("/")
if not uri_path:
raise RuntimeError("invalid HDFS URI: %r" % uri_as_string)
- return dict(scheme=SCHEME, uri_path=uri_path)
+ return dict(scheme=split_uri.scheme, uri_path=uri_path)
def open_uri(uri, mode, transport_params):
diff --git a/smart_open/http.py b/smart_open/http.py
index e4439bd..7bbbe6f 100644
--- a/smart_open/http.py
+++ b/smart_open/http.py
@@ -49,7 +49,8 @@ def open_uri(uri, mode, transport_params):
return open(uri, mode, **kwargs)
-def open(uri, mode, kerberos=False, user=None, password=None, headers=None, timeout=None):
+def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
+ headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE):
"""Implement streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
@@ -66,10 +67,14 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
+ cert: str/tuple, optional
+ if String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’)
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
set this variable to an empty dict, ``{}``.
+ buffer_size: int, optional
+ The buffer size to use when performing I/O.
Note
----
@@ -79,8 +84,9 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time
"""
if mode == constants.READ_BINARY:
fobj = SeekableBufferedInputBase(
- uri, mode, kerberos=kerberos,
- user=user, password=password, headers=headers, timeout=timeout,
+ uri, mode, buffer_size=buffer_size, kerberos=kerberos,
+ user=user, password=password, cert=cert,
+ headers=headers, timeout=timeout,
)
fobj.name = os.path.basename(urllib.parse.urlparse(uri).path)
return fobj
@@ -90,7 +96,8 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time
class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
- kerberos=False, user=None, password=None, headers=None, timeout=None):
+ kerberos=False, user=None, password=None, cert=None,
+ headers=None, timeout=None):
if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
@@ -112,6 +119,7 @@ class BufferedInputBase(io.BufferedIOBase):
self.response = requests.get(
url,
auth=auth,
+ cert=cert,
stream=True,
headers=self.headers,
timeout=self.timeout,
@@ -204,13 +212,15 @@ class BufferedInputBase(io.BufferedIOBase):
class SeekableBufferedInputBase(BufferedInputBase):
"""
Implement seekable streamed reader from a web site.
- Supports Kerberos and Basic HTTP authentication.
+ Supports Kerberos, client certificate and Basic HTTP authentication.
"""
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
- kerberos=False, user=None, password=None, headers=None, timeout=None):
+ kerberos=False, user=None, password=None, cert=None,
+ headers=None, timeout=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
+ If cert is set, will try to use a client certificate
Otherwise, will try to use "basic" HTTP authentication via username/password.
If none of those are set, will connect unauthenticated.
@@ -230,6 +240,7 @@ class SeekableBufferedInputBase(BufferedInputBase):
else:
self.headers = headers
+ self.cert = cert
self.timeout = timeout
self.buffer_size = buffer_size
@@ -325,6 +336,7 @@ class SeekableBufferedInputBase(BufferedInputBase):
self.url,
auth=self.auth,
stream=True,
+ cert=self.cert,
headers=self.headers,
timeout=self.timeout,
)
diff --git a/smart_open/s3.py b/smart_open/s3.py
index b7470dd..b211163 100644
--- a/smart_open/s3.py
+++ b/smart_open/s3.py
@@ -1189,18 +1189,31 @@ def iter_bucket(
with smart_open.concurrency.create_pool(processes=workers) as pool:
result_iterator = pool.imap_unordered(download_key, key_iterator)
- for key_no, (key, content) in enumerate(result_iterator):
- if True or key_no % 1000 == 0:
- logger.info(
- "yielding key #%i: %s, size %i (total %.1fMB)",
- key_no, key, len(content), total_size / 1024.0 ** 2
- )
- yield key, content
- total_size += len(content)
-
- if key_limit is not None and key_no + 1 >= key_limit:
- # we were asked to output only a limited number of keys => we're done
+ key_no = 0
+ while True:
+ try:
+ (key, content) = result_iterator.__next__()
+ if key_no % 1000 == 0:
+ logger.info(
+ "yielding key #%i: %s, size %i (total %.1fMB)",
+ key_no, key, len(content), total_size / 1024.0 ** 2
+ )
+ yield key, content
+ total_size += len(content)
+ if key_limit is not None and key_no + 1 >= key_limit:
+ # we were asked to output only a limited number of keys => we're done
+ break
+ except botocore.exceptions.ClientError as err:
+ #
+ # ignore 404 not found errors: they mean the object was deleted
+ # after we listed the contents of the bucket, but before we
+ # downloaded the object.
+ #
+ if not ('Error' in err.response and err.response['Error'].get('Code') == '404'):
+ raise err
+ except StopIteration:
break
+ key_no += 1
logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))
@@ -1240,7 +1253,7 @@ def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs):
raise ValueError('bucket_name may not be None')
#
- # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html#multithreading-and-multiprocessing
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html#multithreading-or-multiprocessing-with-resources
#
session = boto3.session.Session(**session_kwargs)
s3 = session.resource('s3')
diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py
index cd80f47..b827a22 100644
--- a/smart_open/smart_open_lib.py
+++ b/smart_open/smart_open_lib.py
@@ -106,8 +106,7 @@ def open(
newline=None,
closefd=True,
opener=None,
- ignore_ext=False,
- compression=None,
+ compression=so_compression.INFER_FROM_EXTENSION,
transport_params=None,
):
r"""Open the URI object, returning a file-like object.
@@ -138,11 +137,8 @@ def open(
Mimicks built-in open parameter of the same name. Ignored.
opener: object, optional
Mimicks built-in open parameter of the same name. Ignored.
- ignore_ext: boolean, optional
- Disable transparent compression/decompression based on the file extension.
compression: str, optional (see smart_open.compression.get_supported_compression_types)
Explicitly specify the compression/decompression behavior.
- If you specify this parameter, then ignore_ext must not be specified.
transport_params: dict, optional
Additional parameters for the transport layer (see notes below).
@@ -172,15 +168,8 @@ def open(
if not isinstance(mode, str):
raise TypeError('mode should be a string')
- if compression and ignore_ext:
- raise ValueError('ignore_ext and compression parameters are mutually exclusive')
- elif compression and compression not in so_compression.get_supported_compression_types():
+ if compression not in so_compression.get_supported_compression_types():
raise ValueError(f'invalid compression type: {compression}')
- elif ignore_ext:
- compression = so_compression.NO_COMPRESSION
- warnings.warn("'ignore_ext' will be deprecated in a future release", PendingDeprecationWarning)
- elif compression is None:
- compression = so_compression.INFER_FROM_EXTENSION
if transport_params is None:
transport_params = {}
@@ -246,6 +235,19 @@ def open(
else:
decoded = decompressed
+ #
+ # There are some useful methods in the binary readers, e.g. to_boto3, that get
+ # hidden by the multiple layers of wrapping we just performed. Promote
+ # them so they are visible to the user.
+ #
+ if decoded != binary:
+ promoted_attrs = ['to_boto3']
+ for attr in promoted_attrs:
+ try:
+ setattr(decoded, attr, getattr(binary, attr))
+ except AttributeError:
+ pass
+
return decoded
@@ -379,16 +381,16 @@ def _open_binary_stream(uri, mode, transport_params):
#
raise NotImplementedError('unsupported mode: %r' % mode)
- if hasattr(uri, 'read'):
- # simply pass-through if already a file-like
- # we need to return something as the file name, but we don't know what
- # so we probe for uri.name (e.g., this works with open() or tempfile.NamedTemporaryFile)
- # if the value ends with COMPRESSED_EXT, we will note it in compression_wrapper()
- # if there is no such an attribute, we return "unknown" - this
- # effectively disables any compression
- if not hasattr(uri, 'name'):
- uri.name = getattr(uri, 'name', 'unknown')
- return uri
+ if isinstance(uri, int):
+ #
+ # We're working with a file descriptor. If we open it, its name is
+ # just the integer value, which isn't helpful. Unfortunately, there's
+ # no easy cross-platform way to go from a file descriptor to the filename,
+ # so we just give up here. The user will have to handle their own
+ # compression, etc. explicitly.
+ #
+ fobj = _builtin_open(uri, mode, closefd=False)
+ return fobj
if not isinstance(uri, str):
raise TypeError("don't know how to handle uri %s" % repr(uri))
@@ -481,7 +483,7 @@ def smart_open(
# For completeness, the main differences of the old smart_open function:
#
# 1. Default mode was read binary (mode='rb')
- # 2. ignore_ext parameter was called ignore_extension
+ # 2. compression parameter was called ignore_extension
# 3. Transport parameters were passed directly as kwargs
#
url = 'https://github.com/RaRe-Technologies/smart_open/blob/develop/MIGRATING_FROM_OLDER_VERSIONS.rst'
@@ -493,7 +495,10 @@ def smart_open(
message = 'This function is deprecated. See %s for more information' % url
warnings.warn(message, category=DeprecationWarning)
- ignore_ext = ignore_extension
+ if ignore_extension:
+ compression = so_compression.NO_COMPRESSION
+ else:
+ compression = so_compression.INFER_FROM_EXTENSION
del kwargs, url, message, ignore_extension
return open(**locals())
diff --git a/smart_open/ssh.py b/smart_open/ssh.py
index fa762eb..4b31893 100644
--- a/smart_open/ssh.py
+++ b/smart_open/ssh.py
@@ -25,7 +25,11 @@ Similarly, from a command line::
import getpass
import logging
import urllib.parse
-import warnings
+
+try:
+ import paramiko
+except ImportError:
+ MISSING_DEPS = True
import smart_open.utils
@@ -74,29 +78,15 @@ def open_uri(uri, mode, transport_params):
return open(uri_path, mode, transport_params=transport_params, **parsed_uri)
-def _connect(hostname, username, port, password, transport_params):
- try:
- import paramiko
- except ImportError:
- warnings.warn(
- 'paramiko missing, opening SSH/SCP/SFTP paths will be disabled. '
- '`pip install paramiko` to suppress'
- )
- raise
-
- key = (hostname, username)
- ssh = _SSH.get(key)
- if ssh is None:
- ssh = _SSH[key] = paramiko.client.SSHClient()
- ssh.load_system_host_keys()
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- kwargs = transport_params.get('connect_kwargs', {}).copy()
- # if 'key_filename' is present in transport_params, then I do not
- # overwrite the credentials.
- if 'key_filename' not in kwargs:
- kwargs.setdefault('password', password)
- kwargs.setdefault('username', username)
- ssh.connect(hostname, port, **kwargs)
+def _connect_ssh(hostname, username, port, password, transport_params):
+ ssh = paramiko.SSHClient()
+ ssh.load_system_host_keys()
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ kwargs = transport_params.get('connect_kwargs', {}).copy()
+ if 'key_filename' not in kwargs:
+ kwargs.setdefault('password', password)
+ kwargs.setdefault('username', username)
+ ssh.connect(hostname, port, **kwargs)
return ssh
@@ -142,8 +132,30 @@ def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT,
if not transport_params:
transport_params = {}
- conn = _connect(host, user, port, password, transport_params)
- sftp_client = conn.get_transport().open_sftp_client()
+ key = (host, user)
+
+ attempts = 2
+ for attempt in range(attempts):
+ try:
+ ssh = _SSH[key]
+ except KeyError:
+ ssh = _SSH[key] = _connect_ssh(host, user, port, password, transport_params)
+
+ try:
+ transport = ssh.get_transport()
+ sftp_client = transport.open_sftp_client()
+ break
+ except paramiko.SSHException as ex:
+ connection_timed_out = ex.args and ex.args[0] == 'SSH session not active'
+ if attempt == attempts - 1 or not connection_timed_out:
+ raise
+
+ #
+ # Try again. Delete the connection from the cache to force a
+ # reconnect in the next attempt.
+ #
+ del _SSH[key]
+
fobj = sftp_client.open(path, mode)
fobj.name = path
return fobj
diff --git a/smart_open/tests/test_azure.py b/smart_open/tests/test_azure.py
index 1cce381..3c668aa 100644
--- a/smart_open/tests/test_azure.py
+++ b/smart_open/tests/test_azure.py
@@ -51,10 +51,12 @@ class FakeBlobClient(object):
self.__contents = io.BytesIO()
self._staged_contents = {}
- def commit_block_list(self, block_list):
+ def commit_block_list(self, block_list, metadata=None):
data = b''.join([self._staged_contents[block_blob['id']] for block_blob in block_list])
self.__contents = io.BytesIO(data)
- self.set_blob_metadata(dict(size=len(data)))
+ metadata = metadata or {}
+ metadata.update({"size": len(data)})
+ self.set_blob_metadata(metadata)
self._container_client.register_blob_client(self)
def delete_blob(self):
@@ -526,6 +528,31 @@ class ReaderTest(unittest.TestCase):
self.assertEqual(data, content)
+ def test_read_container_client(self):
+ content = "spirits in the material world".encode("utf-8")
+ blob_name = "test_read_container_client_%s" % BLOB_NAME
+ put_to_container(blob_name, contents=content)
+
+ container_client = CLIENT.get_container_client(CONTAINER_NAME)
+
+ with smart_open.azure.Reader(CONTAINER_NAME, blob_name, container_client) as fin:
+ data = fin.read(100)
+
+ assert data == content
+
+ def test_read_blob_client(self):
+ content = "walking on the moon".encode("utf-8")
+ blob_name = "test_read_blob_client_%s" % BLOB_NAME
+ put_to_container(blob_name, contents=content)
+
+ container_client = CLIENT.get_container_client(CONTAINER_NAME)
+ blob_client = container_client.get_blob_client(blob_name)
+
+ with smart_open.azure.Reader(CONTAINER_NAME, blob_name, blob_client) as fin:
+ data = fin.read(100)
+
+ assert data == content
+
class WriterTest(unittest.TestCase):
"""Test writing into Azure Blob files."""
@@ -548,6 +575,49 @@ class WriterTest(unittest.TestCase):
))
self.assertEqual(output, [test_string])
+ def test_write_container_client(self):
+ """Does writing into Azure Blob Storage work correctly?"""
+ test_string = u"Hiszékeny Öngyilkos Vasárnap".encode('utf8')
+ blob_name = "test_write_container_client_%s" % BLOB_NAME
+
+ container_client = CLIENT.get_container_client(CONTAINER_NAME)
+
+ with smart_open.azure.Writer(CONTAINER_NAME, blob_name, container_client) as fout:
+ fout.write(test_string)
+
+ output = list(smart_open.open(
+ "azure://%s/%s" % (CONTAINER_NAME, blob_name),
+ "rb",
+ transport_params=dict(client=container_client),
+ ))
+ assert output == [test_string]
+
+ def test_write_blob_client(self):
+ """Does writing into Azure Blob Storage work correctly?"""
+ test_string = u"žluťoučký koníček".encode('utf8')
+ blob_name = "test_write_blob_client_%s" % BLOB_NAME
+
+ container_client = CLIENT.get_container_client(CONTAINER_NAME)
+ blob_client = container_client.get_blob_client(blob_name)
+
+ with smart_open.open(
+ "azure://%s/%s" % (CONTAINER_NAME, blob_name),
+ "wb",
+ transport_params={
+ "client": blob_client, "blob_kwargs": {"metadata": {"name": blob_name}}
+ },
+ ) as fout:
+ fout.write(test_string)
+
+ self.assertEqual(blob_client.get_blob_properties()["name"], blob_name)
+
+ output = list(smart_open.open(
+ "azure://%s/%s" % (CONTAINER_NAME, blob_name),
+ "rb",
+ transport_params=dict(client=CLIENT),
+ ))
+ self.assertEqual(output, [test_string])
+
def test_incorrect_input(self):
"""Does azure write fail on incorrect input?"""
blob_name = "test_incorrect_input_%s" % BLOB_NAME
diff --git a/smart_open/tests/test_compression.py b/smart_open/tests/test_compression.py
new file mode 100644
index 0000000..4b93f63
--- /dev/null
+++ b/smart_open/tests/test_compression.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
+#
+# This code is distributed under the terms and conditions
+# from the MIT License (MIT).
+#
+import io
+import gzip
+import pytest
+
+import smart_open.compression
+
+
+plain = 'доброе утро планета!'.encode()
+
+
+def label(thing, name):
+ setattr(thing, 'name', name)
+ return thing
+
+
+@pytest.mark.parametrize(
+ 'fileobj,compression,filename',
+ [
+ (io.BytesIO(plain), 'disable', None),
+ (io.BytesIO(plain), 'disable', ''),
+ (io.BytesIO(plain), 'infer_from_extension', 'file.txt'),
+ (io.BytesIO(plain), 'infer_from_extension', 'file.TXT'),
+ (io.BytesIO(plain), '.unknown', ''),
+ (io.BytesIO(gzip.compress(plain)), 'infer_from_extension', 'file.gz'),
+ (io.BytesIO(gzip.compress(plain)), 'infer_from_extension', 'file.GZ'),
+ (label(io.BytesIO(gzip.compress(plain)), 'file.gz'), 'infer_from_extension', ''),
+ (io.BytesIO(gzip.compress(plain)), '.gz', 'file.gz'),
+ ]
+)
+def test_compression_wrapper_read(fileobj, compression, filename):
+ wrapped = smart_open.compression.compression_wrapper(fileobj, 'rb', compression, filename)
+ assert wrapped.read() == plain
diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py
index f528ca7..d63b20c 100644
--- a/smart_open/tests/test_gcs.py
+++ b/smart_open/tests/test_gcs.py
@@ -5,12 +5,10 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
-import gzip
-import inspect
+
import io
import logging
import os
-import time
import uuid
import unittest
from unittest import mock
@@ -28,13 +26,6 @@ BLOB_NAME = 'test-blob'
WRITE_BLOB_NAME = 'test-write-blob'
DISABLE_MOCKS = os.environ.get('SO_DISABLE_GCS_MOCKS') == "1"
-RESUMABLE_SESSION_URI_TEMPLATE = (
- 'https://www.googleapis.com/upload/storage/v1/b/'
- '%(bucket)s'
- '/o?uploadType=resumable&upload_id='
- '%(upload_id)s'
-)
-
logger = logging.getLogger(__name__)
@@ -54,8 +45,8 @@ class FakeBucket(object):
#
self.client.register_bucket(self)
- def blob(self, blob_id):
- return self.blobs.get(blob_id, FakeBlob(blob_id, self))
+ def blob(self, blob_id, **kwargs):
+ return self.blobs.get(blob_id, FakeBlob(blob_id, self, **kwargs))
def delete(self):
self.client.delete_bucket(self)
@@ -133,46 +124,28 @@ class FakeBucketTest(unittest.TestCase):
class FakeBlob(object):
- def __init__(self, name, bucket):
+ def __init__(self, name, bucket, **kwargs):
self.name = name
self._bucket = bucket # type: FakeBucket
self._exists = False
self.__contents = io.BytesIO()
-
+ self.__contents.close = lambda: None
self._create_if_not_exists()
- def create_resumable_upload_session(self):
- resumeable_upload_url = RESUMABLE_SESSION_URI_TEMPLATE % dict(
- bucket=self._bucket.name,
- upload_id=str(uuid.uuid4()),
- )
- upload = FakeBlobUpload(resumeable_upload_url, self)
- self._bucket.register_upload(upload)
- return resumeable_upload_url
+ self.open = mock.Mock(side_effect=self._mock_open)
+
+ def _mock_open(self, mode, *args, **kwargs):
+ if mode.startswith('r'):
+ self.__contents.seek(0)
+ return self.__contents
def delete(self):
self._bucket.delete_blob(self)
self._exists = False
- def download_as_bytes(self, start=0, end=None):
- # mimics Google's API by returning bytes
- # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.download_as_bytes
- if end is None:
- end = self.__contents.tell()
- self.__contents.seek(start)
- return self.__contents.read(end - start)
-
def exists(self, client=None):
return self._exists
- def upload_from_string(self, data):
- # mimics Google's API by accepting bytes or str, despite the method name
- # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.upload_from_string
- if isinstance(data, str):
- data = bytes(data, 'utf8')
- self.__contents = io.BytesIO(data)
- self.__contents.seek(0, io.SEEK_END)
-
def write(self, data):
self.upload_from_string(data)
@@ -191,52 +164,8 @@ class FakeBlob(object):
self._exists = True
-class FakeBlobTest(unittest.TestCase):
- def setUp(self):
- self.client = FakeClient()
- self.bucket = FakeBucket(self.client, 'test-bucket')
-
- def test_create_resumable_upload_session(self):
- blob = FakeBlob('fake-blob', self.bucket)
- resumable_upload_url = blob.create_resumable_upload_session()
- self.assertTrue(resumable_upload_url in self.client.uploads)
-
- def test_delete(self):
- blob = FakeBlob('fake-blob', self.bucket)
- blob.delete()
- self.assertFalse(blob.exists())
- self.assertEqual(self.bucket.list_blobs(), [])
-
- def test_upload_download(self):
- blob = FakeBlob('fake-blob', self.bucket)
- contents = b'test'
- blob.upload_from_string(contents)
- self.assertEqual(blob.download_as_bytes(), b'test')
- self.assertEqual(blob.download_as_bytes(start=2), b'st')
- self.assertEqual(blob.download_as_bytes(end=2), b'te')
- self.assertEqual(blob.download_as_bytes(start=2, end=3), b's')
-
- def test_size(self):
- blob = FakeBlob('fake-blob', self.bucket)
- self.assertEqual(blob.size, None)
- blob.upload_from_string(b'test')
- self.assertEqual(blob.size, 4)
-
-
-class FakeCredentials(object):
- def __init__(self, client):
- self.client = client # type: FakeClient
-
- def before_request(self, *args, **kwargs):
- pass
-
-
class FakeClient(object):
- def __init__(self, credentials=None):
- if credentials is None:
- credentials = FakeCredentials(self)
- self._credentials = credentials # type: FakeCredentials
- self.uploads = OrderedDict()
+ def __init__(self):
self.__buckets = OrderedDict()
def bucket(self, bucket_id):
@@ -291,617 +220,46 @@ class FakeClientTest(unittest.TestCase):
self.assertEqual(actual, bucket)
-class FakeBlobUpload(object):
- def __init__(self, url, blob):
- self.url = url
- self.blob = blob # type: FakeBlob
- self._finished = False
- self.__contents = io.BytesIO()
-
- def write(self, data):
- self.__contents.write(data)
-
- def finish(self):
- if not self._finished:
- self.__contents.seek(0)
- data = self.__contents.read()
- self.blob.upload_from_string(data)
- self._finished = True
-
- def terminate(self):
- self.blob.delete()
- self.__contents = None
-
-
-class FakeResponse(object):
- def __init__(self, status_code=200, text=None):
- self.status_code = status_code
- self.text = text
-
+def get_test_bucket(client):
+ return client.bucket(BUCKET_NAME)
-class FakeAuthorizedSession(object):
- def __init__(self, credentials):
- self._credentials = credentials # type: FakeCredentials
- def delete(self, upload_url):
- upload = self._credentials.client.uploads.pop(upload_url)
- upload.terminate()
-
- def put(self, url, data=None, headers=None):
- upload = self._credentials.client.uploads[url]
-
- if data is not None:
- if hasattr(data, 'read'):
- upload.write(data.read())
- else:
- upload.write(data)
- if not headers.get('Content-Range', '').endswith(smart_open.gcs._UNKNOWN):
- upload.finish()
- return FakeResponse(200)
- return FakeResponse(smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODES[0])
-
- @staticmethod
- def _blob_with_url(url, client):
- # type: (str, FakeClient) -> FakeBlobUpload
- return client.uploads.get(url)
-
-
-class FakeAuthorizedSessionTest(unittest.TestCase):
- def setUp(self):
- self.client = FakeClient()
- self.credentials = FakeCredentials(self.client)
- self.session = FakeAuthorizedSession(self.credentials)
- self.bucket = FakeBucket(self.client, 'test-bucket')
- self.blob = FakeBlob('test-blob', self.bucket)
- self.upload_url = self.blob.create_resumable_upload_session()
-
- def test_delete(self):
- self.session.delete(self.upload_url)
- self.assertFalse(self.blob.exists())
- self.assertDictEqual(self.client.uploads, {})
-
- def test_unfinished_put_does_not_write_to_blob(self):
- data = io.BytesIO(b'test')
- headers = {
- 'Content-Range': 'bytes 0-3/*',
- 'Content-Length': str(4),
- }
- response = self.session.put(self.upload_url, data, headers=headers)
- self.assertIn(response.status_code, smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODES)
- self.session._blob_with_url(self.upload_url, self.client)
- blob_contents = self.blob.download_as_bytes()
- self.assertEqual(blob_contents, b'')
-
- def test_finished_put_writes_to_blob(self):
- data = io.BytesIO(b'test')
- headers = {
- 'Content-Range': 'bytes 0-3/4',
- 'Content-Length': str(4),
- }
- response = self.session.put(self.upload_url, data, headers=headers)
- self.assertEqual(response.status_code, 200)
- self.session._blob_with_url(self.upload_url, self.client)
- blob_contents = self.blob.download_as_bytes()
- data.seek(0)
- self.assertEqual(blob_contents, data.read())
-
-
-if DISABLE_MOCKS:
- storage_client = google.cloud.storage.Client()
-else:
- storage_client = FakeClient()
-
-
-def get_bucket():
- return storage_client.bucket(BUCKET_NAME)
-
-
-def get_blob():
- bucket = get_bucket()
- return bucket.blob(BLOB_NAME)
-
-
-def cleanup_bucket():
- bucket = get_bucket()
+def cleanup_test_bucket(client):
+ bucket = get_test_bucket(client)
blobs = bucket.list_blobs()
for blob in blobs:
blob.delete()
-def put_to_bucket(contents, num_attempts=12, sleep_time=5):
- logger.debug('%r', locals())
-
- #
- # In real life, it can take a few seconds for the bucket to become ready.
- # If we try to write to the key while the bucket while it isn't ready, we
- # will get a StorageError: NotFound.
- #
- for attempt in range(num_attempts):
- try:
- blob = get_blob()
- blob.upload_from_string(contents)
- return
- except google.cloud.exceptions.NotFound as err:
- logger.error('caught %r, retrying', err)
- time.sleep(sleep_time)
-
- assert False, 'failed to create bucket %s after %d attempts' % (BUCKET_NAME, num_attempts)
-
-
-def mock_gcs(class_or_func):
- """Mock all methods of a class or a function."""
- if inspect.isclass(class_or_func):
- for attr in class_or_func.__dict__:
- if callable(getattr(class_or_func, attr)):
- setattr(class_or_func, attr, mock_gcs_func(getattr(class_or_func, attr)))
- return class_or_func
- else:
- return mock_gcs_func(class_or_func)
-
-
-def mock_gcs_func(func):
- """Mock the function and provide additional required arguments."""
- assert callable(func), '%r is not a callable function' % func
-
- def inner(*args, **kwargs):
- #
- # Is it a function or a method? The latter requires a self parameter.
- #
- signature = inspect.signature(func)
-
- fake_session = FakeAuthorizedSession(storage_client._credentials)
- patched_client = mock.patch(
- 'google.cloud.storage.Client',
- return_value=storage_client,
- )
- patched_session = mock.patch(
- 'google.auth.transport.requests.AuthorizedSession',
- return_value=fake_session,
- )
-
- with patched_client, patched_session:
- if not hasattr(signature, 'self'):
- return func(*args, **kwargs)
- else:
- return func(signature.self, *args, **kwargs)
-
- return inner
-
-
-def maybe_mock_gcs(func):
- if DISABLE_MOCKS:
- return func
- else:
- return mock_gcs(func)
-
-
-@maybe_mock_gcs
-def setUpModule(): # noqa
- """Called once by unittest when initializing this module. Set up the
- test GCS bucket.
- """
- storage_client.create_bucket(BUCKET_NAME)
-
-
-@maybe_mock_gcs
-def tearDownModule(): # noqa
- """Called once by unittest when tearing down this module. Empty and
- removes the test GCS bucket.
- """
- try:
- bucket = get_bucket()
- bucket.delete()
- except google.cloud.exceptions.NotFound:
- pass
-
-
-@maybe_mock_gcs
-class ReaderTest(unittest.TestCase):
+class OpenTest(unittest.TestCase):
def setUp(self):
- # lower the multipart upload size, to speed up these tests
- self.old_min_buffer_size = smart_open.gcs.DEFAULT_BUFFER_SIZE
- smart_open.gcs.DEFAULT_BUFFER_SIZE = 5 * 1024**2
-
- ignore_resource_warnings()
-
- def tearDown(self):
- cleanup_bucket()
-
- def test_iter(self):
- """Are GCS files iterated over correctly?"""
- expected = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=expected)
-
- # connect to fake GCS and read from the fake key we filled above
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- output = [line.rstrip(b'\n') for line in fin]
- self.assertEqual(output, expected.split(b'\n'))
-
- def test_iter_context_manager(self):
- # same thing but using a context manager
- expected = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=expected)
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- output = [line.rstrip(b'\n') for line in fin]
- self.assertEqual(output, expected.split(b'\n'))
-
- def test_read(self):
- """Are GCS files read correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
- logger.debug('content: %r len: %r', content, len(content))
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- self.assertEqual(content[:6], fin.read(6))
- self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
- self.assertEqual(content[14:], fin.read()) # read the rest
-
- def test_seek_beginning(self):
- """Does seeking to the beginning of GCS files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- self.assertEqual(content[:6], fin.read(6))
- self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
-
- fin.seek(0)
- self.assertEqual(content, fin.read()) # no size given => read whole file
-
- fin.seek(0)
- self.assertEqual(content, fin.read(-1)) # same thing
-
- def test_seek_start(self):
- """Does seeking from the start of GCS files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- seek = fin.seek(6)
- self.assertEqual(seek, 6)
- self.assertEqual(fin.tell(), 6)
- self.assertEqual(fin.read(6), u'wořld'.encode('utf-8'))
-
- def test_seek_current(self):
- """Does seeking from the middle of GCS files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- self.assertEqual(fin.read(5), b'hello')
- seek = fin.seek(1, whence=smart_open.constants.WHENCE_CURRENT)
- self.assertEqual(seek, 6)
- self.assertEqual(fin.read(6), u'wořld'.encode('utf-8'))
-
- def test_seek_end(self):
- """Does seeking from the end of GCS files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END)
- self.assertEqual(seek, len(content) - 4)
- self.assertEqual(fin.read(), b'you?')
-
- def test_detect_eof(self):
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME)
- fin.read()
- eof = fin.tell()
- self.assertEqual(eof, len(content))
- fin.seek(0, whence=smart_open.constants.WHENCE_END)
- self.assertEqual(eof, fin.tell())
-
- def test_read_gzip(self):
- expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8')
- buf = io.BytesIO()
- buf.close = lambda: None # keep buffer open so that we can .getvalue()
- with gzip.GzipFile(fileobj=buf, mode='w') as zipfile:
- zipfile.write(expected)
- put_to_bucket(contents=buf.getvalue())
-
- #
- # Make sure we're reading things correctly.
- #
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- self.assertEqual(fin.read(), buf.getvalue())
-
- #
- # Make sure the buffer we wrote is legitimate gzip.
- #
- sanity_buf = io.BytesIO(buf.getvalue())
- with gzip.GzipFile(fileobj=sanity_buf) as zipfile:
- self.assertEqual(zipfile.read(), expected)
-
- logger.debug('starting actual test')
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- with gzip.GzipFile(fileobj=fin) as zipfile:
- actual = zipfile.read()
-
- self.assertEqual(expected, actual)
-
- def test_readline(self):
- content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
-
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- fin.readline()
- self.assertEqual(fin.tell(), content.index(b'\n')+1)
-
- fin.seek(0)
- actual = list(fin)
- self.assertEqual(fin.tell(), len(content))
-
- expected = [b'englishman\n', b'in\n', b'new\n', b'york\n']
- self.assertEqual(expected, actual)
-
- def test_readline_tiny_buffer(self):
- content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
-
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME, buffer_size=8) as fin:
- actual = list(fin)
-
- expected = [b'englishman\n', b'in\n', b'new\n', b'york\n']
- self.assertEqual(expected, actual)
-
- def test_read0_does_not_return_data(self):
- content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
-
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- data = fin.read(0)
-
- self.assertEqual(data, b'')
-
- def test_read_past_end(self):
- content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
-
- with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin:
- data = fin.read(100)
-
- self.assertEqual(data, content)
-
+ if DISABLE_MOCKS:
+ self.client = google.cloud.storage.Client()
+ else:
+ self.client = FakeClient()
+ self.mock_gcs = mock.patch('smart_open.gcs.google.cloud.storage.Client').start()
+ self.mock_gcs.return_value = self.client
-@maybe_mock_gcs
-class WriterTest(unittest.TestCase):
- """
- Test writing into GCS files.
+ self.client.create_bucket(BUCKET_NAME)
- """
- def setUp(self):
ignore_resource_warnings()
def tearDown(self):
- cleanup_bucket()
-
- def test_write_01(self):
- """Does writing into GCS work correctly?"""
- test_string = u"žluťoučký koníček".encode('utf8')
-
- with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout:
- fout.write(test_string)
-
- with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb") as fin:
- output = list(fin)
-
- self.assertEqual(output, [test_string])
-
- def test_incorrect_input(self):
- """Does gcs write fail on incorrect input?"""
- try:
- with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fin:
- fin.write(None)
- except TypeError:
- pass
- else:
- self.fail()
-
- def test_write_02(self):
- """Does gcs write unicode-utf8 conversion work?"""
- smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME)
- smart_open_write.tell()
- logger.info("smart_open_write: %r", smart_open_write)
- with smart_open_write as fout:
- fout.write(u"testžížáč".encode("utf-8"))
- self.assertEqual(fout.tell(), 14)
-
- def test_write_03(self):
- """Do multiple writes less than the min_part_size work correctly?"""
- # write
- min_part_size = 256 * 1024
- smart_open_write = smart_open.gcs.Writer(
- BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size
- )
- local_write = io.BytesIO()
-
- with smart_open_write as fout:
- first_part = b"t" * 262141
- fout.write(first_part)
- local_write.write(first_part)
- self.assertEqual(fout._current_part.tell(), 262141)
-
- second_part = b"t\n"
- fout.write(second_part)
- local_write.write(second_part)
- self.assertEqual(fout._current_part.tell(), 262143)
- self.assertEqual(fout._total_parts, 0)
-
- third_part = b"t"
- fout.write(third_part)
- local_write.write(third_part)
- self.assertEqual(fout._current_part.tell(), 262144)
- self.assertEqual(fout._total_parts, 0)
-
- fourth_part = b"t" * 1
- fout.write(fourth_part)
- local_write.write(fourth_part)
- self.assertEqual(fout._current_part.tell(), 1)
- self.assertEqual(fout._total_parts, 1)
-
- # read back the same key and check its content
- output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)))
- local_write.seek(0)
- actual = [line.decode("utf-8") for line in list(local_write)]
- self.assertEqual(output, actual)
-
- def test_write_03a(self):
- """Do multiple writes greater than the min_part_size work correctly?"""
- min_part_size = 256 * 1024
- smart_open_write = smart_open.gcs.Writer(
- BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size
- )
- local_write = io.BytesIO()
-
- with smart_open_write as fout:
- for i in range(1, 4):
- part = b"t" * (min_part_size + 1)
- fout.write(part)
- local_write.write(part)
- self.assertEqual(fout._current_part.tell(), i)
- self.assertEqual(fout._total_parts, i)
-
- # read back the same key and check its content
- output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)))
- local_write.seek(0)
- actual = [line.decode("utf-8") for line in list(local_write)]
- self.assertEqual(output, actual)
-
- def test_write_03b(self):
- """Does writing a last chunk size equal to a multiple of the min_part_size work?"""
- min_part_size = 256 * 1024
- smart_open_write = smart_open.gcs.Writer(
- BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size
- )
- expected = b"t" * min_part_size * 2
-
- with smart_open_write as fout:
- fout.write(expected)
- self.assertEqual(fout._current_part.tell(), 262144)
- self.assertEqual(fout._total_parts, 1)
-
- # read back the same key and check its content
- with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)) as fin:
- output = fin.read().encode('utf-8')
-
- self.assertEqual(output, expected)
-
- def test_write_04(self):
- """Does writing no data cause key with an empty value to be created?"""
- smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME)
- with smart_open_write as fout: # noqa
- pass
-
- # read back the same key and check its content
- output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)))
-
- self.assertEqual(output, [])
-
- def test_write_05(self):
- """Do blob_properties get applied?"""
- smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME,
- blob_properties={
- "content_type": "random/x-test",
- "content_encoding": "coded"
- }
- )
- with smart_open_write as fout: # noqa
- assert fout._blob.content_type == "random/x-test"
- assert fout._blob.content_encoding == "coded"
-
- def test_gzip(self):
- expected = u'а не спеть ли мне песню... о любви'.encode('utf-8')
- with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout:
- with gzip.GzipFile(fileobj=fout, mode='w') as zipfile:
- zipfile.write(expected)
-
- with smart_open.gcs.Reader(BUCKET_NAME, WRITE_BLOB_NAME) as fin:
- with gzip.GzipFile(fileobj=fin) as zipfile:
- actual = zipfile.read()
-
- self.assertEqual(expected, actual)
-
- def test_buffered_writer_wrapper_works(self):
- """
- Ensure that we can wrap a smart_open gcs stream in a BufferedWriter, which
- passes a memoryview object to the underlying stream in python >= 2.7
- """
- expected = u'не думай о секундах свысока'
-
- with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout:
- with io.BufferedWriter(fout) as sub_out:
- sub_out.write(expected.encode('utf-8'))
-
- with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), 'rb') as fin:
- with io.TextIOWrapper(fin, encoding='utf-8') as text:
- actual = text.read()
-
- self.assertEqual(expected, actual)
-
- def test_binary_iterator(self):
- expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ')
- put_to_bucket(contents=b"\n".join(expected))
- with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, 'rb') as fin:
- actual = [line.rstrip() for line in fin]
- self.assertEqual(expected, actual)
-
- def test_nonexisting_bucket(self):
- expected = u"выйду ночью в поле с конём".encode('utf-8')
- with self.assertRaises(google.api_core.exceptions.NotFound):
- with smart_open.gcs.open('thisbucketdoesntexist', 'mykey', 'wb') as fout:
- fout.write(expected)
-
- def test_read_nonexisting_key(self):
- with self.assertRaises(google.api_core.exceptions.NotFound):
- with smart_open.gcs.open(BUCKET_NAME, 'my_nonexisting_key', 'rb') as fin:
- fin.read()
-
- def test_double_close(self):
- text = u'там за туманами, вечными, пьяными'.encode('utf-8')
- fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb')
- fout.write(text)
- fout.close()
- fout.close()
-
- def test_flush_close(self):
- text = u'там за туманами, вечными, пьяными'.encode('utf-8')
- fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb')
- fout.write(text)
- fout.flush()
- fout.close()
-
- def test_terminate(self):
- text = u'там за туманами, вечными, пьяными'.encode('utf-8')
- fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb')
- fout.write(text)
- fout.terminate()
-
- with self.assertRaises(google.api_core.exceptions.NotFound):
- with smart_open.gcs.open(BUCKET_NAME, 'key', 'rb') as fin:
- fin.read()
-
-
-@maybe_mock_gcs
-class OpenTest(unittest.TestCase):
- def setUp(self):
- ignore_resource_warnings()
+ cleanup_test_bucket(self.client)
+ bucket = get_test_bucket(self.client)
+ bucket.delete()
- def tearDown(self):
- cleanup_bucket()
+ if not DISABLE_MOCKS:
+ self.mock_gcs.stop()
def test_read_never_returns_none(self):
"""read should never return None."""
test_string = u"ветер по морю гуляет..."
with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "wb") as fout:
- self.assertEqual(fout.name, BLOB_NAME)
fout.write(test_string.encode('utf8'))
r = smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "rb")
- self.assertEqual(r.name, BLOB_NAME)
self.assertEqual(r.read(), test_string.encode("utf-8"))
self.assertEqual(r.read(), b"")
self.assertEqual(r.read(), b"")
@@ -918,14 +276,47 @@ class OpenTest(unittest.TestCase):
self.assertEqual(test_string, actual)
-class MakeRangeStringTest(unittest.TestCase):
- def test_no_stop(self):
- start, stop = 1, None
- self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-/*')
+class WriterTest(unittest.TestCase):
+ def setUp(self):
+ self.client = FakeClient()
+ self.mock_gcs = mock.patch('smart_open.gcs.google.cloud.storage.Client').start()
+ self.mock_gcs.return_value = self.client
+
+ self.client.create_bucket(BUCKET_NAME)
- def test_stop(self):
- start, stop = 1, 2
- self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-2/*')
+ def tearDown(self):
+ cleanup_test_bucket(self.client)
+ bucket = get_test_bucket(self.client)
+ bucket.delete()
+ self.mock_gcs.stop()
+
+ def test_property_passthrough(self):
+ blob_properties = {'content_type': 'text/utf-8'}
+
+ smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME, blob_properties=blob_properties)
+
+ b = self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME)
+
+ for k, v in blob_properties.items():
+ self.assertEqual(getattr(b, k), v)
+
+ def test_default_open_kwargs(self):
+ smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME)
+
+ self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME) \
+ .open.assert_called_once_with('wb', **smart_open.gcs._DEFAULT_WRITE_OPEN_KWARGS)
+
+ def test_open_kwargs_passthrough(self):
+ open_kwargs = {'ignore_flush': True, 'property': 'value', 'something': 2}
+
+ smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME, blob_open_kwargs=open_kwargs)
+
+ self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME) \
+ .open.assert_called_once_with('wb', **open_kwargs)
+
+ def test_non_existing_bucket(self):
+ with self.assertRaises(google.cloud.exceptions.NotFound):
+ smart_open.gcs.Writer('unknown_bucket', BLOB_NAME)
if __name__ == '__main__':
diff --git a/smart_open/tests/test_hdfs.py b/smart_open/tests/test_hdfs.py
index c38f8bd..24c8d5f 100644
--- a/smart_open/tests/test_hdfs.py
+++ b/smart_open/tests/test_hdfs.py
@@ -9,20 +9,18 @@ import gzip
import os
import os.path as P
import subprocess
-import unittest
from unittest import mock
import sys
-import smart_open.hdfs
+import pytest
-#
-# Workaround for https://bugs.python.org/issue37380
-#
-if sys.version_info[:2] == (3, 6):
- subprocess._cleanup = lambda: None
+import smart_open.hdfs
CURR_DIR = P.dirname(P.abspath(__file__))
+if sys.platform.startswith("win"):
+ pytest.skip("these tests don't work under Windows", allow_module_level=True)
+
#
# We want our mocks to emulate the real implementation as close as possible,
@@ -43,100 +41,97 @@ def cat(path=None):
return subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
-class CliRawInputBaseTest(unittest.TestCase):
- def setUp(self):
- self.path = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt')
+CAP_PATH = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt')
+with open(CAP_PATH, encoding='utf-8') as fin:
+ CRIME_AND_PUNISHMENT = fin.read()
- #
- # We have to specify the encoding explicitly, because different
- # platforms like Windows may be using something other than unicode
- # by default.
- #
- with open(self.path, encoding='utf-8') as fin:
- self.expected = fin.read()
- self.cat = cat(self.path)
- def test_read(self):
- with mock.patch('subprocess.Popen', return_value=self.cat):
- reader = smart_open.hdfs.CliRawInputBase('hdfs://dummy/url')
- as_bytes = reader.read()
+def test_sanity_read_bytes():
+ with open(CAP_PATH, 'rb') as fin:
+ lines = [line for line in fin]
+ assert len(lines) == 3
- #
- # Not 100% sure why this is necessary on Windows platforms, but the
- # tests fail without it. It may be a bug, but I don't have time to
- # investigate right now.
- #
- as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n')
- assert as_text == self.expected
- def test_read_75(self):
- with mock.patch('subprocess.Popen', return_value=self.cat):
- reader = smart_open.hdfs.CliRawInputBase('hdfs://dummy/url')
- as_bytes = reader.read(75)
+def test_sanity_read_text():
+ with open(CAP_PATH, 'r', encoding='utf-8') as fin:
+ text = fin.read()
- as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n')
- assert as_text == self.expected[:len(as_text)]
+ expected = 'В начале июля, в чрезвычайно жаркое время'
+ assert text[:len(expected)] == expected
- def test_unzip(self):
- path = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt.gz')
- with mock.patch('subprocess.Popen', return_value=cat(path)):
- with gzip.GzipFile(fileobj=smart_open.hdfs.CliRawInputBase('hdfs://dummy/url')) as fin:
- as_bytes = fin.read()
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_read(schema):
+ with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)):
+ reader = smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url')
+ as_bytes = reader.read()
- as_text = as_bytes.decode('utf-8')
- assert as_text == self.expected
+ #
+ # Not 100% sure why this is necessary on Windows platforms, but the
+ # tests fail without it. It may be a bug, but I don't have time to
+ # investigate right now.
+ #
+ as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n')
+ assert as_text == CRIME_AND_PUNISHMENT
- def test_context_manager(self):
- with mock.patch('subprocess.Popen', return_value=self.cat):
- with smart_open.hdfs.CliRawInputBase('hdfs://dummy/url') as fin:
- as_bytes = fin.read()
- as_text = as_bytes.decode('utf-8').replace('\r\n', '\n')
- assert as_text == self.expected
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_read_75(schema):
+ with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)):
+ reader = smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url')
+ as_bytes = reader.read(75)
+ as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n')
+ assert as_text == CRIME_AND_PUNISHMENT[:len(as_text)]
-class SanityTest(unittest.TestCase):
- def test_read_bytes(self):
- path = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt')
- with open(path, 'rb') as fin:
- lines = [line for line in fin]
- assert len(lines) == 3
- def test_read_text(self):
- path = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt')
- with open(path, 'r', encoding='utf-8') as fin:
- text = fin.read()
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_unzip(schema):
+ with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH + '.gz')):
+ with gzip.GzipFile(fileobj=smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url')) as fin:
+ as_bytes = fin.read()
+
+ as_text = as_bytes.decode('utf-8')
+ assert as_text == CRIME_AND_PUNISHMENT
+
+
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_context_manager(schema):
+ with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)):
+ with smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url') as fin:
+ as_bytes = fin.read()
+
+ as_text = as_bytes.decode('utf-8').replace('\r\n', '\n')
+ assert as_text == CRIME_AND_PUNISHMENT
- expected = 'В начале июля, в чрезвычайно жаркое время'
- assert text[:len(expected)] == expected
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_write(schema):
+ expected = 'мы в ответе за тех, кого приручили'
+ mocked_cat = cat()
-class CliRawOutputBaseTest(unittest.TestCase):
- def test_write(self):
- expected = 'мы в ответе за тех, кого приручили'
- mocked_cat = cat()
+ with mock.patch('subprocess.Popen', return_value=mocked_cat):
+ with smart_open.hdfs.CliRawOutputBase(f'{schema}://dummy/url') as fout:
+ fout.write(expected.encode('utf-8'))
- with mock.patch('subprocess.Popen', return_value=mocked_cat):
- with smart_open.hdfs.CliRawOutputBase('hdfs://dummy/url') as fout:
- fout.write(expected.encode('utf-8'))
+ actual = mocked_cat.stdout.read().decode('utf-8')
+ assert actual == expected
- actual = mocked_cat.stdout.read().decode('utf-8')
- assert actual == expected
- def test_zip(self):
- expected = 'мы в ответе за тех, кого приручили'
- mocked_cat = cat()
+@pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )])
+def test_write_zip(schema):
+ expected = 'мы в ответе за тех, кого приручили'
+ mocked_cat = cat()
- with mock.patch('subprocess.Popen', return_value=mocked_cat):
- with smart_open.hdfs.CliRawOutputBase('hdfs://dummy/url') as fout:
- with gzip.GzipFile(fileobj=fout, mode='wb') as gz_fout:
- gz_fout.write(expected.encode('utf-8'))
+ with mock.patch('subprocess.Popen', return_value=mocked_cat):
+ with smart_open.hdfs.CliRawOutputBase(f'{schema}://dummy/url') as fout:
+ with gzip.GzipFile(fileobj=fout, mode='wb') as gz_fout:
+ gz_fout.write(expected.encode('utf-8'))
- with gzip.GzipFile(fileobj=mocked_cat.stdout) as fin:
- actual = fin.read().decode('utf-8')
+ with gzip.GzipFile(fileobj=mocked_cat.stdout) as fin:
+ actual = fin.read().decode('utf-8')
- assert actual == expected
+ assert actual == expected
def main():
diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py
index ab4b6d8..a91a731 100644
--- a/smart_open/tests/test_s3.py
+++ b/smart_open/tests/test_s3.py
@@ -6,12 +6,12 @@
# from the MIT License (MIT).
#
from collections import defaultdict
+import functools
import gzip
import io
import logging
import os
import tempfile
-import time
import unittest
import warnings
from contextlib import contextmanager
@@ -45,42 +45,9 @@ ENABLE_MOTO_SERVER = os.environ.get("SO_ENABLE_MOTO_SERVER") == "1"
os.environ["AWS_ACCESS_KEY_ID"] = "test"
os.environ["AWS_SECRET_ACCESS_KEY"] = "test"
-
logger = logging.getLogger(__name__)
-
-@moto.mock_s3
-def setUpModule():
- '''Called once by unittest when initializing this module. Sets up the
- test S3 bucket.
-
- '''
- bucket = boto3.resource('s3').create_bucket(Bucket=BUCKET_NAME)
- bucket.wait_until_exists()
-
-
-def cleanup_bucket():
- for key in boto3.resource('s3').Bucket(BUCKET_NAME).objects.all():
- key.delete()
-
-
-def put_to_bucket(contents, num_attempts=12, sleep_time=5):
- logger.debug('%r', locals())
-
- #
- # In real life, it can take a few seconds for the bucket to become ready.
- # If we try to write to the key while the bucket while it isn't ready, we
- # will get a ClientError: NoSuchBucket.
- #
- for attempt in range(num_attempts):
- try:
- boto3.resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=contents)
- return
- except botocore.exceptions.ClientError as err:
- logger.error('caught %r, retrying', err)
- time.sleep(sleep_time)
-
- assert False, 'failed to write to bucket %s after %d attempts' % (BUCKET_NAME, num_attempts)
+_resource = functools.partial(boto3.resource, region_name='us-east-1')
def ignore_resource_warnings():
@@ -231,66 +198,57 @@ class ReaderTest(BaseTest):
super().setUp()
+ s3 = _resource('s3')
+ s3.create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+
+ self.body = u"hello wořld\nhow are you?".encode('utf8')
+ s3.Object(BUCKET_NAME, KEY_NAME).put(Body=self.body)
+
def tearDown(self):
smart_open.s3.DEFAULT_MIN_PART_SIZE = self.old_min_part_size
- cleanup_bucket()
def test_iter(self):
"""Are S3 files iterated over correctly?"""
- # a list of strings to test with
- expected = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=expected)
-
# connect to fake s3 and read from the fake key we filled above
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME)
output = [line.rstrip(b'\n') for line in fin]
- self.assertEqual(output, expected.split(b'\n'))
+ self.assertEqual(output, self.body.split(b'\n'))
def test_iter_context_manager(self):
# same thing but using a context manager
- expected = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=expected)
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+
with self.assertApiCalls(GetObject=1):
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin:
output = [line.rstrip(b'\n') for line in fin]
- self.assertEqual(output, expected.split(b'\n'))
+ self.assertEqual(output, self.body.split(b'\n'))
def test_read(self):
"""Are S3 files read correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
- logger.debug('content: %r len: %r', content, len(content))
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME)
- self.assertEqual(content[:6], fin.read(6))
- self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
- self.assertEqual(content[14:], fin.read()) # read the rest
+ self.assertEqual(self.body[:6], fin.read(6))
+ self.assertEqual(self.body[6:14], fin.read(8)) # ř is 2 bytes
+ self.assertEqual(self.body[14:], fin.read()) # read the rest
def test_seek_beginning(self):
"""Does seeking to the beginning of S3 files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME)
- self.assertEqual(content[:6], fin.read(6))
- self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
+ self.assertEqual(self.body[:6], fin.read(6))
+ self.assertEqual(self.body[6:14], fin.read(8)) # ř is 2 bytes
with self.assertApiCalls(GetObject=1):
fin.seek(0)
- self.assertEqual(content, fin.read()) # no size given => read whole file
+ self.assertEqual(self.body, fin.read()) # no size given => read whole file
with self.assertApiCalls(GetObject=1):
fin.seek(0)
- self.assertEqual(content, fin.read(-1)) # same thing
+ self.assertEqual(self.body, fin.read(-1)) # same thing
def test_seek_start(self):
"""Does seeking from the start of S3 files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True)
seek = fin.seek(6)
@@ -300,9 +258,6 @@ class ReaderTest(BaseTest):
def test_seek_current(self):
"""Does seeking from the middle of S3 files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME)
self.assertEqual(fin.read(5), b'hello')
@@ -314,33 +269,24 @@ class ReaderTest(BaseTest):
def test_seek_end(self):
"""Does seeking from the end of S3 files work correctly?"""
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True)
seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END)
- self.assertEqual(seek, len(content) - 4)
+ self.assertEqual(seek, len(self.body) - 4)
self.assertEqual(fin.read(), b'you?')
def test_seek_past_end(self):
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
- with self.assertApiCalls(GetObject=1), patch_invalid_range_response(str(len(content))):
+ with self.assertApiCalls(GetObject=1), patch_invalid_range_response(str(len(self.body))):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True)
seek = fin.seek(60)
- self.assertEqual(seek, len(content))
+ self.assertEqual(seek, len(self.body))
def test_detect_eof(self):
- content = u"hello wořld\nhow are you?".encode('utf8')
- put_to_bucket(contents=content)
-
with self.assertApiCalls(GetObject=1):
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME)
fin.read()
eof = fin.tell()
- self.assertEqual(eof, len(content))
+ self.assertEqual(eof, len(self.body))
fin.seek(0, whence=smart_open.constants.WHENCE_END)
self.assertEqual(eof, fin.tell())
fin.seek(eof)
@@ -352,7 +298,8 @@ class ReaderTest(BaseTest):
buf.close = lambda: None # keep buffer open so that we can .getvalue()
with gzip.GzipFile(fileobj=buf, mode='w') as zipfile:
zipfile.write(expected)
- put_to_bucket(contents=buf.getvalue())
+
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=buf.getvalue())
#
# Make sure we're reading things correctly.
@@ -377,7 +324,7 @@ class ReaderTest(BaseTest):
def test_readline(self):
content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=content)
with self.assertApiCalls(GetObject=2):
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin:
@@ -393,7 +340,7 @@ class ReaderTest(BaseTest):
def test_readline_tiny_buffer(self):
content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=content)
with self.assertApiCalls(GetObject=1):
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, buffer_size=8) as fin:
@@ -403,9 +350,6 @@ class ReaderTest(BaseTest):
self.assertEqual(expected, actual)
def test_read0_does_not_return_data(self):
- content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
-
with self.assertApiCalls():
# set defer_seek to verify that read(0) doesn't trigger an unnecessary API call
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) as fin:
@@ -414,20 +358,18 @@ class ReaderTest(BaseTest):
self.assertEqual(data, b'')
def test_to_boto3(self):
- contents = b'the spice melange\n'
- put_to_bucket(contents=contents)
-
with self.assertApiCalls():
# set defer_seek to verify that to_boto3() doesn't trigger an unnecessary API call
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) as fin:
- returned_obj = fin.to_boto3(boto3.resource('s3'))
+ returned_obj = fin.to_boto3(_resource('s3'))
boto3_body = returned_obj.get()['Body'].read()
- self.assertEqual(contents, boto3_body)
+ self.assertEqual(self.body, boto3_body)
def test_binary_iterator(self):
expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ')
- put_to_bucket(contents=b"\n".join(expected))
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=b'\n'.join(expected))
+
with self.assertApiCalls(GetObject=1):
with smart_open.s3.open(BUCKET_NAME, KEY_NAME, 'rb') as fin:
actual = [line.rstrip() for line in fin]
@@ -435,7 +377,7 @@ class ReaderTest(BaseTest):
def test_defer_seek(self):
content = b'englishman\nin\nnew\nyork\n'
- put_to_bucket(contents=content)
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=content)
with self.assertApiCalls():
fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True)
@@ -449,7 +391,7 @@ class ReaderTest(BaseTest):
self.assertEqual(fin.read(), content[10:])
def test_read_empty_file(self):
- put_to_bucket(contents=b'')
+ _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=b'')
with self.assertApiCalls(GetObject=1), patch_invalid_range_response('0'):
with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin:
@@ -467,8 +409,7 @@ class MultipartWriterTest(unittest.TestCase):
def setUp(self):
ignore_resource_warnings()
- def tearDown(self):
- cleanup_bucket()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
def test_write_01(self):
"""Does writing into s3 work correctly?"""
@@ -595,7 +536,7 @@ class MultipartWriterTest(unittest.TestCase):
with smart_open.s3.open(BUCKET_NAME, KEY_NAME, 'wb') as fout:
fout.write(contents)
- returned_obj = fout.to_boto3(boto3.resource('s3'))
+ returned_obj = fout.to_boto3(_resource('s3'))
boto3_body = returned_obj.get()['Body'].read()
self.assertEqual(contents, boto3_body)
@@ -623,8 +564,7 @@ class SinglepartWriterTest(unittest.TestCase):
def setUp(self):
ignore_resource_warnings()
- def tearDown(self):
- cleanup_bucket()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
def test_write_01(self):
"""Does writing into s3 work correctly?"""
@@ -729,9 +669,7 @@ ARBITRARY_CLIENT_ERROR = botocore.client.ClientError(error_response={}, operatio
class IterBucketTest(unittest.TestCase):
def setUp(self):
ignore_resource_warnings()
-
- def tearDown(self):
- cleanup_bucket()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
@pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows")
@pytest.mark.xfail(
@@ -743,6 +681,38 @@ class IterBucketTest(unittest.TestCase):
results = list(smart_open.s3.iter_bucket(BUCKET_NAME))
self.assertEqual(len(results), 10)
+ @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows")
+ @pytest.mark.xfail(
+ condition=sys.platform == 'darwin',
+ reason="MacOS uses spawn rather than fork for multiprocessing",
+ )
+ def test_iter_bucket_404(self):
+ populate_bucket()
+
+ def throw_404_error_for_key_4(*args):
+ if args[1] == "key_4":
+ raise botocore.exceptions.ClientError(
+ error_response={"Error": {"Code": "404", "Message": "Not Found"}},
+ operation_name="HeadObject",
+ )
+ else:
+ return [0]
+
+ with mock.patch("smart_open.s3._download_fileobj", side_effect=throw_404_error_for_key_4):
+ results = list(smart_open.s3.iter_bucket(BUCKET_NAME))
+ self.assertEqual(len(results), 9)
+
+ @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows")
+ @pytest.mark.xfail(
+ condition=sys.platform == 'darwin',
+ reason="MacOS uses spawn rather than fork for multiprocessing",
+ )
+ def test_iter_bucket_non_404(self):
+ populate_bucket()
+ with mock.patch("smart_open.s3._download_fileobj", side_effect=ARBITRARY_CLIENT_ERROR):
+ with pytest.raises(botocore.exceptions.ClientError):
+ list(smart_open.s3.iter_bucket(BUCKET_NAME))
+
def test_deprecated_top_level_s3_iter_bucket(self):
populate_bucket()
with self.assertLogs(smart_open.logger.name, level='WARN') as cm:
@@ -762,7 +732,7 @@ class IterBucketTest(unittest.TestCase):
)
def test_accepts_boto3_bucket(self):
populate_bucket()
- bucket = boto3.resource('s3').Bucket(BUCKET_NAME)
+ bucket = _resource('s3').Bucket(BUCKET_NAME)
results = list(smart_open.s3.iter_bucket(bucket))
self.assertEqual(len(results), 10)
@@ -801,9 +771,10 @@ class IterBucketConcurrentFuturesTest(unittest.TestCase):
smart_open.concurrency._MULTIPROCESSING = False
ignore_resource_warnings()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+
def tearDown(self):
smart_open.concurrency._MULTIPROCESSING = self.old_flag_multi
- cleanup_bucket()
def test(self):
num_keys = 101
@@ -831,9 +802,10 @@ class IterBucketMultiprocessingTest(unittest.TestCase):
smart_open.concurrency._CONCURRENT_FUTURES = False
ignore_resource_warnings()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+
def tearDown(self):
smart_open.concurrency._CONCURRENT_FUTURES = self.old_flag_concurrent
- cleanup_bucket()
def test(self):
num_keys = 101
@@ -855,10 +827,11 @@ class IterBucketSingleProcessTest(unittest.TestCase):
ignore_resource_warnings()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+
def tearDown(self):
smart_open.concurrency._MULTIPROCESSING = self.old_flag_multi
smart_open.concurrency._CONCURRENT_FUTURES = self.old_flag_concurrent
- cleanup_bucket()
def test(self):
num_keys = 101
@@ -877,6 +850,7 @@ class IterBucketSingleProcessTest(unittest.TestCase):
@moto.mock_s3
class IterBucketCredentialsTest(unittest.TestCase):
def test(self):
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
num_keys = 10
populate_bucket(num_keys=num_keys)
result = list(
@@ -895,28 +869,26 @@ class DownloadKeyTest(unittest.TestCase):
def setUp(self):
ignore_resource_warnings()
- def tearDown(self):
- cleanup_bucket()
+ s3 = _resource('s3')
+ bucket = s3.create_bucket(Bucket=BUCKET_NAME)
+ bucket.wait_until_exists()
+
+ self.body = b'hello'
+ s3.Object(BUCKET_NAME, KEY_NAME).put(Body=self.body)
def test_happy(self):
- contents = b'hello'
- put_to_bucket(contents=contents)
- expected = (KEY_NAME, contents)
+ expected = (KEY_NAME, self.body)
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME)
self.assertEqual(expected, actual)
def test_intermittent_error(self):
- contents = b'hello'
- put_to_bucket(contents=contents)
- expected = (KEY_NAME, contents)
- side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, contents]
+ expected = (KEY_NAME, self.body)
+ side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, self.body]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME)
self.assertEqual(expected, actual)
def test_persistent_error(self):
- contents = b'hello'
- put_to_bucket(contents=contents)
side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR,
ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
@@ -924,18 +896,14 @@ class DownloadKeyTest(unittest.TestCase):
KEY_NAME, bucket_name=BUCKET_NAME)
def test_intermittent_error_retries(self):
- contents = b'hello'
- put_to_bucket(contents=contents)
- expected = (KEY_NAME, contents)
+ expected = (KEY_NAME, self.body)
side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR,
- ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, contents]
+ ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, self.body]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME, retries=4)
self.assertEqual(expected, actual)
def test_propagates_other_exception(self):
- contents = b'hello'
- put_to_bucket(contents=contents)
with mock.patch('smart_open.s3._download_fileobj', side_effect=ValueError):
self.assertRaises(ValueError, smart_open.s3._download_key,
KEY_NAME, bucket_name=BUCKET_NAME)
@@ -945,9 +913,7 @@ class DownloadKeyTest(unittest.TestCase):
class OpenTest(unittest.TestCase):
def setUp(self):
ignore_resource_warnings()
-
- def tearDown(self):
- cleanup_bucket()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
def test_read_never_returns_none(self):
"""read should never return None."""
@@ -962,7 +928,7 @@ class OpenTest(unittest.TestCase):
def populate_bucket(num_keys=10):
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
for key_number in range(num_keys):
key_name = 'key_%d' % key_number
s3.Object(BUCKET_NAME, key_name).put(Body=str(key_number))
@@ -993,9 +959,7 @@ def test_client_propagation_singlepart():
# have done that for us by now.
#
session = boto3.Session()
- resource = session.resource('s3')
- bucket = resource.create_bucket(Bucket=BUCKET_NAME)
- bucket.wait_until_exists()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
client = session.client('s3')
@@ -1014,9 +978,7 @@ def test_client_propagation_singlepart():
def test_client_propagation_multipart():
"""Does the resource parameter make it from the caller to Boto3?"""
session = boto3.Session()
- resource = session.resource('s3')
- bucket = resource.create_bucket(Bucket=BUCKET_NAME)
- bucket.wait_until_exists()
+ _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
client = session.client('s3')
@@ -1035,7 +997,7 @@ def test_client_propagation_multipart():
def test_resource_propagation_reader():
"""Does the resource parameter make it from the caller to Boto3?"""
session = boto3.Session()
- resource = session.resource('s3')
+ resource = session.resource('s3', region_name='us-east-1')
bucket = resource.create_bucket(Bucket=BUCKET_NAME)
bucket.wait_until_exists()
diff --git a/smart_open/tests/test_s3_version.py b/smart_open/tests/test_s3_version.py
index 907187b..6f9584d 100644
--- a/smart_open/tests/test_s3_version.py
+++ b/smart_open/tests/test_s3_version.py
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
+import functools
import logging
import unittest
import uuid
+import time
import boto3
import moto
@@ -16,32 +18,7 @@ KEY_NAME = 'test-key'
logger = logging.getLogger(__name__)
-@moto.mock_s3
-def setUpModule():
- '''Called once by unittest when initializing this module. Sets up the
- test S3 bucket.
-
- '''
- bucket = boto3.resource('s3').create_bucket(Bucket=BUCKET_NAME)
- bucket.wait_until_exists()
- boto3.resource('s3').BucketVersioning(BUCKET_NAME).enable()
-
-
-@moto.mock_s3
-def tearDownModule():
- '''Called once by unittest when tearing down this module. Empties and
- removes the test S3 bucket.
-
- '''
- s3 = boto3.resource('s3')
- bucket = s3.Bucket(BUCKET_NAME)
- try:
- bucket.object_versions.delete()
- bucket.delete()
- except s3.meta.client.exceptions.NoSuchBucket:
- pass
-
- bucket.wait_until_not_exists()
+_resource = functools.partial(boto3.resource, region_name='us-east-1')
def get_versions(bucket, key):
@@ -49,7 +26,7 @@ def get_versions(bucket, key):
return [
v.id
for v in sorted(
- boto3.resource('s3').Bucket(bucket).object_versions.filter(Prefix=key),
+ _resource('s3').Bucket(bucket).object_versions.filter(Prefix=key),
key=lambda version: version.last_modified,
)
]
@@ -57,24 +34,27 @@ def get_versions(bucket, key):
@moto.mock_s3
class TestVersionId(unittest.TestCase):
-
def setUp(self):
#
# Each run of this test reuses the BUCKET_NAME, but works with a
# different key for isolation.
#
+ resource = _resource('s3')
+ resource.create_bucket(Bucket=BUCKET_NAME).wait_until_exists()
+ resource.BucketVersioning(BUCKET_NAME).enable()
+
self.key = 'test-write-key-{}'.format(uuid.uuid4().hex)
self.url = "s3://%s/%s" % (BUCKET_NAME, self.key)
self.test_ver1 = u"String version 1.0".encode('utf8')
self.test_ver2 = u"String version 2.0".encode('utf8')
- bucket = boto3.resource('s3').Bucket(BUCKET_NAME)
+ bucket = resource.Bucket(BUCKET_NAME)
bucket.put_object(Key=self.key, Body=self.test_ver1)
-
logging.critical('versions after first write: %r', get_versions(BUCKET_NAME, self.key))
- bucket.put_object(Key=self.key, Body=self.test_ver2)
+ time.sleep(3)
+ bucket.put_object(Key=self.key, Body=self.test_ver2)
self.versions = get_versions(BUCKET_NAME, self.key)
logging.critical('versions after second write: %r', get_versions(BUCKET_NAME, self.key))
@@ -112,7 +92,7 @@ class TestVersionId(unittest.TestCase):
actual = fin.read()
self.assertEqual(actual, self.test_ver2)
- def test_oldset_version(self):
+ def test_oldest_version(self):
"""Passing in the oldest version gives the oldest content?"""
params = {'version_id': self.versions[0]}
with open(self.url, mode='rb', transport_params=params) as fin:
@@ -124,7 +104,7 @@ class TestVersionId(unittest.TestCase):
self.versions = get_versions(BUCKET_NAME, self.key)
params = {'version_id': self.versions[0]}
with open(self.url, mode='rb', transport_params=params) as fin:
- returned_obj = fin.to_boto3(boto3.resource('s3'))
+ returned_obj = fin.to_boto3(_resource('s3'))
boto3_body = boto3_body = returned_obj.get()['Body'].read()
self.assertEqual(boto3_body, self.test_ver1)
diff --git a/smart_open/tests/test_sanity.py b/smart_open/tests/test_sanity.py
deleted file mode 100644
index da0c3fe..0000000
--- a/smart_open/tests/test_sanity.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import unittest
-
-import boto3
-import moto
-
-
-@moto.mock_s3()
-def setUpModule():
- bucket = boto3.resource('s3').create_bucket(Bucket='mybucket')
-
- bucket.wait_until_exists()
-
-
-@moto.mock_s3()
-def tearDownModule():
- resource = boto3.resource('s3')
- bucket = resource.Bucket('mybucket')
- try:
- bucket.delete()
- except resource.meta.client.exceptions.NoSuchBucket:
- pass
- bucket.wait_until_not_exists()
-
-
-@moto.mock_s3()
-class Test(unittest.TestCase):
-
- def test(self):
- resource = boto3.resource('s3')
-
- bucket = resource.Bucket('mybucket')
- self.assertEqual(bucket.name, 'mybucket')
-
- expected = b'hello'
- resource.Object('mybucket', 'mykey').put(Body=expected)
-
- actual = resource.Object('mybucket', 'mykey').get()['Body'].read()
- self.assertEqual(expected, actual)
-
- def tearDown(self):
- boto3.resource('s3').Object('mybucket', 'mykey').delete()
diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py
index 4532fd1..367e994 100644
--- a/smart_open/tests/test_smart_open.py
+++ b/smart_open/tests/test_smart_open.py
@@ -9,6 +9,7 @@
import bz2
import csv
import contextlib
+import functools
import io
import gzip
import hashlib
@@ -21,8 +22,7 @@ from unittest import mock
import warnings
import boto3
-from moto import mock_s3
-import parameterizedtestcase
+import moto
import pytest
import responses
@@ -38,6 +38,8 @@ CURR_DIR = os.path.abspath(os.path.dirname(__file__))
SAMPLE_TEXT = 'Hello, world!'
SAMPLE_BYTES = SAMPLE_TEXT.encode('utf-8')
+_resource = functools.partial(boto3.resource, region_name='us-east-1')
+
#
# For Windows platforms, under which tempfile.NamedTemporaryFile has some
@@ -446,23 +448,31 @@ class SmartOpenHttpTest(unittest.TestCase):
self.assertTrue('Authorization' in actual_request.headers)
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))
+ @responses.activate
+ def test_http_cert(self):
+ """Does cert parameter get passed to requests"""
+ responses.add(responses.GET, "http://127.0.0.1/index.html",
+ body='line1\nline2', stream=True)
+ cert_path = '/path/to/my/cert.pem'
+ tp = dict(cert=cert_path)
+ smart_open.open("http://127.0.0.1/index.html", transport_params=tp)
+ self.assertEqual(len(responses.calls), 1)
+ actual_request = responses.calls[0].request
+ self.assertEqual(cert_path, actual_request.req_kwargs['cert'])
+
@responses.activate
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
+ assert suffix in ('.gz', '.bz2')
+
raw_data = b'Hello World Compressed.' * 10000
- buf = make_buffer(name='data' + suffix)
- with smart_open.open(buf, 'wb') as outfile:
- outfile.write(raw_data)
- compressed_data = buf._value_when_closed
- # check that the string was actually compressed
- self.assertNotEqual(compressed_data, raw_data)
+ compressed_data = gzip_compress(raw_data) if suffix == '.gz' else bz2.compress(raw_data)
responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True)
url = 'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else '')
smart_open_object = smart_open.open(url, 'rb')
- # decompress the file and get the same md5 hash
- self.assertEqual(smart_open_object.read(), raw_data)
+ assert smart_open_object.read() == raw_data
def test_http_gz(self):
"""Can open gzip via http?"""
@@ -481,33 +491,6 @@ class SmartOpenHttpTest(unittest.TestCase):
self._test_compressed_http(".bz2", True)
-def make_buffer(cls=io.BytesIO, initial_value=None, name=None, noclose=False):
- """
- Construct a new in-memory file object aka "buf".
-
- :param cls: Class of the file object. Meaningful values are BytesIO and StringIO.
- :param initial_value: Passed directly to the constructor, this is the content of the returned buffer.
- :param name: Associated file path. Not assigned if is None (default).
- :param noclose: If True, disables the .close function.
- :return: Instance of `cls`.
- """
- buf = cls(initial_value) if initial_value else cls()
- if name is not None:
- buf.name = name
-
- buf._value_when_closed = None
- orig_close = buf.close
-
- def close():
- if buf.close.call_count == 1:
- buf._value_when_closed = buf.getvalue()
- if not noclose:
- orig_close()
-
- buf.close = mock.Mock(side_effect=close)
- return buf
-
-
class RealFileSystemTests(unittest.TestCase):
"""Tests that touch the file system via temporary files."""
@@ -565,124 +548,6 @@ class RealFileSystemTests(unittest.TestCase):
self.assertEqual(text, SAMPLE_TEXT * 2)
-class SmartOpenFileObjTest(unittest.TestCase):
- """
- Test passing raw file objects.
- """
-
- def test_read_bytes(self):
- """Can we read bytes from a byte stream?"""
- buf = make_buffer(initial_value=SAMPLE_BYTES)
- with smart_open.open(buf, 'rb') as sf:
- data = sf.read()
- self.assertEqual(data, SAMPLE_BYTES)
-
- def test_write_bytes(self):
- """Can we write bytes to a byte stream?"""
- buf = make_buffer()
- with smart_open.open(buf, 'wb') as sf:
- sf.write(SAMPLE_BYTES)
- self.assertEqual(buf.getvalue(), SAMPLE_BYTES)
-
- def test_read_text_stream_fails(self):
- """Attempts to read directly from a text stream should fail.
-
- This is because smart_open.open expects a byte stream as input.
- If you have a text stream, there's no point passing it to smart_open:
- you can read from it directly.
- """
- buf = make_buffer(io.StringIO, initial_value=SAMPLE_TEXT)
- with smart_open.open(buf, 'r') as sf:
- self.assertRaises(TypeError, sf.read) # we expect binary mode
-
- def test_write_text_stream_fails(self):
- """Attempts to write directly to a text stream should fail."""
- buf = make_buffer(io.StringIO)
- with smart_open.open(buf, 'w') as sf:
- with self.assertRaises(TypeError):
- sf.write(SAMPLE_TEXT) # we expect binary mode
- # Need to flush because TextIOWrapper may buffer and we need
- # to write to the underlying StringIO to get the TypeError.
- sf.flush()
-
- def test_read_text_from_bytestream(self):
- buf = make_buffer(initial_value=SAMPLE_BYTES)
- with smart_open.open(buf, 'r') as sf:
- data = sf.read()
- self.assertEqual(data, SAMPLE_TEXT)
-
- def test_read_text_from_bytestream_rt(self):
- buf = make_buffer(initial_value=SAMPLE_BYTES)
- with smart_open.open(buf, 'rt') as sf:
- data = sf.read()
- self.assertEqual(data, SAMPLE_TEXT)
-
- def test_read_text_from_bytestream_rtplus(self):
- buf = make_buffer(initial_value=SAMPLE_BYTES)
- with smart_open.open(buf, 'rt+') as sf:
- data = sf.read()
- self.assertEqual(data, SAMPLE_TEXT)
-
- def test_write_text_to_bytestream(self):
- """Can we write strings to a byte stream?"""
- buf = make_buffer(noclose=True)
- with smart_open.open(buf, 'w') as sf:
- sf.write(SAMPLE_TEXT)
-
- self.assertEqual(buf.getvalue(), SAMPLE_BYTES)
-
- def test_write_text_to_bytestream_wt(self):
- """Can we write strings to a byte stream?"""
- buf = make_buffer(noclose=True)
- with smart_open.open(buf, 'wt') as sf:
- sf.write(SAMPLE_TEXT)
-
- self.assertEqual(buf.getvalue(), SAMPLE_BYTES)
-
- def test_write_text_to_bytestream_wtplus(self):
- """Can we write strings to a byte stream?"""
- buf = make_buffer(noclose=True)
- with smart_open.open(buf, 'wt+') as sf:
- sf.write(SAMPLE_TEXT)
-
- self.assertEqual(buf.getvalue(), SAMPLE_BYTES)
-
- def test_name_read(self):
- """Can we use the "name" attribute to decompress on the fly?"""
- data = SAMPLE_BYTES * 1000
- buf = make_buffer(initial_value=bz2.compress(data), name='data.bz2')
- with smart_open.open(buf, 'rb') as sf:
- data = sf.read()
- self.assertEqual(data, data)
-
- def test_name_write(self):
- """Can we use the "name" attribute to compress on the fly?"""
- data = SAMPLE_BYTES * 1000
- buf = make_buffer(name='data.bz2')
- with smart_open.open(buf, 'wb') as sf:
- sf.write(data)
- self.assertEqual(bz2.decompress(buf._value_when_closed), data)
-
- def test_open_side_effect(self):
- """
- Does our detection of the `name` attribute work with wrapped open()-ed streams?
-
- We `open()` a file with ".bz2" extension, pass the file object to `smart_open()` and check that
- we read decompressed data. This behavior is driven by detecting the `name` attribute in
- `_open_binary_stream()`.
- """
- data = SAMPLE_BYTES * 1000
- with named_temporary_file(prefix='smart_open_tests_', suffix=".bz2", delete=False) as tmpf:
- tmpf.write(bz2.compress(data))
- try:
- with open(tmpf.name, 'rb') as openf:
- with smart_open.open(openf, 'rb') as smartf:
- smart_data = smartf.read()
- self.assertEqual(data, smart_data)
- finally:
- os.unlink(tmpf.name)
-
-
#
# What exactly to patch here differs on _how_ we're opening the file.
# See the _shortcut_open function for details.
@@ -739,10 +604,10 @@ class SmartOpenReadTest(unittest.TestCase):
actual = fin.read()
self.assertEqual(expected, actual)
- @mock_s3
+ @moto.mock_s3
def test_read_never_returns_none(self):
"""read should never return None."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
test_string = u"ветер по морю гуляет..."
@@ -754,12 +619,12 @@ class SmartOpenReadTest(unittest.TestCase):
self.assertEqual(r.read(), b"")
self.assertEqual(r.read(), b"")
- @mock_s3
+ @moto.mock_s3
def test_read_newline_none(self):
"""Does newline open() parameter for reading work according to
https://docs.python.org/3/library/functions.html#open-newline-parameter
"""
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
# Unicode line separator and various others must never split lines
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
@@ -774,9 +639,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"last line"
])
- @mock_s3
+ @moto.mock_s3
def test_read_newline_empty(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -789,9 +654,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"last line"
])
- @mock_s3
+ @moto.mock_s3
def test_read_newline_cr(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -803,9 +668,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"\nlast line"
])
- @mock_s3
+ @moto.mock_s3
def test_read_newline_lf(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -817,9 +682,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"last line"
])
- @mock_s3
+ @moto.mock_s3
def test_read_newline_crlf(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -830,9 +695,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"last line"
])
- @mock_s3
+ @moto.mock_s3
def test_read_newline_slurp(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -843,9 +708,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"line\u2028 LF\nline\x1c CR\nline\x85 CRLF\nlast line"
)
- @mock_s3
+ @moto.mock_s3
def test_read_newline_binary(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
fout.write(test_file.encode("utf-8"))
@@ -857,12 +722,12 @@ class SmartOpenReadTest(unittest.TestCase):
u"last line".encode('utf-8')
])
- @mock_s3
+ @moto.mock_s3
def test_write_newline_none(self):
"""Does newline open() parameter for writing work according to
https://docs.python.org/3/library/functions.html#open-newline-parameter
"""
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
# Unicode line separator and various others must never split lines
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
# No newline parameter means newline=None, all LF are translatest to os.linesep
@@ -876,9 +741,9 @@ class SmartOpenReadTest(unittest.TestCase):
+ u"last line"
)
- @mock_s3
+ @moto.mock_s3
def test_write_newline_empty(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
# If newline='' nothing is changed
with smart_open.open("s3://mybucket/mykey", "w", encoding='utf-8', newline='') as fout:
@@ -889,9 +754,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
)
- @mock_s3
+ @moto.mock_s3
def test_write_newline_lf(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
# If newline='\n' nothing is changed
with smart_open.open("s3://mybucket/mykey", "w", encoding='utf-8', newline='\n') as fout:
@@ -902,9 +767,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
)
- @mock_s3
+ @moto.mock_s3
def test_write_newline_cr(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
# If newline='\r' all LF are replaced by CR
with smart_open.open("s3://mybucket/mykey", "w", encoding='utf-8', newline='\r') as fout:
@@ -915,9 +780,9 @@ class SmartOpenReadTest(unittest.TestCase):
u"line\u2028 LF\rline\x1c CR\rline\x85 CRLF\r\rlast line"
)
- @mock_s3
+ @moto.mock_s3
def test_write_newline_crlf(self):
- boto3.resource('s3').create_bucket(Bucket='mybucket')
+ _resource('s3').create_bucket(Bucket='mybucket')
test_file = u"line\u2028 LF\nline\x1c CR\rline\x85 CRLF\r\nlast line"
# If newline='\r\n' all LF are replaced by CRLF
with smart_open.open("s3://mybucket/mykey", "w", encoding='utf-8', newline='\r\n') as fout:
@@ -928,10 +793,10 @@ class SmartOpenReadTest(unittest.TestCase):
u"line\u2028 LF\r\nline\x1c CR\rline\x85 CRLF\r\r\nlast line"
)
- @mock_s3
+ @moto.mock_s3
def test_readline(self):
"""Does readline() return the correct file content?"""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
test_string = u"hello žluťoučký\u2028world!\nhow are you?".encode('utf8')
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
@@ -940,10 +805,10 @@ class SmartOpenReadTest(unittest.TestCase):
reader = smart_open.open("s3://mybucket/mykey", "rb")
self.assertEqual(reader.readline(), u"hello žluťoučký\u2028world!\n".encode("utf-8"))
- @mock_s3
+ @moto.mock_s3
def test_readline_iter(self):
"""Does __iter__ return the correct file content?"""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
lines = [u"всем\u2028привет!\n", u"что нового?"]
with smart_open.open("s3://mybucket/mykey", "wb") as fout:
@@ -956,10 +821,10 @@ class SmartOpenReadTest(unittest.TestCase):
self.assertEqual(lines[0], actual_lines[0])
self.assertEqual(lines[1], actual_lines[1])
- @mock_s3
+ @moto.mock_s3
def test_readline_eof(self):
"""Does readline() return empty string on EOF?"""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
with smart_open.open("s3://mybucket/mykey", "wb"):
pass
@@ -971,11 +836,11 @@ class SmartOpenReadTest(unittest.TestCase):
self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")
- @mock_s3
+ @moto.mock_s3
def test_s3_iter_lines(self):
"""Does s3_iter_lines give correct content?"""
# create fake bucket and fake key
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
test_string = u"hello žluťoučký\u2028world!\nhow are you?".encode('utf8')
with smart_open.open("s3://mybucket/mykey", "wb") as fin:
@@ -1095,14 +960,14 @@ class SmartOpenReadTest(unittest.TestCase):
smart_open_object = smart_open.open("webhdfs://127.0.0.1:8440/path/file", 'rb')
self.assertEqual(smart_open_object.read().decode("utf-8"), "line1\nline2")
- @mock_s3
+ @moto.mock_s3
def test_s3_iter_moto(self):
"""Are S3 files iterated over correctly?"""
# a list of strings to test with
expected = [b"*" * 5 * 1024**2] + [b'0123456789'] * 1024 + [b"test"]
# create fake bucket and fake key
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
tp = dict(s3_min_part_size=5 * 1024**2)
@@ -1127,10 +992,10 @@ class SmartOpenReadTest(unittest.TestCase):
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)
- @mock_s3
+ @moto.mock_s3
def test_s3_read_moto(self):
"""Are S3 files read correctly?"""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
# write some bogus key so we can check it below
@@ -1144,10 +1009,10 @@ class SmartOpenReadTest(unittest.TestCase):
self.assertEqual(content[14:], smart_open_object.read()) # read the rest
- @mock_s3
+ @moto.mock_s3
def test_s3_seek_moto(self):
"""Does seeking in S3 files work correctly?"""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
# write some bogus key so we can check it below
@@ -1165,10 +1030,10 @@ class SmartOpenReadTest(unittest.TestCase):
smart_open_object.seek(0)
self.assertEqual(content, smart_open_object.read(-1)) # same thing
- @mock_s3
+ @moto.mock_s3
def test_s3_tell(self):
"""Does tell() work when S3 file is opened for text writing? """
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
with smart_open.open("s3://mybucket/mykey", "w") as fout:
@@ -1358,11 +1223,11 @@ class SmartOpenTest(unittest.TestCase):
["hdfs", "dfs", "-put", "-f", "-", "/tmp/test.txt"], stdin=mock_subprocess.PIPE
)
- @mock_s3
+ @moto.mock_s3
def test_s3_modes_moto(self):
"""Do s3:// open modes work correctly?"""
# fake bucket and key
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
raw_data = b"second test"
@@ -1377,7 +1242,7 @@ class SmartOpenTest(unittest.TestCase):
self.assertEqual(output, [raw_data])
- @mock_s3
+ @moto.mock_s3
def test_s3_metadata_write(self):
# Read local file fixture
path = os.path.join(CURR_DIR, 'test_data/crime-and-punishment.txt.gz')
@@ -1386,7 +1251,7 @@ class SmartOpenTest(unittest.TestCase):
data = fd.read()
# Create a test bucket
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='mybucket')
tp = {
@@ -1410,7 +1275,7 @@ class SmartOpenTest(unittest.TestCase):
self.assertIn('text/plain', key.content_type)
self.assertEqual(key.content_encoding, 'gzip')
- @mock_s3
+ @moto.mock_s3
def test_write_bad_encoding_strict(self):
"""Should abort on encoding error."""
text = u'欲しい気持ちが成長しすぎて'
@@ -1420,7 +1285,7 @@ class SmartOpenTest(unittest.TestCase):
with smart_open.open(infile.name, 'w', encoding='koi8-r', errors='strict') as fout:
fout.write(text)
- @mock_s3
+ @moto.mock_s3
def test_write_bad_encoding_replace(self):
"""Should replace characters that failed to encode."""
text = u'欲しい気持ちが成長しすぎて'
@@ -1521,18 +1386,19 @@ def gzip_compress(data, filename=None):
return buf.getvalue()
-class CompressionFormatTest(parameterizedtestcase.ParameterizedTestCase):
+class CompressionFormatTest(unittest.TestCase):
"""Test transparent (de)compression."""
def write_read_assertion(self, suffix):
- test_file = make_buffer(name='file' + suffix)
- with smart_open.open(test_file, 'wb') as fout:
- fout.write(SAMPLE_BYTES)
- self.assertNotEqual(SAMPLE_BYTES, test_file._value_when_closed)
- # we have to recreate the buffer because it is closed
- test_file = make_buffer(initial_value=test_file._value_when_closed, name=test_file.name)
- with smart_open.open(test_file, 'rb') as fin:
- self.assertEqual(fin.read(), SAMPLE_BYTES)
+ with named_temporary_file(suffix=suffix) as tmp:
+ with smart_open.open(tmp.name, 'wb') as fout:
+ fout.write(SAMPLE_BYTES)
+
+ with open(tmp.name, 'rb') as fin:
+ assert fin.read() != SAMPLE_BYTES # is the content really compressed? (built-in fails)
+
+ with smart_open.open(tmp.name, 'rb') as fin:
+ assert fin.read() == SAMPLE_BYTES # ... smart_open correctly opens and decompresses
def test_open_gz(self):
"""Can open gzip?"""
@@ -1542,20 +1408,6 @@ class CompressionFormatTest(parameterizedtestcase.ParameterizedTestCase):
m = hashlib.md5(data)
assert m.hexdigest() == '18473e60f8c7c98d29d65bf805736a0d', 'Failed to read gzip'
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("extension", "compressed"),
- [
- (".gz", gzip_compress(_DECOMPRESSED_DATA, 'key')),
- (".bz2", bz2.compress(_DECOMPRESSED_DATA)),
- ],
- )
- def test_closes_compressed_stream(self, extension, compressed):
- """Transparent compression closes the compressed stream?"""
- compressed_stream = make_buffer(initial_value=compressed, name=f"file{extension}")
- with smart_open.open(compressed_stream, encoding="utf-8"):
- pass
- assert compressed_stream.close.call_count == 1
-
def test_write_read_gz(self):
"""Can write and read gzip?"""
self.write_read_assertion('.gz')
@@ -1649,12 +1501,12 @@ class MultistreamsBZ2Test(unittest.TestCase):
class S3OpenTest(unittest.TestCase):
- @mock_s3
+ @moto.mock_s3
def test_r(self):
"""Reading a UTF string should work."""
text = u"физкульт-привет!"
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = s3.Object('bucket', 'key')
key.put(Body=text.encode('utf-8'))
@@ -1670,10 +1522,10 @@ class S3OpenTest(unittest.TestCase):
uri = smart_open_lib._parse_uri("s3://bucket/key")
self.assertRaises(NotImplementedError, smart_open.open, uri, "x")
- @mock_s3
+ @moto.mock_s3
def test_rw_encoding(self):
"""Should read and write text, respecting encodings, etc."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key"
@@ -1694,10 +1546,10 @@ class S3OpenTest(unittest.TestCase):
with smart_open.open(key, "r", encoding="euc-jp", errors="replace") as fin:
fin.read()
- @mock_s3
+ @moto.mock_s3
def test_rw_gzip(self):
"""Should read/write gzip files, implicitly and explicitly."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.gz"
@@ -1708,7 +1560,7 @@ class S3OpenTest(unittest.TestCase):
#
# Check that what we've created is a gzip.
#
- with smart_open.open(key, "rb", ignore_ext=True) as fin:
+ with smart_open.open(key, "rb", compression='disable') as fin:
gz = gzip.GzipFile(fileobj=fin)
self.assertEqual(gz.read().decode("utf-8"), text)
@@ -1718,22 +1570,22 @@ class S3OpenTest(unittest.TestCase):
with smart_open.open(key, "rb") as fin:
self.assertEqual(fin.read().decode("utf-8"), text)
- @mock_s3
+ @moto.mock_s3
@mock.patch('smart_open.smart_open_lib._inspect_kwargs', mock.Mock(return_value={}))
def test_gzip_write_mode(self):
"""Should always open in binary mode when writing through a codec."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
with mock.patch('smart_open.s3.open', return_value=open(__file__, 'rb')) as mock_open:
smart_open.open("s3://bucket/key.gz", "wb")
mock_open.assert_called_with('bucket', 'key.gz', 'wb')
- @mock_s3
+ @moto.mock_s3
@mock.patch('smart_open.smart_open_lib._inspect_kwargs', mock.Mock(return_value={}))
def test_gzip_read_mode(self):
"""Should always open in binary mode when reading through a codec."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.gz"
@@ -1745,10 +1597,10 @@ class S3OpenTest(unittest.TestCase):
smart_open.open(key, "r")
mock_open.assert_called_with('bucket', 'key.gz', 'rb')
- @mock_s3
+ @moto.mock_s3
def test_read_encoding(self):
"""Should open the file with the correct encoding, explicit text read."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt"
text = u'это знала ева, это знал адам, колеса любви едут прямо по нам'
@@ -1758,10 +1610,10 @@ class S3OpenTest(unittest.TestCase):
actual = fin.read()
self.assertEqual(text, actual)
- @mock_s3
+ @moto.mock_s3
def test_read_encoding_implicit_text(self):
"""Should open the file with the correct encoding, implicit text read."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt"
text = u'это знала ева, это знал адам, колеса любви едут прямо по нам'
@@ -1771,10 +1623,10 @@ class S3OpenTest(unittest.TestCase):
actual = fin.read()
self.assertEqual(text, actual)
- @mock_s3
+ @moto.mock_s3
def test_write_encoding(self):
"""Should open the file for writing with the correct encoding."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt"
text = u'какая боль, какая боль, аргентина - ямайка, 5-0'
@@ -1785,10 +1637,10 @@ class S3OpenTest(unittest.TestCase):
actual = fin.read()
self.assertEqual(text, actual)
- @mock_s3
+ @moto.mock_s3
def test_write_bad_encoding_strict(self):
"""Should open the file for writing with the correct encoding."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt"
text = u'欲しい気持ちが成長しすぎて'
@@ -1797,10 +1649,10 @@ class S3OpenTest(unittest.TestCase):
with smart_open.open(key, 'w', encoding='koi8-r', errors='strict') as fout:
fout.write(text)
- @mock_s3
+ @moto.mock_s3
def test_write_bad_encoding_replace(self):
"""Should open the file for writing with the correct encoding."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt"
text = u'欲しい気持ちが成長しすぎて'
@@ -1812,10 +1664,10 @@ class S3OpenTest(unittest.TestCase):
actual = fin.read()
self.assertEqual(expected, actual)
- @mock_s3
+ @moto.mock_s3
def test_write_text_gzip(self):
"""Should open the file for writing with the correct encoding."""
- s3 = boto3.resource('s3')
+ s3 = _resource('s3')
s3.create_bucket(Bucket='bucket')
key = "s3://bucket/key.txt.gz"
text = u'какая боль, какая боль, аргентина - ямайка, 5-0'
@@ -1880,146 +1732,138 @@ class CheckKwargsTest(unittest.TestCase):
self.assertEqual(expected, actual)
-@mock_s3
+def initialize_bucket():
+ s3 = _resource("s3")
+ bucket = s3.create_bucket(Bucket="bucket")
+ bucket.wait_until_exists()
+
+ bucket.Object('gzipped').put(Body=gzip_compress(_DECOMPRESSED_DATA))
+ bucket.Object('bzipped').put(Body=bz2.compress(_DECOMPRESSED_DATA))
+
+
+@moto.mock_s3
@mock.patch('time.time', _MOCK_TIME)
-class S3CompressionTestCase(parameterizedtestcase.ParameterizedTestCase):
+def test_s3_gzip_compress_sanity():
+ """Does our gzip_compress function actually produce gzipped data?"""
+ initialize_bucket()
+ assert gzip.decompress(gzip_compress(_DECOMPRESSED_DATA)) == _DECOMPRESSED_DATA
- def setUp(self):
- s3 = boto3.resource("s3")
- bucket = s3.create_bucket(Bucket="bucket")
- bucket.wait_until_exists()
-
- bucket.Object('gzipped').put(Body=gzip_compress(_DECOMPRESSED_DATA))
- bucket.Object('bzipped').put(Body=bz2.compress(_DECOMPRESSED_DATA))
-
- def test_gzip_compress_sanity(self):
- """Does our gzip_compress function actually produce gzipped data?"""
- assert gzip.decompress(gzip_compress(_DECOMPRESSED_DATA)) == _DECOMPRESSED_DATA
-
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("url", "_compression"),
- [
- ("s3://bucket/gzipped", ".gz"),
- ("s3://bucket/bzipped", ".bz2"),
- ]
- )
- def test_read_explicit(self, url, _compression):
- """Can we read using the explicitly specified compression?"""
- with smart_open.open(url, 'rb', compression=_compression) as fin:
- assert fin.read() == _DECOMPRESSED_DATA
-
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("_compression", "expected"),
- [
- (".gz", gzip_compress(_DECOMPRESSED_DATA, 'key')),
- (".bz2", bz2.compress(_DECOMPRESSED_DATA)),
- ],
- )
- def test_write_explicit(self, _compression, expected):
- """Can we write using the explicitly specified compression?"""
- with smart_open.open("s3://bucket/key", "wb", compression=_compression) as fout:
- fout.write(_DECOMPRESSED_DATA)
-
- with smart_open.open("s3://bucket/key", "rb", compression=NO_COMPRESSION) as fin:
- assert fin.read() == expected
-
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("url", "_compression", "expected"),
- [
- ("s3://bucket/key.gz", ".gz", gzip_compress(_DECOMPRESSED_DATA, 'key.gz')),
- ("s3://bucket/key.bz2", ".bz2", bz2.compress(_DECOMPRESSED_DATA)),
- ],
- )
- def test_write_implicit(self, url, _compression, expected):
- """Can we determine the compression from the file extension?"""
- with smart_open.open(url, "wb", compression=INFER_FROM_EXTENSION) as fout:
- fout.write(_DECOMPRESSED_DATA)
-
- with smart_open.open(url, "rb", compression=NO_COMPRESSION) as fin:
- assert fin.read() == expected
-
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("url", "_compression", "expected"),
- [
- ("s3://bucket/key.gz", ".gz", gzip_compress(_DECOMPRESSED_DATA, 'key.gz')),
- ("s3://bucket/key.bz2", ".bz2", bz2.compress(_DECOMPRESSED_DATA)),
- ],
- )
- def test_ignore_ext(self, url, _compression, expected):
- """Can we handle the deprecated ignore_ext parameter when reading/writing?"""
- with smart_open.open(url, "wb") as fout:
- fout.write(_DECOMPRESSED_DATA)
-
- with smart_open.open(url, "rb", ignore_ext=True) as fin:
- assert fin.read() == expected
-
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ("extension", "kwargs", "error"),
- [
- ("", dict(compression="foo"), ValueError),
- ("", dict(compression="foo", ignore_ext=True), ValueError),
- ("", dict(compression=NO_COMPRESSION, ignore_ext=True), ValueError),
- (
- ".gz",
- dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
- ValueError,
- ),
- (
- ".bz2",
- dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
- ValueError,
- ),
- ("", dict(compression=".gz", ignore_ext=True), ValueError),
- ("", dict(compression=".bz2", ignore_ext=True), ValueError),
- ],
- )
- def test_compression_invalid(self, extension, kwargs, error):
- """Should detect and error on these invalid inputs"""
- with pytest.raises(error):
- smart_open.open(f"s3://bucket/key{extension}", "wb", **kwargs)
-
- with pytest.raises(error):
- smart_open.open(f"s3://bucket/key{extension}", "rb", **kwargs)
-
-
-class GetBinaryModeTest(parameterizedtestcase.ParameterizedTestCase):
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ('mode', 'expected'),
- [
- ('r', 'rb'),
- ('r+', 'rb+'),
- ('rt', 'rb'),
- ('rt+', 'rb+'),
- ('r+t', 'rb+'),
- ('w', 'wb'),
- ('w+', 'wb+'),
- ('wt', 'wb'),
- ('wt+', 'wb+'),
- ('w+t', 'wb+'),
- ('a', 'ab'),
- ('a+', 'ab+'),
- ('at', 'ab'),
- ('at+', 'ab+'),
- ('a+t', 'ab+'),
- ]
- )
- def test(self, mode, expected):
- actual = smart_open.smart_open_lib._get_binary_mode(mode)
- assert actual == expected
- @parameterizedtestcase.ParameterizedTestCase.parameterize(
- ('mode', ),
- [
- ('rw', ),
- ('rwa', ),
- ('rbt', ),
- ('r++', ),
- ('+', ),
- ('x', ),
- ]
- )
- def test_bad(self, mode):
- self.assertRaises(ValueError, smart_open.smart_open_lib._get_binary_mode, mode)
+@moto.mock_s3
+@mock.patch('time.time', _MOCK_TIME)
+@pytest.mark.parametrize(
+ "url,_compression",
+ [
+ ("s3://bucket/gzipped", ".gz"),
+ ("s3://bucket/bzipped", ".bz2"),
+ ]
+)
+def test_s3_read_explicit(url, _compression):
+ """Can we read using the explicitly specified compression?"""
+ initialize_bucket()
+ with smart_open.open(url, 'rb', compression=_compression) as fin:
+ assert fin.read() == _DECOMPRESSED_DATA
+
+
+@moto.mock_s3
+@mock.patch('time.time', _MOCK_TIME)
+@pytest.mark.parametrize(
+ "_compression,expected",
+ [
+ (".gz", gzip_compress(_DECOMPRESSED_DATA, 'key')),
+ (".bz2", bz2.compress(_DECOMPRESSED_DATA)),
+ ],
+)
+def test_s3_write_explicit(_compression, expected):
+ """Can we write using the explicitly specified compression?"""
+ initialize_bucket()
+
+ with smart_open.open("s3://bucket/key", "wb", compression=_compression) as fout:
+ fout.write(_DECOMPRESSED_DATA)
+
+ with smart_open.open("s3://bucket/key", "rb", compression=NO_COMPRESSION) as fin:
+ assert fin.read() == expected
+
+
+@moto.mock_s3
+@mock.patch('time.time', _MOCK_TIME)
+@pytest.mark.parametrize(
+ "url,_compression,expected",
+ [
+ ("s3://bucket/key.gz", ".gz", gzip_compress(_DECOMPRESSED_DATA, 'key.gz')),
+ ("s3://bucket/key.bz2", ".bz2", bz2.compress(_DECOMPRESSED_DATA)),
+ ],
+)
+def test_s3_write_implicit(url, _compression, expected):
+ """Can we determine the compression from the file extension?"""
+ initialize_bucket()
+
+ with smart_open.open(url, "wb", compression=INFER_FROM_EXTENSION) as fout:
+ fout.write(_DECOMPRESSED_DATA)
+
+ with smart_open.open(url, "rb", compression=NO_COMPRESSION) as fin:
+ assert fin.read() == expected
+
+
+@moto.mock_s3
+@mock.patch('time.time', _MOCK_TIME)
+@pytest.mark.parametrize(
+ "url,_compression,expected",
+ [
+ ("s3://bucket/key.gz", ".gz", gzip_compress(_DECOMPRESSED_DATA, 'key.gz')),
+ ("s3://bucket/key.bz2", ".bz2", bz2.compress(_DECOMPRESSED_DATA)),
+ ],
+)
+def test_s3_disable_compression(url, _compression, expected):
+ """Can we handle the compression parameter when reading/writing?"""
+ initialize_bucket()
+
+ with smart_open.open(url, "wb") as fout:
+ fout.write(_DECOMPRESSED_DATA)
+
+ with smart_open.open(url, "rb", compression='disable') as fin:
+ assert fin.read() == expected
+
+
+@pytest.mark.parametrize(
+ 'mode,expected',
+ [
+ ('r', 'rb'),
+ ('r+', 'rb+'),
+ ('rt', 'rb'),
+ ('rt+', 'rb+'),
+ ('r+t', 'rb+'),
+ ('w', 'wb'),
+ ('w+', 'wb+'),
+ ('wt', 'wb'),
+ ('wt+', 'wb+'),
+ ('w+t', 'wb+'),
+ ('a', 'ab'),
+ ('a+', 'ab+'),
+ ('at', 'ab'),
+ ('at+', 'ab+'),
+ ('a+t', 'ab+'),
+ ]
+)
+def test_get_binary_mode(mode, expected):
+ actual = smart_open.smart_open_lib._get_binary_mode(mode)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ 'mode',
+ [
+ ('rw', ),
+ ('rwa', ),
+ ('rbt', ),
+ ('r++', ),
+ ('+', ),
+ ('x', ),
+ ]
+)
+def test_get_binary_mode_bad(mode):
+ with pytest.raises(ValueError):
+ smart_open.smart_open_lib._get_binary_mode(mode)
def test_backwards_compatibility_wrapper():
@@ -2039,6 +1883,44 @@ def test_backwards_compatibility_wrapper():
smart_open.smart_open(fpath, unsupported_keyword_param=123)
+@pytest.mark.skipif(os.name == "nt", reason="this test does not work on Windows")
+def test_read_file_descriptor():
+ with smart_open.open(__file__) as fin:
+ expected = fin.read()
+
+ fd = os.open(__file__, os.O_RDONLY)
+ with smart_open.open(fd) as fin:
+ actual = fin.read()
+
+ assert actual == expected
+
+
+@pytest.mark.skipif(os.name == "nt", reason="this test does not work on Windows")
+def test_write_file_descriptor():
+ with named_temporary_file() as tmp:
+ with smart_open.open(os.open(tmp.name, os.O_WRONLY), 'wt') as fout:
+ fout.write("hello world")
+
+ with smart_open.open(tmp.name, 'rt') as fin:
+ assert fin.read() == "hello world"
+
+
+@moto.mock_s3()
+def test_to_boto3():
+ resource = _resource('s3')
+ resource.create_bucket(Bucket='mybucket')
+ #
+ # If we don't specify encoding explicitly, the platform-dependent encoding
+ # will be used, and it may not necessarily support Unicode, breaking this
+ # test under Windows on github actions.
+ #
+ with smart_open.open('s3://mybucket/key.txt', 'wt', encoding='utf-8') as fout:
+ fout.write('я бегу по вызженной земле, гермошлем захлопнув на ходу')
+ obj = fout.to_boto3(resource)
+ assert obj.bucket_name == 'mybucket'
+ assert obj.key == 'key.txt'
+
+
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()
diff --git a/smart_open/tests/test_ssh.py b/smart_open/tests/test_ssh.py
index 18a8f16..a4b91eb 100644
--- a/smart_open/tests/test_ssh.py
+++ b/smart_open/tests/test_ssh.py
@@ -4,6 +4,8 @@ import logging
import unittest
from unittest import mock
+from paramiko import SSHException
+
import smart_open.ssh
@@ -12,8 +14,8 @@ def mock_ssh(func):
smart_open.ssh._SSH.clear()
return func(*args, **kwargs)
- return mock.patch("paramiko.client.SSHClient.get_transport")(
- mock.patch("paramiko.client.SSHClient.connect")(wrapper)
+ return mock.patch("paramiko.SSHClient.get_transport")(
+ mock.patch("paramiko.SSHClient.connect")(wrapper)
)
@@ -49,6 +51,23 @@ class SSHOpen(unittest.TestCase):
)
mock_connect.assert_called_with("some-host", 22, username="user", key_filename="key")
+ @mock_ssh
+ def test_reconnect_after_session_timeout(self, mock_connect, get_transp_mock):
+ mock_sftp = get_transp_mock().open_sftp_client()
+ get_transp_mock().open_sftp_client.reset_mock()
+
+ def mocked_open_sftp():
+ if len(mock_connect.call_args_list) < 2: # simulate timeout until second connect()
+ yield SSHException('SSH session not active')
+ while True:
+ yield mock_sftp
+
+ get_transp_mock().open_sftp_client.side_effect = mocked_open_sftp()
+
+ smart_open.open("ssh://user:pass@some-host/")
+ mock_connect.assert_called_with("some-host", 22, username="user", password="pass")
+ mock_sftp.open.assert_called_once()
+
if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG)
diff --git a/smart_open/transport.py b/smart_open/transport.py
index 00fb27d..086ea2b 100644
--- a/smart_open/transport.py
+++ b/smart_open/transport.py
@@ -54,15 +54,15 @@ def register_transport(submodule):
# Save only the last module name piece
module_name = module_name.rsplit(".")[-1]
- if hasattr(submodule, 'SCHEME'):
+ if hasattr(submodule, "SCHEME"):
schemes = [submodule.SCHEME]
- elif hasattr(submodule, 'SCHEMES'):
+ elif hasattr(submodule, "SCHEMES"):
schemes = submodule.SCHEMES
else:
- raise ValueError('%r does not have a .SCHEME or .SCHEMES attribute' % submodule)
+ raise ValueError("%r does not have a .SCHEME or .SCHEMES attribute" % submodule)
- for f in ('open', 'open_uri', 'parse_uri'):
- assert hasattr(submodule, f), '%r is missing %r' % (submodule, f)
+ for f in ("open", "open_uri", "parse_uri"):
+ assert hasattr(submodule, f), "%r is missing %r" % (submodule, f)
for scheme in schemes:
assert scheme not in _REGISTRY
@@ -80,7 +80,9 @@ def get_transport(scheme):
"""
global _ERRORS, _MISSING_DEPS_ERROR, _REGISTRY, SUPPORTED_SCHEMES
expected = SUPPORTED_SCHEMES
- readme_url = 'https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst'
+ readme_url = (
+ "https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst"
+ )
message = (
"Unable to handle scheme %(scheme)r, expected one of %(expected)r. "
"Extra dependencies required by %(scheme)r may be missing. "
@@ -94,13 +96,14 @@ def get_transport(scheme):
register_transport(smart_open.local_file)
-register_transport('smart_open.azure')
-register_transport('smart_open.gcs')
-register_transport('smart_open.hdfs')
-register_transport('smart_open.http')
-register_transport('smart_open.s3')
-register_transport('smart_open.ssh')
-register_transport('smart_open.webhdfs')
+register_transport("smart_open.azure")
+register_transport("smart_open.ftp")
+register_transport("smart_open.gcs")
+register_transport("smart_open.hdfs")
+register_transport("smart_open.http")
+register_transport("smart_open.s3")
+register_transport("smart_open.ssh")
+register_transport("smart_open.webhdfs")
SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys()))
"""The transport schemes that the local installation of ``smart_open`` supports."""
diff --git a/smart_open/version.py b/smart_open/version.py
index b70d87d..676f190 100644
--- a/smart_open/version.py
+++ b/smart_open/version.py
@@ -1,4 +1,4 @@
-__version__ = '5.2.1'
+__version__ = '6.3.0'
if __name__ == '__main__':
diff --git a/tox.ini b/tox.ini
deleted file mode 100644
index f81f4af..0000000
--- a/tox.ini
+++ /dev/null
@@ -1,85 +0,0 @@
-[tox]
-minversion = 2.0
-envlist = py{36,37,38,39}-{test,doctest,integration,benchmark}, sdist, flake8
-
-[pytest]
-addopts = -rfxEXs --durations=20 --showlocals
-
-[flake8]
-ignore = E12, W503, E226
-max-line-length = 110
-show-source = True
-
-[testenv]
-passenv = SO_* AWS_* COVERALLS_* RUN_BENCHMARKS GITHUB_*
-recreate = True
-whitelist_externals =
- sh
- bash
-
-deps =
- .[all]
- .[test]
-
- integration: numpy
-
- benchmark: pytest_benchmark
-
- benchmark: awscli
-
-commands =
- test: pytest smart_open -v
-
- integration: python tox_helpers/run_integration_tests.py
-
- benchmark: python tox_helpers/run_benchmarks.py
-
- doctest: python tox_helpers/doctest.py
-
-
-[testenv:sdist]
-whitelist_externals = rm
-recreate = True
-commands =
- rm -rf dist/
- python setup.py sdist
-
-
-[testenv:flake8]
-skip_install = True
-recreate = True
-deps = flake8
-commands = flake8 smart_open/ {posargs}
-
-
-[testenv:check_keys]
-skip_install = True
-recreate = True
-deps = boto3
-commands = python tox_helpers/check_keys.py
-
-
-[testenv:enable_moto_server]
-skip_install = True
-recreate = False
-deps = moto[server]
-commands = bash tox_helpers/helpers.sh enable_moto_server
-
-
-[testenv:disable_moto_server]
-skip_install = True
-recreate = False
-deps =
-commands = bash tox_helpers/helpers.sh disable_moto_server
-
-
-[testenv:test_coverage]
-skip_install = True
-recreate = True
-deps =
- .
- pytest-cov
-commands =
- python tox_helpers/test_missing_dependencies.py
- pip install .[all,test]
- pytest smart_open -v --cov smart_open --cov-report term-missing --cov-append
diff --git a/tox_helpers/README.txt b/tox_helpers/README.txt
deleted file mode 100644
index f8a1cee..0000000
--- a/tox_helpers/README.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-This subdirectory contains helper scripts for our tox.ini file.
-
-They are designed to be platform-independent: they run on both Linux and Windows.
diff --git a/tox_helpers/check_keys.py b/tox_helpers/check_keys.py
deleted file mode 100644
index 6fd52ed..0000000
--- a/tox_helpers/check_keys.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Check that the environment variables contain valid boto3 credentials."""
-import logging
-import os
-import boto3
-import boto3.session
-
-
-def check(session):
- client = session.client('s3')
- try:
- response = client.list_buckets()
- except Exception as e:
- logging.exception(e)
- return None
- else:
- return [b['Name'] for b in response['Buckets']]
-
-
-def check_implicit():
- session = boto3.session.Session()
- buckets = check(session)
- if buckets:
- print('implicit check OK: %r' % buckets)
- else:
- print('implicit check failed')
-
-
-def check_explicit():
- key_id = os.environ.get('AWS_ACCESS_KEY_ID')
- secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
- if not (key_id and secret_key):
- print('no credentials found in os.environ, skipping explicit check')
- return
-
- session = boto3.session.Session(aws_access_key_id=key_id, aws_secret_access_key=secret_key)
- buckets = check(session)
- if buckets:
- print('explicit check OK: %r' % buckets)
- else:
- print('explicit check failed')
-
-
-def main():
- check_implicit()
- check_explicit()
-
-
-if __name__ == '__main__':
- main()
diff --git a/tox_helpers/helpers.sh b/tox_helpers/helpers.sh
deleted file mode 100755
index 0230193..0000000
--- a/tox_helpers/helpers.sh
+++ /dev/null
@@ -1,14 +0,0 @@
-#!/bin/bash
-
-set -e
-set -x
-
-enable_moto_server(){
- moto_server -p5000 2>/dev/null&
-}
-
-disable_moto_server(){
- lsof -i tcp:5000 | tail -n1 | cut -f2 -d" " | xargs kill -9
-}
-
-"$@"
More details
Historical runs
- patch-application-failed: Patch application failed: exclude-gcs.patch
- patch-application-failed: Patch application failed: exclude-gcs.patch
- bad-gateway: Failed to push result branch: Unexpected HTTP status 502 for https://janitor.debian.net/git/smart-open/info/refs?service=git-upload-pack: Unable to handle http code: Bad Gateway
- push-failed: Failed to push result branch: Connection closed: Connection closed early The remote server unexpectedly closed the connection.
- push-failed: Failed to push result branch: Connection closed: Connection closed early The remote server unexpectedly closed the connection.