diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml new file mode 100644 index 00000000..fd2c4dce --- /dev/null +++ b/.github/workflows/windows.yml @@ -0,0 +1,50 @@ +name: Windows Test +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] +jobs: + runtests: + runs-on: windows-2019 + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v1 + with: + python-version: '3.6' + + # Unfortunately msiexec does seem to work on the windows runner. + # 7zip is able to unpack .msi files but it loses the directory structure. + # We fix that in the next step. + - name: Download MonetDB + run: | + curl https://www.monetdb.org/downloads/Windows/Jan2022-SP1/MonetDB5-SQL-Installer-x86_64-20220207.msi -o ${{ runner.temp }}\monetdb.msi --no-progress-meter + dir ${{ runner.temp }} + 7z x ${{ runner.temp }}\monetdb.msi -o${{ runner.temp }}\staging + dir ${{ runner.temp }}\staging + + # Run a script to restore the directory structure and see if it works (a little) + - name: Install MonetDB + run: | + python tests/install_monetdb_from_msi_dir.py ${{ runner.temp }}\staging ${{ runner.temp }}\MONET + dir ${{ runner.temp }}\MONET + dir ${{ runner.temp }}\MONET\bin + ${{ runner.temp }}\MONET\bin\mserver5.exe --help + + - name: Setup virtual environment + run: | + python -m venv venv + venv\Scripts\Activate.ps1 + python -m pip install -r tests/requirements.txt + + # Script tests/windows_tests.py starts an mserver in the background + # and runs pytest, excluding the Control tests. + - name: run the tests + run: | + venv\Scripts\Activate.ps1 + mkdir ${{ runner.temp }}\dbfarm + python tests/windows_tests.py ${{ runner.temp }}\MONET ${{ runner.temp }}\dbfarm demo 50000 + echo ""; echo ""; echo "================ SERVER STDERR: ==================="; echo "" + type ${{ runner.temp }}\dbfarm\errlog + diff --git a/doc/api.rst b/doc/api.rst index 84bb6f5c..f1897146 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -38,6 +38,22 @@ MAPI :show-inheritance: +File Uploads and Downloads +========================== + +Classes related to file transfer requests as used by COPY INTO ON CLIENT. + +.. automodule:: pymonetdb.filetransfer + :members: Uploader, Downloader, SafeDirectoryHandler + :member-order: bysource + +.. automodule:: pymonetdb.filetransfer.uploads + :members: Upload + +.. automodule:: pymonetdb.filetransfer.downloads + :members: Download + + MonetDB remote control ====================== diff --git a/doc/development.rst b/doc/development.rst index 0de39578..2323e95b 100644 --- a/doc/development.rst +++ b/doc/development.rst @@ -1,4 +1,4 @@ -development +Development =========== Github diff --git a/doc/examples.rst b/doc/examples.rst index 63017912..c33547ea 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -1,7 +1,12 @@ Examples ======== -examples usage below:: +Here are some examples of how to use pymonetdb. + +Example session +--------------- + +:: > # import the SQL module > import pymonetdb @@ -49,6 +54,8 @@ examples usage below:: ('commit_action', 'smallint', 1, 1, None, None, None), ('temporary', 'tinyint', 1, 1, None, None, None)] +MAPI Connection +--------------- If you would like to communicate with the database at a lower level you can use the MAPI library:: @@ -60,3 +67,10 @@ you can use the MAPI library:: > server.cmd("sSELECT * FROM tables;") ... + +CSV Upload +-------------- + +This is an example script that uploads some csv data from the local file system: + +.. literalinclude:: examples/uploadcsv.py diff --git a/doc/examples/uploadcsv.py b/doc/examples/uploadcsv.py new file mode 100644 index 00000000..83f2b6ba --- /dev/null +++ b/doc/examples/uploadcsv.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import os +import pymonetdb + +# Create the data directory and the CSV file +try: + os.mkdir("datadir") +except FileExistsError: + pass +with open("datadir/data.csv", "w") as f: + for i in range(10): + print(f"{i},item{i + 1}", file=f) + +# Connect to MonetDB and register the upload handler +conn = pymonetdb.connect('demo') +handler = pymonetdb.SafeDirectoryHandler("datadir") +conn.set_uploader(handler) +cursor = conn.cursor() + +# Set up the table +cursor.execute("DROP TABLE foo") +cursor.execute("CREATE TABLE foo(i INT, t TEXT)") + +# Upload the data, this will ask the handler to upload data.csv +cursor.execute("COPY INTO foo FROM 'data.csv' ON CLIENT USING DELIMITERS ','") + +# Check that it has loaded +cursor.execute("SELECT t FROM foo WHERE i = 9") +row = cursor.fetchone() +assert row[0] == 'item10' + +# Goodbye +conn.commit() +cursor.close() +conn.close() \ No newline at end of file diff --git a/doc/examples/uploaddyn.py b/doc/examples/uploaddyn.py new file mode 100644 index 00000000..e562563b --- /dev/null +++ b/doc/examples/uploaddyn.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +import pymonetdb + +class MyUploader(pymonetdb.Uploader): + def handle_upload(self, upload, filename, text_mode, skip_amount): + tw = upload.text_writer() + for i in range(skip_amount, 1000): + print(f'{i},number{i}', file=tw) + +conn = pymonetdb.connect('demo') +conn.set_uploader(MyUploader()) + +cursor = conn.cursor() +cursor.execute("DROP TABLE foo") +cursor.execute("CREATE TABLE foo(i INT, t TEXT)") +cursor.execute("COPY 10 RECORDS OFFSET 7 INTO foo FROM 'data.csv' ON CLIENT USING DELIMITERS ','") +cursor.execute("SELECT COUNT(i), MIN(i), MAX(i) FROM foo") +row = cursor.fetchone() +print(row) +assert row[0] == 10 # ten records numbered +assert row[1] == 6 # offset 7 means skip first 6, that is, records 0, .., 5 +assert row[2] == 15 # 10 records: 6, 7,8, 9,10,11, 12,13,14, and 15 + +# Goodbye +conn.commit() +cursor.close() +conn.close() \ No newline at end of file diff --git a/doc/examples/uploadsafe.py b/doc/examples/uploadsafe.py new file mode 100644 index 00000000..d5dc5691 --- /dev/null +++ b/doc/examples/uploadsafe.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +import pathlib +import shutil +import pymonetdb + +class MyUploader(pymonetdb.Uploader): + def __init__(self, dir): + self.dir = pathlib.Path(dir) + + def handle_upload(self, upload, filename, text_mode, skip_amount): + # security check + path = self.dir.joinpath(filename).resolve() + if not str(path).startswith(str(self.dir.resolve())): + return upload.send_error('Forbidden') + # open + tw = upload.text_writer() + with open(path) as f: + # skip + for i in range(skip_amount): + f.readline() + # bulk upload + shutil.copyfileobj(f, tw) + +conn = pymonetdb.connect('demo') +conn.set_uploader(MyUploader('datadir')) + +cursor = conn.cursor() +cursor.execute("DROP TABLE foo") +cursor.execute("CREATE TABLE foo(i INT, t TEXT)") +cursor.execute("COPY 10 RECORDS OFFSET 7 INTO foo FROM 'data.csv' ON CLIENT USING DELIMITERS ','") +cursor.execute("SELECT COUNT(i), MIN(i), MAX(i) FROM foo") +row = cursor.fetchone() +print(row) + +# Goodbye +conn.commit() +cursor.close() +conn.close() \ No newline at end of file diff --git a/doc/filetransfers.rst b/doc/filetransfers.rst new file mode 100644 index 00000000..24331968 --- /dev/null +++ b/doc/filetransfers.rst @@ -0,0 +1,120 @@ +File Transfers +============== + +MonetDB supports the non-standard :code:`COPY INTO` statement to load a CSV-like +text file into a table or to dump a table to a text file. This statement has an +optional modifier :code:`ON CLIENT` to indicate that the server should not +try to open the file server-side, but should instead ask the client to open the +file on its behalf. + +For example:: + + COPY INTO mytable FROM 'data'.csv' ON CLIENT + USING DELIMITERS ',', E'\n', '"'; + +By default, if pymonetdb receives a file request from the server, it will refuse +it for security considerations. You do not want the server or a hacker pretending +to be the server to be able to request arbitrary files on your system and even +overwrite them. + +To enable file transfers, create a `pymonetdb.Uploader` and/or +`pymonetdb.Downloader` and register them with your connection:: + + transfer_handler = pymonetdb.SafeDirectoryHandler(datadir) + conn.set_uploader(transfer_handler) + conn.set_downloader(transfer_handler) + +With this in place, the COPY INTO ON CLIENT statement above will ask to open +file data.csv in the given `datadir` and upload its contents. As its name +suggests, :class:`SafeDirectoryHandler` will only allow access to the files in +that directory. + +Note that in this example we register the same handler object both as an +uploader and a downloader, but it is perfectly sensible to only register an +uploader, or only a downloader, or to use two separate handlers. + +See the API documentation for details. + + +Make up data as you go +---------------------- + +You can also write your own transfer handlers. And instead of opening a file, +such handlers can also make up the data on the fly, retrieve it from a remote +microservice, prompt the user interactively or do whatever else you come up +with: + +.. literalinclude:: examples/uploaddyn.py + :pyobject: MyUploader + +In this example we called `upload.text_writer()` which yields a text-mode +file-like object. There is also `upload.binary_writer()` which yields a +binary-mode file-like object. This works even if the server requested a text +mode object, but in that case you have to make sure the bytes you write are valid +utf-8 and delimited with Unix line endings rather than Windows line endings. + +If you want to refuse an up- or download, call `upload.send_error()` to send an +error message. This is only possible before any calls to `text_writer()` and +`binary_writer()`. + +For custom downloaders the situation is similar, except that instead of +`text_writer` and `binary_writer`, the `download` parameter offers +`download.text_reader()` and `download.text_writer()`. + + +Skip amount +----------- + +MonetDB's :code:`COPY INTO` statement allows you to skip for example the first +line in a file using the the modifier :code:`OFFSET 2`. In such a case, +the `skip_amount` parameter to `handle_upload` will be greater than zero. + +Note that the offset in the SQL statement is 1-based, whereas the `skip_amount` +parameter has already been converted to be 0-based. In the example above +this allowed us to write :code:`for i in range(skip_amount, 1000):` rather +than :code:`for i in range(1000):`. + + +Cancellation +------------ + +If the server does not need all uploaded data, for example if you did:: + + COPY 100 RECORDS INTO mytable FROM 'data.csv' ON CLIENT + +the server may at some point cancel the upload. This does not happen instantly, +from time to time pymonetdb explicitly asks the server if they are still +interested. By default this is after every MiB of data but that can be +configured using `upload.set_chunk_size()`. If the server answers that it is no +longer interested, pymonetdb will discard any further data written to the +writer. It is recommended to occasionally call `upload.is_cancelled()` to check +for this and exit early if the upload has been cancelled. + +Upload handlers also have an optional method `cancel()` that you can override. +This method is called when pymonetdb receives the cancellation request. + + +Copying data from or to a file-like object +------------------------------------------ + +If you are moving large amounts of data between pymonetdb and a file-like object +such as a file, Pythons `copyfileobj`_ function may come in handy: + +.. literalinclude:: examples/uploadsafe.py + :pyobject: MyUploader + +However, note that copyfileobj does not handle cancellations as described above. + +.. _copyfileobj: https://docs.python.org/3/library/shutil.html#shutil.copyfileobj + + +Security considerations +----------------------- + +If your handler accesses the file system or the network, it is absolutely critical +to carefully validate the file name you are given. Otherwise an attacker can take +over the server or the connection to the server and cause great damage. + +An example of how to validate file systems paths is given in the code sample above. +Similar considerations apply to text that is inserted into network urls and other +resource identifiers. diff --git a/doc/index.rst b/doc/index.rst index 7c44c4e8..e8ed1bff 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -12,6 +12,7 @@ Contents: :maxdepth: 2 introduction + filetransfers examples api development diff --git a/pymonetdb/__init__.py b/pymonetdb/__init__.py index 082591f3..c039bdab 100644 --- a/pymonetdb/__init__.py +++ b/pymonetdb/__init__.py @@ -20,6 +20,10 @@ from pymonetdb.sql.connections import Connection from pymonetdb.sql.pythonize import * from pymonetdb.exceptions import * +from pymonetdb.filetransfer import Downloader, Uploader +from pymonetdb.filetransfer.downloads import Download +from pymonetdb.filetransfer.uploads import Upload +from pymonetdb.filetransfer.directoryhandler import SafeDirectoryHandler try: __version__ = pkg_resources.require("pymonetdb")[0].version @@ -34,7 +38,7 @@ 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', 'DataError', 'DatabaseError', 'Error', 'IntegrityError', 'InterfaceError', 'InternalError', 'NUMBER', 'NotSupportedError', 'OperationalError', 'ProgrammingError', 'ROWID', 'STRING', 'TIME', 'Warning', 'apilevel', 'connect', 'paramstyle', - 'threadsafety'] + 'threadsafety', 'Download', 'Downloader', 'Upload', 'Uploader', 'SafeDirectoryHandler'] def connect(*args, **kwargs) -> Connection: diff --git a/pymonetdb/filetransfer/__init__.py b/pymonetdb/filetransfer/__init__.py new file mode 100644 index 00000000..0b4cc37c --- /dev/null +++ b/pymonetdb/filetransfer/__init__.py @@ -0,0 +1,159 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. + + +from abc import ABC, abstractmethod +from pymonetdb import mapi as mapi_protocol +from pymonetdb.exceptions import ProgrammingError +import pymonetdb.filetransfer.uploads +import pymonetdb.filetransfer.downloads + + +def handle_file_transfer(mapi: "mapi_protocol.Connection", cmd: str): + if cmd.startswith("r "): + parts = cmd[2:].split(' ', 2) + if len(parts) == 2: + try: + n = int(parts[0]) + except ValueError: + pass + return handle_upload(mapi, parts[1], True, n) + elif cmd.startswith("rb "): + return handle_upload(mapi, cmd[2:], False, 0) + elif cmd.startswith("w "): + return handle_download(mapi, cmd[2:], True) + elif cmd.startswith("wb "): + return handle_download(mapi, cmd[3:], False) + else: + pass + # we only reach this if decoding the cmd went wrong: + mapi._putblock(f"Invalid file transfer command: {cmd!r}") + + +def handle_upload(mapi: "mapi_protocol.Connection", filename: str, text_mode: bool, offset: int): + if not mapi.uploader: + mapi._putblock("No upload handler has been registered with pymonetdb\n") + return + skip_amount = offset - 1 if offset > 0 else 0 + upload = pymonetdb.filetransfer.uploads.Upload(mapi) + try: + mapi.uploader.handle_upload(upload, filename, text_mode, skip_amount) + except Exception as e: + # We must make sure the server doesn't think this is a succesful upload. + # The protocol does not allow us to flag an error after the upload has started, + # so the only thing we can do is kill the connection + upload.error = True + mapi._sabotage() + raise e + finally: + upload.close() + if not upload.has_been_used(): + raise ProgrammingError("Upload handler didn't do anything") + + +def handle_download(mapi: "mapi_protocol.Connection", filename: str, text_mode: bool): + if not mapi.downloader: + mapi._putblock("No download handler has been registered with pymonetdb\n") + return + download = pymonetdb.filetransfer.downloads.Download(mapi) + try: + mapi.downloader.handle_download(download, filename, text_mode) + except Exception as e: + # For consistency we also drop the connection on these exceptions. + # + # # Alternatively we might just discard the incoming data and allow + # work to continue, but in 99% of the cases the application is about + # to crash and it makes no sense to delay that by first reading all + # the data. + # + # Also, if the download has not really started yet we might send + # an error message to the server but then you get inconsistent + # behaviour: if the download hadn't started yet, the transaction ends + # up in an aborted state and must be ROLLed BACK, but if the download + # has started we discard all data and allow it to continue without + # error. + # + # Bottom line is that it's easier to understand if we just always + # crash the connection. + download._shutdown() + mapi._sabotage() + raise e + finally: + download.close() + + +class Uploader(ABC): + """ + Base class for upload hooks. Instances of subclasses of this class can be + registered using pymonetdb.Connection.set_uploader(). Every time an upload + request is received, an Upload object is created and passed to this objects + .handle_upload() method. + + If the server cancels the upload halfway, the .cancel() methods is called + and all further data written is ignored. + """ + + @abstractmethod + def handle_upload(self, upload: "pymonetdb.filetransfer.uploads.Upload", filename: str, text_mode: bool, skip_amount: int): + """ + Called when an upload request is received. Implementations should either + send an error using upload.send_error(), or request a writer using + upload.text_writer() or upload.binary_writer(). All data written to the + writer will be sent to the server. + + Parameter 'filename' is the file name used in the COPY INTO statement. + Parameter 'text_mode' indicates whether the server requested a text file + or a binary file. In case of a text file, 'skip_amount' indicates the + number of lines to skip. In binary mode, 'skip_amount' is always 0. + + SECURITY NOTE! Make sure to carefully validate the file name before + opening files on the file system. Otherwise, if an adversary has taken + control of the network connection or of the server, they can use file + upload requests to read arbitrary files from your computer + (../../) + + """ + pass + + def cancel(self): + """Optional method called when the server cancels the upload.""" + pass + + +class Downloader(ABC): + """ + Base class for download hooks. Instances of subclasses of this class can be + registered using pymonetdb.Connection.set_downloader(). Every time a + download request arrives, a Download object is created and passed to this + objects .handle_download() method. + + SECURITY NOTE! Make sure to carefully validate the file name before opening + files on the file system. Otherwise, if an adversary has taken control of + the network connection or of the server, they can use download requests to + OVERWRITE ARBITRARY FILES on your computer + """ + + @abstractmethod + def handle_download(self, download: "pymonetdb.filetransfer.downloads.Download", filename: str, text_mode: bool): + """ + Called when a download request is received from the server. Implementations + should either refuse by sending an error using download.send_error(), or + request a reader using download.binary_reader() or download.text_reader(). + + Parameter 'filename' is the file name used in the COPY INTO statement. + Parameter 'text_mode' indicates whether the server requested to send a binary + file or a text file. + + SECURITY NOTE! Make sure to carefully validate the file name before opening + files on the file system. Otherwise, if an adversary has taken control of + the network connection or of the server, they can use download requests to + OVERWRITE ARBITRARY FILES on your computer + """ + pass + + +# Only import this at the end to avoid circular imports +from pymonetdb.filetransfer.directoryhandler import SafeDirectoryHandler \ No newline at end of file diff --git a/pymonetdb/filetransfer/directoryhandler.py b/pymonetdb/filetransfer/directoryhandler.py new file mode 100644 index 00000000..539d92d9 --- /dev/null +++ b/pymonetdb/filetransfer/directoryhandler.py @@ -0,0 +1,165 @@ +""" +Classes related to file transfer requests as used by COPY INTO ON CLIENT. +""" +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. + + +import codecs +from importlib import import_module +from pathlib import Path +from shutil import copyfileobj +from typing import Optional +from pymonetdb.filetransfer import Uploader, Downloader +from pymonetdb.filetransfer.uploads import Upload +from pymonetdb.filetransfer.downloads import Download + + +class SafeDirectoryHandler(Uploader, Downloader): + """ + File transfer handler which uploads and downloads files from a given + directory, taking care not to allow access to files outside that directory. + Instances of this class can be registered using the pymonetb.Connection's + set_uploader() and set_downloader() methods. + + When downloading text files, the downloaded text is converted according to + the `encoding` and `newline` parameters, if present. Valid values for + `encoding` are any encoding known to Python, or None. Valid values for + `newline` are `"\\\\n"`, `"\\\\r\\\\n"` or None. None means to use the + system default. + + For binary up- and downloads, no conversions are applied. + + When uploading text files, the `encoding` parameter indicates how the text + is read and `newline` is mostly ignored: both `\\\\n` and `\\\\r\\\\n` are + valid line endings. The exception is that because the server expects its + input to be `\\\\n`-terminated UTF-8 text, if you set encoding to "utf-8" + and newline to "\\\\n", text mode transfers are performed as binary, which + improves performance. For uploads, only do this if you are absolutely, + positively sure that all files in the directory are actually valid UTF-8 + encoded and have Unix line endings. + + If `compression` is set to True, which is the default, the + SafeDirectoryHandler will automatically compress and decompress files with + extensions .gz, .bz2, .xz and .lz4. Note that the first three algorithms are + built into Python, but LZ4 only works if the lz4.frame module is available. + """ + + def __init__(self, dir, encoding: Optional[str] = None, newline: Optional[str] = None, compression=True): + self.dir = Path(dir).resolve() + self.encoding = encoding + self.is_utf8 = (self.encoding and (codecs.lookup('utf-8') == codecs.lookup(self.encoding))) + self.newline = newline + self.compression = compression + + def secure_resolve(self, filename: str) -> Optional[Path]: + p = self.dir.joinpath(filename).resolve() + if str(p).startswith(str(self.dir)): + return p + else: + return None + + def handle_upload(self, upload: Upload, filename: str, text_mode: bool, skip_amount: int): + """:meta private:""" # keep the API docs cleaner, this has already been documented on class Uploader. + + p = self.secure_resolve(filename) + if not p: + return upload.send_error("Forbidden") + + if self.is_utf8 and self.newline == "\n" and skip_amount == 0: + # optimization + text_mode = False + + # open + if text_mode: + mode = "rt" + encoding = self.encoding + newline = self.newline + else: + mode = "rb" + encoding = None + newline = None + try: + opener = lookup_compression_algorithm(filename) if self.compression else open + except ModuleNotFoundError as e: + return upload.send_error(str(e)) + try: + f = opener(p, mode=mode, encoding=encoding, newline=newline) + except IOError as e: + return upload.send_error(str(e)) + + with f: + if text_mode: + tw = upload.text_writer() + for _ in range(skip_amount): + if not f.readline(): + break + self._upload_data(upload, f, tw) + else: + bw = upload.binary_writer() + self._upload_data(upload, f, bw) + + def _upload_data(self, upload: Upload, src, dst): + # Due to duck typing this method works equally well in text- and binary mode + bufsize = 1024 * 1024 + while not upload.is_cancelled(): + data = src.read(bufsize) + if not data: + break + dst.write(data) + + def handle_download(self, download: Download, filename: str, text_mode: bool): + p = self.secure_resolve(filename) + if not p: + return download.send_error("Forbidden") + + if self.is_utf8 and self.newline == "\n": + # optimization + text_mode = False + + # open + mode = "w" if text_mode else "wb" + if text_mode: + mode = "wt" + encoding = self.encoding + newline = self.newline + else: + mode = "wb" + encoding = None + newline = None + try: + opener = lookup_compression_algorithm(filename) if self.compression else open + except ModuleNotFoundError as e: + return download.send_error(str(e)) + try: + f = opener(p, mode=mode, encoding=encoding, newline=newline) + except IOError as e: + return download.send_error(str(e)) + + with f: + if text_mode: + tr = download.text_reader() + copyfileobj(tr, f) + else: + br = download.binary_reader() + copyfileobj(br, f) + + +def lookup_compression_algorithm(filename: str): + lowercase = str(filename).lower() + if lowercase.endswith('.gz'): + mod = 'gzip' + elif lowercase.endswith('.bz2'): + mod = 'bz2' + elif lowercase.endswith('.xz'): + mod = 'lzma' + elif lowercase.endswith('.lz4'): + # not always available + mod = 'lz4.frame' + else: + return open + opener = import_module(mod).open + return opener \ No newline at end of file diff --git a/pymonetdb/filetransfer/downloads.py b/pymonetdb/filetransfer/downloads.py new file mode 100644 index 00000000..1b7e6deb --- /dev/null +++ b/pymonetdb/filetransfer/downloads.py @@ -0,0 +1,132 @@ +""" +Classes related to file transfer requests as used by COPY INTO ON CLIENT. +""" +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. + + +from abc import ABC, abstractmethod +from io import BufferedIOBase, TextIOWrapper +from pymonetdb import mapi as mapi_protocol +from pymonetdb.exceptions import ProgrammingError + + +class Download: + """ + Represents a request from the server to download data from the server. It is + passed to the Downloader registered by the application, which for example + might write the data to a file on the client system. See + pymonetdb.Connection.set_downloader(). + + Use the method send_error() to refuse the download, binary_reader() to get a + binary file object to read bytes from, or text_reader() to get a text-mode + file object to read text from. + + Implementations should be EXTREMELY CAREFUL to validate the file name before + opening and writing to any files on the client system! + """ + + def __init__(self, mapi: "mapi_protocol.Connection"): + self.mapi = mapi + self.started = False + self.buffer = bytearray(8190) + self.pos = 0 + self.len = 0 + self.reader = None + self.treader = None + + def send_error(self, message: str) -> None: + """ + Tell the server the requested download is refused + """ + if self.started: + raise ProgrammingError("Cannot send error anymore") + if not self.mapi: + return + self.started = True + if not message.endswith("\n"): + message += "\n" + self.mapi._putblock(message) + self._shutdown() + + def binary_reader(self): + """Returns a binary file-like object to read the downloaded data from.""" + if not self.reader: + if not self.mapi: + raise ProgrammingError("download has already been closed") + self.started = True + self.mapi._putblock("\n") + self.reader = DownloadIO(self) + return self.reader + + def text_reader(self): + """Returns a text mode file-like object to read the downloaded data from.""" + if not self.treader: + self.treader = TextIOWrapper(self.binary_reader(), encoding='utf-8', newline='\n') + return self.treader + + def close(self): + """End the download succesfully. Any unconsumed data will be discarded.""" + while self.mapi: + self._fetch() + assert not self.mapi + + def _available(self) -> int: + return self.len - self.pos + + def _consume(self, n: int) -> memoryview: + end = min(self.pos + n, self.len) + ret = memoryview(self.buffer)[self.pos:end] + self.pos = end + return ret + + def _fetch(self): + if not self.mapi: + return + self.pos = 0 + self.len = 0 # safety in case of exceptions + self.len, last = self.mapi._get_minor_block(self.buffer, 0) + if last: + self._shutdown() + + def _shutdown(self): + self.started = True + self.mapi = None + + +class DownloadIO(BufferedIOBase): + + def __init__(self, download: Download): + self.download = download + + def readable(self): + return True + + def read(self, n=0): + if self.download._available() == 0: + self.download._fetch() + return bytes(self.download._consume(n)) + + def read1(self, n=0): + return self.read(n) + + +class Downloader(ABC): + """ + Base class for download hooks. Instances of subclasses of this class can be + registered using pymonetdb.Connection.set_downloader(). Every time a + download request arrives, a Download object is created and passed to this + objects .handle_download() method. + + SECURITY NOTE! Make sure to carefully validate the file name before opening + files on the file system. Otherwise, if an adversary has taken control of + the network connection or of the server, they can use download requests to + OVERWRITE ARBITRARY FILES on your computer + """ + + @abstractmethod + def handle_download(self, download: Download, filename: str, text_mode: bool): + pass diff --git a/pymonetdb/filetransfer/uploads.py b/pymonetdb/filetransfer/uploads.py new file mode 100644 index 00000000..9b581f14 --- /dev/null +++ b/pymonetdb/filetransfer/uploads.py @@ -0,0 +1,245 @@ +""" +Classes related to file transfer requests as used by COPY INTO ON CLIENT. +""" +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. + + +from io import BufferedIOBase, BufferedWriter, RawIOBase, TextIOBase, TextIOWrapper +from typing import Any, Optional, Union +from pymonetdb import mapi as mapi_protocol +from pymonetdb.exceptions import ProgrammingError + + +class Upload: + """ + Represents a request from the server to upload data to the server. It is + passed to the Uploader registered by the application, which for example + might retrieve the data from a file on the client system. See + pymonetdb.Connection.set_uploader(). + + Use the method send_error() to refuse the upload, binary_writer() to get a + binary file object to write to, or text_writer() to get a text-mode file + object to write to. + + Implementations should be VERY CAREFUL to validate the file name before + opening any files on the client system! + """ + + mapi: Optional["mapi_protocol.Connection"] + error = False + cancelled = False + bytes_sent = 0 + chunk_size = 1024 * 1024 + chunk_used = 0 + rawio: Optional["UploadIO"] = None + writer: Optional[BufferedWriter] = None + twriter: Optional[TextIOBase] = None + + def __init__(self, mapi: "mapi_protocol.Connection"): + self.mapi = mapi + + def _check_usable(self): + if self.error: + raise ProgrammingError("Upload handle has had an error, cannot be used anymore") + if not self.mapi: + raise ProgrammingError("Upload handle has been closed, cannot be used anymore") + + def is_cancelled(self) -> bool: + """Returns true if the server has cancelled the upload.""" + return self.cancelled + + def has_been_used(self) -> bool: + """Returns true if .send_error(), .text_writer() or .binary_writer() have been called.""" + return self.error or (self.rawio is not None) + + def set_chunk_size(self, size: int): + """ + After every CHUNK_SIZE bytes, the server gets the opportunity to cancel + the rest of the upload. Defaults to 1 MiB. + """ + self.chunk_size = size + + def send_error(self, message: str) -> None: + """ + Tell the server the requested upload has been refused + """ + if self.cancelled: + return + self._check_usable() + if self.bytes_sent: + raise ProgrammingError("Cannot send error after data has been sent") + if not message.endswith("\n"): + message += "\n" + self.error = True + assert self.mapi + self.mapi._putblock(message) + self.mapi = None + + def _raw(self) -> "UploadIO": + if self.bytes_sent == 0: + # send the magic newline indicating we're ok with the upload + self._send(b'\n', False) + if not self.rawio: + self.rawio = UploadIO(self) + return self.rawio + + def binary_writer(self) -> BufferedIOBase: + """ + Returns a binary file-like object. All data written to it is uploaded + to the server. + """ + if not self.writer: + self.writer = BufferedWriter(self._raw()) + return self.writer + + def text_writer(self) -> TextIOBase: + r""" + Returns a text-mode file-like object. All text written to it is uploaded + to the server. DOS/Windows style line endings (CR LF, \\r \\n) are + automatically rewritten to single \\n's. + """ + if not self.twriter: + # Without the Any annotation there is no way I can convince the + # type checker that TextIOWrapper can accept a NormalizeCrLf + # object. Apparently being a subclass of BufferedIOBase is not enough. + w: Any = NormalizeCrLf(self._raw()) + self.twriter = TextIOWrapper(w, encoding='utf-8', newline='\n') + return self.twriter + + def _send_data(self, data: Union[bytes, memoryview]): + if self.cancelled: + return + self._check_usable() + assert self.mapi is not None + pos = 0 + end = len(data) + while pos < end: + n = min(end - pos, self.chunk_size - self.chunk_used) + chunk = memoryview(data)[pos:pos + n] + if n == self.chunk_size - self.chunk_used and self.chunk_size > 0: + server_wants_more = self._send_and_get_prompt(chunk) + if not server_wants_more: + self.cancelled = True + self.mapi.uploader.cancel() + self.mapi = None + break + else: + self._send(chunk, False) + pos += n + + def _send(self, data: Union[bytes, memoryview], finish: bool): + assert self.mapi + self.mapi._putblock_raw(data, finish) + self.bytes_sent += len(data) + self.chunk_used += len(data) + + def _send_and_get_prompt(self, data: Union[bytes, memoryview]) -> bool: + assert self.mapi + self._send(data, True) + prompt = self.mapi._getblock() + if prompt == mapi_protocol.MSG_MORE: + self.chunk_used = 0 + return True + elif prompt == mapi_protocol.MSG_FILETRANS: + # server says stop + return False + else: + raise ProgrammingError(f"Unexpected server response: {prompt[:50]!r}") + + def close(self): + """ + End the upload succesfully + """ + if self.error: + return + if self.twriter: + self.twriter.close() + if self.writer: + self.writer.close() + if self.mapi: + server_wants_more = False + if self.chunk_used != 0: + # finish the current block + server_wants_more = self._send_and_get_prompt(b'') + if server_wants_more: + # send empty block to indicate end of upload + self.mapi._putblock('') + # receive acknowledgement + resp = self.mapi._getblock() + if resp != mapi_protocol.MSG_FILETRANS: + raise ProgrammingError(f"Unexpected server response: {resp[:50]!r}") + self.mapi = None + + +class UploadIO(RawIOBase): + """IO adaptor for Upload. """ + + def __init__(self, upload: Upload): + self.upload = upload + + def writable(self): + return True + + def write(self, b): + n = len(b) + if self.upload.is_cancelled(): + return n + self.upload._send_data(b) + return n + + +class NormalizeCrLf(BufferedIOBase): + """ + Helper class used to normalize line endings before sending text to MonetDB. + + Existing normalization code mostly deals with normalizing after reading, + this one normalizes before writing. + """ + + def __init__(self, inner): + self.inner = inner + self.pending = False + + def writable(self): + return True + + def write(self, data) -> int: + if not data: + return 0 + + if self.pending: + if data.startswith(b"\n"): + # normalize by forgetting the pending \r + pass + else: + # pending \r not followed by \n, write it + self.inner.write(b"\r") + # do not take the above write into account in the return value, + # it was included last time + + normalized = data.replace(b"\r\n", b"\n") + + if normalized[-1] == 13: # \r + # not sure if it will be followed by \n, move it to pending + self.pending = True + normalized = memoryview(normalized)[:-1] + else: + self.pending = False + + n = self.inner.write(normalized) + assert n == len(normalized) + return len(data) + + def flush(self): + return self.inner.flush() + + def close(self): + if self.pending: + self.inner.write(b"\r") + self.pending = False + return self.inner.close() + diff --git a/pymonetdb/mapi.py b/pymonetdb/mapi.py index 6705ba15..6ab5a04d 100644 --- a/pymonetdb/mapi.py +++ b/pymonetdb/mapi.py @@ -13,12 +13,12 @@ import struct import hashlib import os -from typing import Optional -from io import BytesIO +from typing import Optional, Tuple from urllib.parse import urlparse, parse_qs from pymonetdb.exceptions import OperationalError, DatabaseError, \ ProgrammingError, NotSupportedError, IntegrityError +import pymonetdb.filetransfer logger = logging.getLogger(__name__) @@ -26,6 +26,7 @@ MSG_PROMPT = "" MSG_MORE = "\1\2\n" +MSG_FILETRANS = "\1\3\n" MSG_INFO = "#" MSG_ERROR = "!" MSG_Q = "&" @@ -41,6 +42,8 @@ MSG_REDIRECT = "^" MSG_OK = "=OK" +MSG_FILETRANS_B = bytes(MSG_FILETRANS, 'utf-8') + STATE_INIT = 0 STATE_READY = 1 @@ -94,6 +97,9 @@ def __init__(self): self.language = "" self.handshake_options = None self.connect_timeout = socket.getdefaulttimeout() + self.uploader = None + self.downloader = None + self.stashed_buffer = None def connect(self, database, username, password, language, hostname=None, port=None, unix_socket=None, connect_timeout=-1, handshake_options=None): @@ -252,15 +258,32 @@ def disconnect(self): self.state = STATE_INIT self.socket.close() + def _sabotage(self): + """ Kill the connection in a way that the server is sure to recognize as an error""" + sock = self.socket + self.socket = None + self.state = STATE_INIT + if not sock: + return + bad_header = struct.pack('= 7: + response += "FILETRANS:" options_level = 0 for part in challenges[6].split(","): if part.startswith("sql="): @@ -352,63 +376,103 @@ def _challenge_response(self, challenge): return response - def _getblock(self): - """ read one mapi encoded block """ - if self.language == 'control' and not self.hostname: - return self._getblock_socket() # control doesn't do block splitting when using a socket - else: - return self._getblock_inet() - - def _getblock_inet(self): - result = BytesIO() - last = 0 - while not last: - flag = self._getbytes(2) - unpacked = struct.unpack('> 1 - last = unpacked & 1 - result.write(self._getbytes(length)) - return result.getvalue().decode() - - def _getblock_socket(self): - buffer = BytesIO() + def _getblock_and_transfer_files(self): + """ read one mapi encoded block and take care of any file transfers the server requests""" + buffer = self._get_buffer() + offset = 0 while True: - x = self.socket.recv(1) - if len(x): - buffer.write(x) + old_offset = offset + offset = self._getblock_raw(buffer, old_offset) + i = buffer.rfind(b'\n', old_offset, offset - 1) + if i >= old_offset + 2 and buffer[i - 2: i + 1] == MSG_FILETRANS_B: + # File transfer request. Chop the cmd off the buffer by lowering the offset + cmd = str(buffer[i + 1: offset - 1], 'utf-8') + offset = i - 2 + pymonetdb.filetransfer.handle_file_transfer(self, cmd) + continue else: break - return buffer.getvalue().strip().decode() - - def _getbytes(self, bytes_): - """Read an amount of bytes from the socket""" - result = BytesIO() - count = bytes_ - while count > 0: - recv = self.socket.recv(count) - if len(recv) == 0: + self._stash_buffer(buffer) + return str(memoryview(buffer)[:offset], 'utf-8') + + def _getblock(self) -> str: + """ read one mapi encoded block """ + buf = self._get_buffer() + end = self._getblock_raw(buf, 0) + ret = str(memoryview(buf)[:end], 'utf-8') + self._stash_buffer(buf) + return ret + + def _getblock_raw(self, buffer: bytearray, offset: int) -> int: + """ + Read one mapi block into 'buffer' starting at 'offset', enlarging the buffer + as necessary and returning offset plus the number of bytes read. + """ + last = False + while not last: + offset, last = self._get_minor_block(buffer, offset) + return offset + + def _get_minor_block(self, buffer: bytearray, offset: int) -> Tuple[int, bool]: + self._getbytes(buffer, offset, 2) + unpacked = buffer[offset] + 256 * buffer[offset + 1] + length = unpacked >> 1 + last = unpacked & 1 + if length: + offset = self._getbytes(buffer, offset, length) + return (offset, bool(last)) + + def _getbytes(self, buffer: bytearray, offset: int, count: int) -> int: + """ + Read 'count' bytes from the socket into 'buffer' starting at 'offset'. + Enlarge buffer if necessary. + Return offset + count if all goes well. + """ + assert self.socket + end = count + offset + if len(buffer) < end: + # enlarge + nblocks = 1 + (end - len(buffer)) // 8192 + buffer += bytes(nblocks * 8192) + while offset < end: + view = memoryview(buffer)[offset:end] + n = self.socket.recv_into(view) + if n == 0: raise BrokenPipeError("Server closed connection") - count -= len(recv) - result.write(recv) - return result.getvalue() + offset += n + return end + + def _get_buffer(self) -> bytearray: + """Retrieve a previously stashed buffer for reuse, or create a new one""" + if self.stashed_buffer: + buffer = self.stashed_buffer + self.stashed_buffer = None + else: + buffer = bytearray(8192) + return buffer + + def _stash_buffer(self, buffer): + """Stash a used buffer for future reuse""" + if self.stashed_buffer is None or len(self.stashed_buffer) < len(buffer): + self.stashed_buffer = buffer def _putblock(self, block): """ wrap the line in mapi format and put it into the socket """ - if self.language == 'control' and not self.hostname: - return self.socket.send(block.encode()) # control doesn't do block splitting when using a socket - else: - self._putblock_inet(block) + self._putblock_inet_raw(block.encode(), True) + + def _putblock_raw(self, block, finish: bool): + """ put the data into the socket """ + self._putblock_inet_raw(block, finish) - def _putblock_inet(self, block): + def _putblock_inet_raw(self, block, finish): pos = 0 last = 0 - block = block.encode() while not last: - data = block[pos:pos + MAX_PACKAGE_LENGTH] + data = memoryview(block)[pos:pos + MAX_PACKAGE_LENGTH] length = len(data) if length < MAX_PACKAGE_LENGTH: last = 1 - flag = struct.pack(' ") +staging_dir = sys.argv[1] +dest_dir = sys.argv[2] + +TREE = """ +bin/ +bin/bat.dll +bin/bat.pdb +bin/bz2.dll +bin/charset-1.dll +bin/geos.dll +bin/geos_c.dll +bin/getopt.dll +bin/iconv-2.dll +bin/libxml2.dll +bin/lz4.dll +bin/lzma.dll +bin/mapi.dll +bin/mapi.pdb +bin/mclient.exe +bin/mclient.pdb +bin/monetdb5.dll +bin/monetdb5.pdb +bin/monetdbe.dll +bin/monetdbsql.dll +bin/mserver5.exe +bin/mserver5.pdb +bin/msqldump.exe +bin/msqldump.pdb +bin/pcre.dll +bin/stream.dll +bin/stream.pdb +bin/zlib1.dll +etc/ +etc/.monetdb +include/ +include/monetdb/ +include/monetdb/copybinary.h +include/monetdb/exception_buffer.h +include/monetdb/gdk.h +include/monetdb/gdk_atoms.h +include/monetdb/gdk_bbp.h +include/monetdb/gdk_calc.h +include/monetdb/gdk_cand.h +include/monetdb/gdk_delta.h +include/monetdb/gdk_hash.h +include/monetdb/gdk_posix.h +include/monetdb/gdk_strimps.h +include/monetdb/gdk_system.h +include/monetdb/gdk_time.h +include/monetdb/gdk_tracer.h +include/monetdb/gdk_utils.h +include/monetdb/mal.h +include/monetdb/mal_authorize.h +include/monetdb/mal_client.h +include/monetdb/mal_errors.h +include/monetdb/mal_exception.h +include/monetdb/mal_function.h +include/monetdb/mal_import.h +include/monetdb/mal_instruction.h +include/monetdb/mal_linker.h +include/monetdb/mal_listing.h +include/monetdb/mal_module.h +include/monetdb/mal_namespace.h +include/monetdb/mal_prelude.h +include/monetdb/mal_resolve.h +include/monetdb/mal_stack.h +include/monetdb/mal_type.h +include/monetdb/mapi.h +include/monetdb/matomic.h +include/monetdb/mel.h +include/monetdb/monet_getopt.h +include/monetdb/monet_options.h +include/monetdb/monetdb_config.h +include/monetdb/monetdbe.h +include/monetdb/mstring.h +include/monetdb/opt_backend.h +include/monetdb/rel_basetable.h +include/monetdb/rel_distribute.h +include/monetdb/rel_dump.h +include/monetdb/rel_exp.h +include/monetdb/rel_optimizer.h +include/monetdb/rel_partition.h +include/monetdb/rel_prop.h +include/monetdb/rel_rel.h +include/monetdb/rel_semantic.h +include/monetdb/sql_atom.h +include/monetdb/sql_backend.h +include/monetdb/sql_catalog.h +include/monetdb/sql_hash.h +include/monetdb/sql_import.h +include/monetdb/sql_keyword.h +include/monetdb/sql_list.h +include/monetdb/sql_mem.h +include/monetdb/sql_mvc.h +include/monetdb/sql_parser.h +include/monetdb/sql_privileges.h +include/monetdb/sql_qc.h +include/monetdb/sql_query.h +include/monetdb/sql_relation.h +include/monetdb/sql_scan.h +include/monetdb/sql_semantic.h +include/monetdb/sql_stack.h +include/monetdb/sql_storage.h +include/monetdb/sql_string.h +include/monetdb/sql_symbol.h +include/monetdb/sql_tokens.h +include/monetdb/sql_types.h +include/monetdb/store_sequence.h +include/monetdb/stream.h +include/monetdb/stream_socket.h +lib/ +lib/bat.lib +lib/bz2.lib +lib/charset.lib +lib/getopt.lib +lib/iconv.lib +lib/libxml2.lib +lib/lz4.lib +lib/lzma.lib +lib/mapi.lib +lib/monetdb5.lib +lib/monetdb5/ +lib/monetdb5/_generator.dll +lib/monetdb5/_generator.pdb +lib/monetdb5/_geom.dll +lib/monetdb5/_geom.pdb +lib/monetdb5/_pyapi3.dll +lib/monetdb5/_pyapi3.pdb +lib/monetdbe.lib +lib/monetdbsql.lib +lib/pcre.lib +lib/stream.lib +lib/zlib.lib +license.rtf +M5server.bat +mclient.bat +msqldump.bat +MSQLserver.bat +pyapi_locatepython3.bat +share/ +share/doc/ +share/doc/MonetDB-SQL/ +share/doc/MonetDB-SQL/dump-restore.html +share/doc/MonetDB-SQL/dump-restore.txt +share/doc/MonetDB-SQL/website.html +System64/ +System64/concrt140.dll +System64/msvcp140.dll +System64/msvcp140_1.dll +System64/msvcp140_2.dll +System64/msvcp140_atomic_wait.dll +System64/msvcp140_codecvt_ids.dll +System64/vccorlib140.dll +System64/vcruntime140.dll +System64/vcruntime140_1.dll +""" + + +failures = 0 +for line in TREE.strip().splitlines(): + parts = line.split('/') + if not parts[-1]: + continue + src0 = parts[-1] + src0 = src0.replace('-', '_') + if src0.startswith('.'): + src0 = '_' + src0 + if '140' in src0 and '.dll' in src0: + src0 += '.DFEFC2FE_EEE6_424C_841B_D4E66F0C84A3' + src = os.path.join(staging_dir, src0) + tgt_dir = os.path.join(dest_dir, *parts[:-1]) + tgt = os.path.join(dest_dir, *parts) + if not os.path.isdir(tgt_dir): + print(f"Creating dir {tgt_dir}", flush=True) + os.makedirs(tgt_dir) + print(f"Copying [{src0}] {src} to {tgt}", flush=True) + try: + shutil.copyfile(src, tgt) + except Exception as e: + print(f" !! FAILED: {e}", flush=True) + failures += 1 + +if failures: + exit(f"Encountered {failures} failures") diff --git a/tests/requirements.txt b/tests/requirements.txt index f0342004..6f3b56b8 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,4 +1,5 @@ coveralls +future pycodestyle pytest Sphinx diff --git a/tests/test_filetransfer.py b/tests/test_filetransfer.py new file mode 100644 index 00000000..d182633e --- /dev/null +++ b/tests/test_filetransfer.py @@ -0,0 +1,813 @@ +""" +This is the python implementation of the mapi protocol. +""" +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. + + +import codecs +from io import BufferedIOBase, StringIO +import os +from pathlib import Path +from shutil import copyfileobj, copyfile +import signal +import sys +from tempfile import mkdtemp +from threading import Condition, Thread +import time +from typing import Optional, Tuple +from unittest import TestCase, skipUnless + + +from pymonetdb import connect, Error as MonetError +from pymonetdb.exceptions import OperationalError, ProgrammingError +from pymonetdb import Download, Downloader, Upload, Uploader +from pymonetdb.filetransfer.directoryhandler import SafeDirectoryHandler, lookup_compression_algorithm +from pymonetdb.filetransfer.uploads import NormalizeCrLf +from tests.util import have_lz4, test_args, test_full + + +class MyException(Exception): + pass + + +class MyUploader(Uploader): + rows: int = 5_000 + error_at: Optional[int] = None + chunk_size: int = 10_000 + force_binary: bool = False + forget_to_return_after_error: bool = False + do_nothing_at_all: bool = False + ignore_cancel: bool = False + # + cancelled_at: Optional[int] = None + + def handle_upload(self, upload: Upload, filename: str, text_mode: bool, skip_amount: int): + if self.do_nothing_at_all: + return + elif filename.startswith("x"): + upload.send_error("filename must not start with 'x'") + if not self.forget_to_return_after_error: + return + + iter = range(skip_amount + 1, self.rows + 1) + upload.set_chunk_size(self.chunk_size) + if text_mode and not self.force_binary: + tw = upload.text_writer() + for i in iter: + if i == self.error_at: + raise MyException(f"Oops {i}") + if upload.is_cancelled() and not self.ignore_cancel: + self.cancelled_at = i + break + s = f"{i}\n" + tw.write(s) + else: + bw = upload.binary_writer() + for i in iter: + if i == self.error_at: + raise MyException(f"Oops {i}") + if upload.is_cancelled() and not self.ignore_cancel: + self.cancelled_at = i + break + s = f"{i}\n" + bw.write(bytes(s, 'ascii')) + + +class MyDownloader(Downloader): + lines: Optional[int] = None + error_at_line: Optional[int] = None + refuse: Optional[str] = None + forget_to_return_after_refusal: bool = False + buffer: StringIO + + def __init__(self): + self.buffer = StringIO() + + def handle_download(self, download: Download, filename: str, text_mode: bool): + if self.refuse: + download.send_error(self.refuse) + if not self.forget_to_return_after_refusal: + return + if self.lines is None: + copyfileobj(download.text_reader(), self.buffer) + else: + tr = download.text_reader() + i = 0 + while self.lines is None or self.lines < 0 or i < self.lines: + if i == self.error_at_line: + raise MyException("oopsie") + else: + line = tr.readline() + if not line: + break + self.buffer.write(line) + i += 1 + + def get(self): + return self.buffer.getvalue() + + +class DeadManHandle: + """ + Watchdog timer. Kills the process after a certain time. Great to detect + deadlocks, but can be inconvenient when running tests in the debugger. In + that case, temporarily call deadman.cancel() at the start of the test. + """ + def __init__(self): + self.cond = Condition() + self.deadline = None + self.message = None + self.thread = Thread(target=self.work, daemon=True) + self.thread.start() + + def set_timeout(self, t, msg): + self.set_deadline(time.time() + t, msg) + + def set_deadline(self, d, msg): + with self.cond: + self.deadline = d + if msg: + self.message = msg + self.cond.notify_all() + + def cancel(self): + with self.cond: + self.deadline = None + self.message = None + + def work(self): + with self.cond: + while True: + now = time.time() + if self.deadline: + delta = self.deadline - now + if delta <= 0: + print("\n\nTIMEOUT:", self.message, "\n\n", file=sys.stderr) + os.kill(os.getpid(), signal.SIGKILL) + else: + delta = None + self.cond.wait(timeout=delta) + + +deadman = DeadManHandle() + + +class Common: + first = True + tmpdir: Optional[Path] = None + + defaultencoding = None + + def file(self, filename): + """Resolve the given relative path within our temp directory.""" + if not self.tmpdir: + self.tmpdir = Path(mkdtemp(prefix="filetrans_")) + return self.tmpdir.joinpath(filename) + + def open(self, filename, mode, **kwargs): + """Open the given filename, resolved within our temp directory""" + fullname = self.file(filename) + return open(fullname, mode, **kwargs) + + def commonSetUp(self): + with self.open('checkencoding.txt', 'wt') as f: + self.defaultencoding = f.encoding + + self.conn = conn = connect(**test_args) + self.uploader = MyUploader() + conn.set_uploader(self.uploader) + self.downloader = MyDownloader() + conn.set_downloader(self.downloader) + + self.cursor = c = self.conn.cursor() + if self.first: + c.execute('DROP TABLE IF EXISTS foo') + c.execute('CREATE TABLE foo(i INT)') + c.execute('DROP TABLE IF EXISTS foo2') + c.execute('CREATE TABLE foo2(i INT, t VARCHAR(20))') + conn.commit() + self.first = False + + deadman.set_timeout(10, f"timeout in {self._testMethodName}()") + + def commonTearDown(self): + deadman.cancel() + try: + self.cursor.close() + self.conn.rollback() + self.conn.close() + except MonetError: + pass + + def fill_foo(self, nrows): + self.execute("INSERT INTO foo(i) SELECT * FROM sys.generate_series(1, %s + 1)", [nrows]) + + def execute(self, *args, **kwargs): + return self.cursor.execute(*args, **kwargs) + + def expect(self, expected_resultset): + actual_resultset = self.cursor.fetchall() + self.assertEqual(expected_resultset, actual_resultset) + + def expect1(self, value): + self.expect([(value,)]) + + def compression_prefix(self, scheme): + return {'gz': b'\x1F\x8B', 'bz2': b'\x42\x5A\x68', 'xz': b'\xFD\x37\x7A\x58\x5A\x00', 'lz4': b'\x04\x22\x4D\x18', None: None}[scheme] + + +class TestFileTransfer(TestCase, Common): + + def setUp(self): + super().setUp() + self.commonSetUp() + + def tearDown(self): + self.commonTearDown() + super().tearDown() + + def test_do_nothing_at_all(self): + self.uploader.do_nothing_at_all = True + with self.assertRaises(ProgrammingError): + # Handler must either refuse or create a writer. + # Not writing to the writer is not a problem, that's just an empty file + self.execute("COPY INTO foo FROM 'foo' ON CLIENT") + + def test_upload(self): + self.execute("COPY INTO foo FROM 'foo' ON CLIENT") + self.execute("SELECT COUNT(*) FROM foo") + self.expect1(self.uploader.rows) + + def test_upload_empty(self): + self.uploader.rows = 0 + self.execute("COPY INTO foo FROM 'foo' ON CLIENT") + self.execute("SELECT COUNT(*) FROM foo") + self.expect1(self.uploader.rows) + + # Also see test_NormalizeCrLf from the Java tests + def test_upload_crlf(self): + class CustomUploader(Uploader): + def handle_upload(self, upload: Upload, filename: str, text_mode: bool, skip_amount: int): + w = upload.text_writer() + w.write("1|A\r\n2|BB\r") + w.flush() + w.write("\n3|CCC\r\n") + + self.conn.set_uploader(CustomUploader()) + self.execute("COPY INTO foo2 FROM 'foo2' ON CLIENT") + self.execute("SELECT i, t FROM foo2") + self.expect([(1, "A"), (2, "BB"), (3, "CCC")]) + + def test_client_refuses_upload(self): + # our Uploader refuses filename that start with 'x' + with self.assertRaises(OperationalError): + self.execute("COPY INTO foo FROM 'xfoo' ON CLIENT") + + def test_upload_offset0(self): + # OFFSET 0 and OFFSET 1 behave the same, they do nothing + self.uploader.chunk_size = 100 + self.uploader.rows = 100 + self.execute("COPY OFFSET 0 INTO foo FROM 'foo' ON CLIENT") + self.execute("SELECT MIN(i) AS mi, MAX(i) AS ma FROM foo") + self.expect([(1, 100)]) + + def test_upload_offset1(self): + # OFFSET 0 and OFFSET 1 behave the same, they do nothing + self.uploader.chunk_size = 100 + self.uploader.rows = 100 + self.execute("COPY OFFSET 1 INTO foo FROM 'foo' ON CLIENT") + self.execute("SELECT MIN(i) AS mi, MAX(i) AS ma FROM foo") + self.expect([(1, 100)]) + + def test_upload_offset5(self): + self.uploader.chunk_size = 100 + self.uploader.rows = 100 + self.execute("COPY OFFSET 5 INTO foo FROM 'foo' ON CLIENT") + self.execute("SELECT MIN(i) AS mi, MAX(i) AS ma FROM foo") + self.expect([(5, 100)]) + + def test_server_cancels_upload(self): + # self.uploader.chunkSize = 100 + self.execute("COPY 10 RECORDS INTO foo FROM 'foo' ON CLIENT") + self.assertGreater(self.uploader.cancelled_at, 0) + self.execute("SELECT COUNT(*) FROM foo") + self.expect1(10) + + def test_download_refused(self): + self.downloader.refuse = 'no thanks' + with self.assertRaises(OperationalError): + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + # connection still alive + self.conn.rollback() + self.execute("SELECT 42") + self.expect1(42) + + def test_download(self): + self.fill_foo(5) + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + self.assertEqual("1\n2\n3\n4\n5\n", self.downloader.get()) + + def test_download_empty(self): + self.fill_foo(0) + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + self.assertEqual("", self.downloader.get()) + + def test_download_lines(self): + self.downloader.lines = -1 + self.fill_foo(5) + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + self.assertEqual("1\n2\n3\n4\n5\n", self.downloader.get()) + + def test_download_stop_reading_halfway(self): + self.fill_foo(10000) + self.downloader.lines = 10 + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + self.assertEqual("1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n", self.downloader.get()) + # connection still alive + self.execute("SELECT 42") + self.expect1(42) + + @skipUnless(test_full, "full test disabled") + def test_large_upload(self): + deadman.set_timeout(50, None) + n = 1_000_000 + self.uploader.rows = n + self.uploader.chunk_size = 1024 * 1024 + self.execute("COPY INTO foo FROM 'foo' ON CLIENT") + self.assertIsNone(self.uploader.cancelled_at) + self.execute("SELECT COUNT(*) FROM foo") + self.expect1(n) + + @skipUnless(test_full, "full test disabled") + def test_large_download(self): + deadman.set_timeout(50, None) + n = 1_000_000 + self.fill_foo(n) + self.execute("COPY (SELECT * FROM foo) INTO 'banana' ON CLIENT") + content = self.downloader.get() + nlines = len(content.splitlines()) + self.assertEqual(n, nlines) + + def test_upload_native_text_file(self): + self.upload_file('native.csv', {}, True) + + def test_upload_unix_text_file(self): + self.upload_file('unix.csv', dict(newline="\n"), True) + + def test_upload_dos_text_file(self): + self.upload_file('dos.csv', dict(newline="\r\n"), True) + + def test_upload_unix_binary_file(self): + self.upload_file('unix.csv', dict(newline="\n"), False) + + def upload_file(self, filename, write_opts, read_text): + encoding = self.defaultencoding if read_text else 'utf-8' + interesting_text = "Únïçøðε¡!÷" + encodable_text = "" + for c in interesting_text: + try: + # Not all of the following characters may be available in the system encoding + bytes(c, self.defaultencoding) + except UnicodeEncodeError: + continue + encodable_text += c + assert len(encodable_text) > 0 + n = 1000 + f = self.open(filename, 'wt', encoding=encoding, **write_opts) + assert f.encoding + for i in range(n): + print(f"{i}|{encodable_text}{i}", file=f) + f.close() + testcase = self + + class CustomUploader(Uploader): + def handle_upload(self, upload: Upload, filename: str, text_mode: bool, skip_amount: int): + if read_text: + f = testcase.open(filename, 'r') + tw = upload.text_writer() + copyfileobj(f, tw) + else: + f = testcase.open(filename, 'rb') + bw = upload.binary_writer() + copyfileobj(f, bw) + f.close() + + self.conn.set_uploader(CustomUploader()) + self.execute(f"COPY INTO foo2 FROM '{filename}' ON CLIENT") + self.execute("SELECT t FROM foo2 where i = 999") + self.expect1(f"{encodable_text}999") + + def test_fail_upload_late(self): + self.uploader.error_at = 99 + # If the handler raises an exception, .. + with self.assertRaises(MyException): + self.execute("COPY INTO foo FROM 'foo' ON CLIENT") + # .. the connection is dropped + with self.assertRaisesRegex(ProgrammingError, "ot connected"): + self.execute("SELECT COUNT(*) FROM foo") + + def test_download_immediate_exception(self): + + class CustomDownloader(Downloader): + def handle_download(self, download: Download, filename: str, text_mode: bool): + raise MyException("fail early") + + self.fill_foo(5000) + self.conn.set_downloader(CustomDownloader()) + # If the handler raises an exception, .. + with self.assertRaises(MyException): + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + # .. the connection is dropped + with self.assertRaisesRegex(ProgrammingError, "ot connected"): + self.execute("SELECT COUNT(*) FROM foo") + + def test_fail_download_late(self): + self.fill_foo(5000) + self.downloader.lines = 6000 + self.downloader.error_at_line = 4000 + # If the handler raises an exception, .. + with self.assertRaises(MyException): + self.execute("COPY (SELECT * FROM foo) INTO 'foo' ON CLIENT") + # .. the connection is dropped + with self.assertRaisesRegex(ProgrammingError, "ot connected"): + self.execute("SELECT COUNT(*) FROM foo") + + +class TestSafeDirectoryHandler(TestCase, Common): + + def setUp(self): + super().setUp() + self.commonSetUp() + + def tearDown(self): + self.commonTearDown() + super().tearDown() + + def test_upload_handler_security(self): + f = self.open("foo.csv", "w") + f.write("1\n2\n3\n") + f.close() + outside = self.file('') + inside = self.file('inside') + inside.mkdir() + f = self.open(inside.joinpath("foo.csv"), "w") + f.write("10\n20\n30\n") + f.close() + # + handler = SafeDirectoryHandler(inside) + self.conn.set_uploader(handler) + # + testcases = [ + ('foo.csv', True), + ('./foo.csv', True), + (inside.joinpath('foo.csv'), True), + ('../foo.csv', False), + (outside.joinpath('foo.csv'), False), + ] + for path, valid in testcases: + with self.subTest(dir=str(inside), path=str(path), expect_valid=valid): + self.conn.rollback() + path = str(path) + if valid: + self.execute("COPY INTO foo FROM %s ON CLIENT", [path]) + else: + with self.assertRaises(OperationalError): + self.execute("COPY INTO foo FROM %s ON CLIENT", [path]) + continue + self.execute("SELECT MAX(i) FROM foo") + self.expect1(30) + + def test_download_handler_security(self): + self.execute("INSERT INTO foo SELECT * FROM sys.generate_series(0, 10)") + outside = self.file('') + inside = self.file('inside') + inside.mkdir() + # + handler = SafeDirectoryHandler(inside) + self.conn.set_downloader(handler) + # + testcases = [ + ('foo.csv', True), + ('./foo.csv', True), + (inside.joinpath('foo.csv'), True), + ('../foo.csv', False), + (outside.joinpath('foo.csv'), False), + ] + for path, valid in testcases: + with self.subTest(dir=str(inside), path=str(path), expect_valid=valid): + self.conn.rollback() + path = str(path) + if valid: + self.execute("COPY (SELECT * FROM foo) INTO %s ON CLIENT", [path]) + else: + with self.assertRaises(OperationalError): + self.execute("COPY (SELECT * FROM foo) INTO %s ON CLIENT", [path]) + continue + + def get_testdata_name(self, enc_name: str, newline: str, lines: int = None, compression=None) -> str: + newline_name = {None: "none", "\n": "lf", "\r\n": "crlf"}[newline] + file_name = f"{enc_name}_{newline_name}" + if lines is not None: + file_name += f"_{lines}lines" + file_name += ".txt" + if compression: + file_name += "." + compression + return file_name + + def get_testdata(self, enc_name: str, newline: str, lines: int, compression: str = None) -> str: + encoding = codecs.lookup(enc_name) if enc_name else None + fname = self.get_testdata_name(enc_name, newline, lines, compression) + p = self.file(fname) + if not p.exists(): + enc = encoding.name if encoding else None + opener = lookup_compression_algorithm(p) + f = opener(p, mode="wt", encoding=enc, newline=newline) + for n in range(lines): + i, t = self.line(n) + print(f"{i}|{t}", file=f) + f.close() + assert p.exists() + + return fname + + def line(self, i: int) -> Tuple[int, str]: + k = i + 1 + if i % 7 == 0: + s = "" + else: + # ÷ is interesting because it appears in all of UTF-8, Latin1 and + # Shift-JIS, but with different encodings. + s = f"÷{k}" + return (k, s) + + def test_utf8_uploads(self): + self.perform_line_endings_and_offsets_upload_tests('utf-8') + + def test_latin1_uploads(self): + self.perform_line_endings_and_offsets_upload_tests('latin1') + + def test_shiftjis_uploads(self): + self.perform_line_endings_and_offsets_upload_tests('shift-jis') + + def test_native_uploads(self): + self.perform_line_endings_and_offsets_upload_tests(None) + + def perform_line_endings_and_offsets_upload_tests(self, enc): + endings = [ + ('\n', '\n'), + ('\r\n', '\r\n'), + ('\n', None), + ('\r\n', None), + ] + offsets = { + None, + 0, + 1, + 2, + 5, + 15, + } + for file_ending, handler_ending in endings: + for offset in offsets: + with self.subTest(encoding=enc, file_ending=file_ending, handler_ending=handler_ending, offset=offset): + self.perform_upload_test(enc, file_ending, handler_ending, offset) + + def perform_upload_test(self, encoding, file_ending, handler_ending, offset=None, end=10, compression=None): + if offset is None: + offset_clause = '' + skip = 0 + else: + offset_clause = f" OFFSET {offset}" + skip = offset - 1 if offset else 0 + uploader = SafeDirectoryHandler(self.file(''), encoding, handler_ending) + self.conn.set_uploader(uploader) + fname = self.get_testdata(encoding, file_ending, end, compression=compression) + # Double check the compression, are we testing what we want tot test? + compression_prefix = self.compression_prefix(compression) + if compression_prefix: + f = self.open(fname, 'rb') + content = f.read() + content_prefix = content[:len(compression_prefix)] + f.close() + self.assertEqual(compression_prefix, content_prefix) + # Double check the testdata encoding, are we testing what we want tot test? + # These are the various encodings of the '÷' character as used by the + # .line() method above. + encmarker = {'utf-8': b'\xC3\xB7', 'latin1': b'\xF7', 'shift-jis': b'\x81\x80', None: None}[encoding] + if encmarker: + full_name = self.file(fname) + opener = lookup_compression_algorithm(full_name) + f = opener(full_name, 'rb') + content = f.read() + f.close() + self.assertTrue(content == b'' or encmarker in content) + # Run the test + # self.conn.rollback() + self.execute("DELETE FROM foo2") + self.execute("COPY" + offset_clause + " INTO foo2 FROM %s ON CLIENT", [fname]) + self.execute("SELECT * FROM foo2") + rows = self.cursor.fetchall() + expected = [self.line(i) for i in range(skip, end)] + self.assertEqual(expected, rows) + + def test_upload_utf8_lf_uses_binary(self): + class CustomHandler(SafeDirectoryHandler): + used_mode = None + + def __init__(self, dir): + super().__init__(dir, 'utf-8', '\n') + + def handle_upload(self, upload: Upload, filename: str, text_mode: bool, skip_amount: int): + super().handle_upload(upload, filename, text_mode, skip_amount) + # peek into the internals of the upload to see what was used. + if upload.writer: + self.used_mode = 'binary' + if upload.twriter: + # overwrite + self.used_mode = 'text' + + fname = self.get_testdata('utf-8', '\n', 10) + uploader = CustomHandler(self.file('')) + self.conn.set_uploader(uploader) + self.execute("COPY INTO foo2 FROM %s ON CLIENT", fname) + self.assertEqual('binary', uploader.used_mode) + + def test_download_encodings_and_line_endings(self): + compressions = [ + None, + 'gz', + 'bz2', + 'xz', + ] + if have_lz4: + compressions.append('lz4') + encodings = [ + 'utf-8', + 'latin1', + 'shift-jis', + None # means native + ] + file_endings = [ + "\n", + "\r\n", + None, + ] + for compression in compressions: + for encoding in encodings: + for handler_ending in file_endings + [None]: + with self.subTest(compression=compression, encoding=encoding, handler_ending=handler_ending): + self.perform_download_test(compression, encoding, handler_ending) + + def perform_download_test(self, compression, encoding, handler_ending): + # We want to check that when asked to use the given encoding and line endings, + # this happens. + n = 10 + downloader = SafeDirectoryHandler(self.file(''), encoding, handler_ending) + self.conn.set_downloader(downloader) + self.execute("DELETE FROM foo2") + # + enc = encoding or self.defaultencoding + expected = b"" + for k in range(n): + i, s = self.line(k) + expected += bytes(str(i), enc) + expected += b'|"' + expected += bytes(s, enc) + expected += b'"' + expected += bytes(handler_ending or os.linesep, enc) + self.execute("INSERT INTO foo2(i, t) VALUES (%s, %s)", [i, s]) + # + fname = self.get_testdata_name(encoding, handler_ending, compression=compression) + self.execute("COPY (SELECT * FROM foo2) INTO %s ON CLIENT", [fname]) + # check compression + compression_prefix = self.compression_prefix(compression) + if compression_prefix: + f = self.open(fname, 'rb') + content = f.read() + content_prefix = content[:len(compression_prefix)] + f.close() + self.assertEqual(compression_prefix, content_prefix) + # check contents + full_name = self.file(fname) + opener = lookup_compression_algorithm(full_name) + f = opener(full_name, 'rb') + content = f.read() + f.close() + # + self.assertEqual(expected, content) + + def test_download_utf8_lf_uses_binary(self): + class CustomHandler(SafeDirectoryHandler): + used_mode = None + + def __init__(self, dir): + super().__init__(dir, 'utf-8', '\n') + + def handle_download(self, download: Download, filename: str, text_mode: bool): + super().handle_download(download, filename, text_mode) + # peek into the internals of the download + if download.reader: + self.used_mode = 'binary' + if download.treader: + # overwrite + self.used_mode = 'text' + + fname = self.get_testdata_name('utf-8', '\n') + downloader = CustomHandler(self.file('')) + self.conn.set_downloader(downloader) + self.execute("COPY SELECT * FROM sys.generate_series(0,10) INTO %s ON CLIENT", fname) + self.assertEqual('binary', downloader.used_mode) + + def test_upload_with_compression_disabled(self): + fname = self.get_testdata('utf-8', '\n', 3, compression=None) + # give it a misleading name + misleading_name = 'banana.txt.gz' + copyfile(self.file(fname), self.file(misleading_name)) + # now upload it + handler = SafeDirectoryHandler(self.file(''), compression=False) + self.conn.set_uploader(handler) + self.execute("COPY INTO foo2 FROM %s ON CLIENT", misleading_name) + self.conn.commit() + self.execute("SELECT MAX(i) FROM foo2") + self.expect1(3) + + def test_download_with_compression_disabled(self): + fname = 'misleading.txt.gz' + handler = SafeDirectoryHandler(self.file(''), encoding='utf-8', newline='\n', compression=False) + self.conn.set_downloader(handler) + self.execute("COPY SELECT value FROM sys.generate_series(1,4) INTO %s ON CLIENT", fname) + # should not be gzipped + f = self.open(fname, 'rb') + content = f.read() + f.close() + self.assertEqual(b'1\n2\n3\n', content) + + +class TestNormalizeCrLf(TestCase): + + class Sink(BufferedIOBase): + def __init__(self): + self.written = b'' + + def writable(self) -> bool: + return True + + def write(self, buf): + self.written += bytes(buf) + return len(buf) + + def get_written(self): + res = self.written + self.written = b'' + return res + + def setUp(self): + self.sink = self.Sink() + self.normalizer = NormalizeCrLf(self.sink) + + def transaction(self, buf, expect_pending, expect_written): + n = self.normalizer.write(buf) + self.assertEqual(len(buf), n) + written = self.sink.get_written() + pending = self.normalizer.pending + self.assertEqual(expect_written, written) + self.assertEqual(expect_pending, pending) + + def test_normalizer(self): + self.assertEqual(False, self.normalizer.pending) + self.assertEqual(b'', self.sink.written) + + # can all be written through + self.transaction(b"\r\naaa\n\n\r\n", False, b"\naaa\n\n\n") + + # trailing CR pending + self.transaction(b"\n\r\naaa\r", True, b"\n\naaa") + + # LF consumes the pending CR + self.transaction(b"\n", False, b"\n") + + # a new pending CR + self.transaction(b"\r", True, b"") + + # a new pending CR + self.transaction(b"a", False, b"\ra") + + # CR after CR emits one CR and stays pending + self.transaction(b"\r", True, b"") + self.transaction(b"\r", True, b"\r") + + # empty write stays pending + self.transaction(b"", True, b"") + + # flushing the normalizer does not flush the pending CR + self.normalizer.flush() + self.assertTrue(self.normalizer.pending) + + # but closing it does + self.normalizer.close() + self.assertFalse(self.normalizer.pending) + self.assertEqual(b"\r", self.sink.get_written()) diff --git a/tests/test_multiline.py b/tests/test_multiline.py index 02504cd8..67e6d6b9 100644 --- a/tests/test_multiline.py +++ b/tests/test_multiline.py @@ -27,19 +27,24 @@ class MultilineResponseTest(unittest.TestCase): """ @patch('pymonetdb.mapi.Connection._putblock') - @patch('pymonetdb.mapi.Connection._getblock') - def test_failed_transactions(self, mock_getblock, _): - """This test is mocking 2 low level methods in the mapi.Connection class: - mapi.Connection._getblock + @patch('pymonetdb.mapi.Connection._getblock_raw') + def test_failed_transactions(self, mock_getblock_raw, _): + """This test mocks two low level methods in the mapi.Connection class: + mapi.Connection._getblock_raw mapi.Connection._putblock and tests mapi.Connection.cmd. Specifically we test for the event that a transaction has failed due to concurrency conflicts. """ query_text = 'sINSERT INTO tbl VALUES (1)' - response = "&2 1 -1\n!40000!COMMIT: transaction is aborted " \ - "because of concurrency conflicts, will ROLLBACK instead\n" - mock_getblock.return_value = response + response = b"&2 1 -1\n!40000!COMMIT: transaction is aborted " \ + b"because of concurrency conflicts, will ROLLBACK instead\n" + + def mocked_getblock_raw(buf, off): + buf[off:] = response + return off + len(response) + + mock_getblock_raw.side_effect = mocked_getblock_raw c = pymonetdb.mapi.Connection() # Simulate a connection diff --git a/tests/util.py b/tests/util.py index 72205bdc..7c01e0df 100644 --- a/tests/util.py +++ b/tests/util.py @@ -7,6 +7,7 @@ # # Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V. +from importlib import import_module from os import environ test_port = int(environ.get('MAPIPORT', 50000)) @@ -24,3 +25,11 @@ 'username': test_username, 'password': test_password, } + + +try: + import_module('lz4.frame') + have_lz4 = True +except ModuleNotFoundError: + have_lz4 = False + diff --git a/tests/windows_tests.py b/tests/windows_tests.py new file mode 100644 index 00000000..ff5cfe32 --- /dev/null +++ b/tests/windows_tests.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import os +import platform +import subprocess +import sys +import time + +import pytest + + +def start_mserver(monetdbdir, farmdir, dbname, port, logfile): + exe = os.path.join(monetdbdir, 'bin', 'mserver5') + if platform.system() == 'Windows': + exe += '.exe' + dbpath = os.path.join(farmdir, dbname) + try: + os.mkdir(dbpath) + except FileExistsError: + pass + # + env = dict((k, v) for k, v in os.environ.items()) + path_components = [ + os.path.join(monetdbdir, "bin"), + os.path.join(monetdbdir, "lib", "monetdb5"), + env['PATH'], + ] + env['PATH'] = os.pathsep.join(path_components) + sets = dict( + prefix=monetdbdir, + exec_prefix=monetdbdir, + mapi_port=port, + ) + cmdline = [ + exe, + f'--dbpath={dbpath}', + ] + for k, v in sets.items(): + cmdline.append('--set') + cmdline.append(f'{k}={v}') + print() + print('-- Starting mserver') + print(f'-- PATH={env["PATH"]}') + print(f'-- cmdline: {cmdline!r}') + t0 = time.time() + awkward_silence = t0 + 2 + proc = subprocess.Popen(cmdline, env=env, stderr=open(logfile, 'wb')) + # + while True: + try: + code = proc.wait(timeout=0.1) + exit(f'mserver unexpectedly exited with code {code}') + except subprocess.TimeoutExpired: + if os.path.exists(os.path.join(dbpath, '.started')): + break + t = time.time() + if t >= awkward_silence: + print(f"-- Waited for {t - t0:.1f}s") + awkward_silence = t + 1 + if t > t0 + 30.1: + print("Starting mserver took too long, giving up") + proc.kill() + exit("given up") + print('-- mserver has started') + return proc + + +if len(sys.argv) != 5: + exit(f"Usage: {sys.argv[0]} MONETDIR FARMDIR DBNAME PORT") +monet_dir = sys.argv[1] +farm_dir = sys.argv[2] +db_name = sys.argv[3] +db_port = int(sys.argv[4]) + +proc = start_mserver(monet_dir, farm_dir, db_name, db_port, os.path.join(farm_dir, "errlog")) +try: + print('The reported default encoding is', sys.getdefaultencoding()) + with open(os.path.join(farm_dir, 'w.txt'), 'w') as f: + print("The encoding for 'w' files is", f.encoding) + with open(os.path.join(farm_dir, 'w.txt'), 'wt') as f: + print("The encoding for 'wt files is", f.encoding) + ret = pytest.main(args=['-k', 'not test_control']) + exit(ret) +finally: + if proc.returncode is None: + print('-- Killing the server') + proc.kill() + else: + print('-- Server has already terminated')