diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ec8951ea..480a85c6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -31,6 +31,7 @@ jobs: sudo -u monetdb monetdbd set control=yes ${{ env.DBFARM }} sudo -u monetdb monetdbd set passphrase=testdb ${{ env.DBFARM }} sudo -u monetdb monetdbd start ${{ env.DBFARM }} + sudo -u monetdb chmod o+rwx /tmp/.s.mero* - name: Create MonetDB test database run: | sudo -u monetdb monetdb create demo diff --git a/pymonetdb/mapi.py b/pymonetdb/mapi.py index 7c35f0ba..e4902767 100644 --- a/pymonetdb/mapi.py +++ b/pymonetdb/mapi.py @@ -382,6 +382,10 @@ def _challenge_response(self, challenge: str, password: str): # noqa: C901 def _getblock_and_transfer_files(self): """ read one mapi encoded block and take care of any file transfers the server requests""" + if self.language == 'control' and not self.hostname: + # control connections do not use the blocking protocol and do not transfer files + return self._recv_to_end() + buffer = self._get_buffer() offset = 0 @@ -405,6 +409,9 @@ def _getblock_and_transfer_files(self): def _getblock(self) -> str: """ read one mapi encoded block """ + if self.language == 'control' and not self.hostname: + # control connections do not use the blocking protocol + return self._recv_to_end() buf = self._get_buffer() end = self._getblock_raw(buf, 0) ret = str(memoryview(buf)[:end], 'utf-8') @@ -450,6 +457,19 @@ def _getbytes(self, buffer: bytearray, offset: int, count: int) -> int: offset += n return end + def _recv_to_end(self) -> str: + """ + Read bytes from the socket until the server closes the connection + """ + parts = [] + while True: + assert self.socket + received = self.socket.recv(4096) + if not received: + break + parts.append(received) + return str(b"".join(parts).strip(), 'utf-8') + def _get_buffer(self) -> bytearray: """Retrieve a previously stashed buffer for reuse, or create a new one""" if self.stashed_buffer: @@ -466,13 +486,15 @@ def _stash_buffer(self, buffer): def _putblock(self, block): """ wrap the line in mapi format and put it into the socket """ - self._putblock_inet_raw(block.encode(), True) + data = block.encode('utf-8') + if self.language == 'control' and not self.hostname: + # control does not use the blocking protocol + return self._send_all_and_shutdown(data) + else: + self._putblock_raw(block.encode(), True) - def _putblock_raw(self, block, finish: bool): + def _putblock_raw(self, block, finish): """ put the data into the socket """ - self._putblock_inet_raw(block, finish) - - def _putblock_inet_raw(self, block, finish): pos = 0 last = 0 while not last: @@ -485,6 +507,17 @@ def _putblock_inet_raw(self, block, finish): self.socket.send(data) pos += length + def _send_all_and_shutdown(self, block): + """ put the data into the socket """ + pos = 0 + end = len(block) + block = memoryview(block) + while pos < end: + data = block[pos:pos + 8192] + nsent = self.socket.send(data) + pos += nsent + self.socket.shutdown(socket.SHUT_WR) + def __del__(self): if self.socket: self.socket.close() diff --git a/tests/test_control.py b/tests/test_control.py index 812211c2..74a2569c 100644 --- a/tests/test_control.py +++ b/tests/test_control.py @@ -30,12 +30,12 @@ class TestControl(unittest.TestCase): Where /var/lib/monetdb is the path to your dbfarm. Don't forget to restart the db after setting the credentials. """ - def setUp(self): + def setUpControl(self): # use tcp - self.control = Control(test_hostname, test_port, test_passphrase) + return Control(hostname=test_hostname, port=test_port, passphrase=test_passphrase) - # use socket - # self.control = Control() + def setUp(self): + self.control = self.setUpControl() do_without_fail(lambda: self.control.stop(database_name)) do_without_fail(lambda: self.control.destroy(database_name)) @@ -139,3 +139,9 @@ def test_defaults(self): @unittest.skipUnless(test_full, "full test disabled") def test_neighbours(self): self.control.neighbours() + + +class TestLocalControl(TestControl): + def setUpControl(self): + # use unix domain socket + return Control(port=test_port, passphrase=test_passphrase)