From db59660522737542b7b7bbdfdf8cec726c59c5af Mon Sep 17 00:00:00 2001 From: James McKinney <26463+jpmckinney@users.noreply.github.com> Date: Fri, 12 Jul 2024 21:23:32 -0400 Subject: [PATCH] fix: Add glob, user dir and env var support on Windows #634 --- CHANGELOG.rst | 1 + csvkit/cli.py | 29 +++++++++++++++++++++++++++ tests/test_utilities/test_csvjoin.py | 10 +++++++++ tests/test_utilities/test_csvsql.py | 18 +++++++++++++++++ tests/test_utilities/test_csvstack.py | 14 +++++++++++++ 5 files changed, 72 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 64466fe1..fe283a1e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,7 @@ Unreleased - feat: Add a Docker image. - feat: Add man pages to the sdist and wheel distributions. - fix: :doc:`/scripts/csvstat` no longer errors when a column is a time delta and :code:`--json` is set. +- fix: When taking arguments from ``sys.argv`` on Windows, glob patterns, user directories, and environment variables are expanded. 2.0.0 - May 1, 2024 ------------------- diff --git a/csvkit/cli.py b/csvkit/cli.py index 7d3554c4..a972cdb8 100644 --- a/csvkit/cli.py +++ b/csvkit/cli.py @@ -7,8 +7,11 @@ import gzip import itertools import lzma +import os +import re import sys import warnings +from glob import glob from os.path import splitext import agate @@ -73,6 +76,11 @@ def __init__(self, args=None, output_file=None, error_file=None): """ Perform argument processing and other setup for a CSVKitUtility. """ + if args is None: + args = sys.argv[1:] + if os.name == 'nt': + args = _expand_args(args) + self._init_common_parser() self.add_arguments() self.args = self.argparser.parse_args(args) @@ -550,3 +558,24 @@ def parse_column_identifiers(ids, column_names, column_offset=1, excluded_column excludes.append(match_column_identifier(column_names, x, column_offset)) return [c for c in columns if c not in excludes] + + +# Adapted from https://github.com/pallets/click/blame/main/src/click/utils.py +def _expand_args(args): + out = [] + + for arg in args: + arg = os.path.expanduser(arg) + arg = os.path.expandvars(arg) + + try: + matches = glob(arg, recursive=True) + except re.error: + matches = [] + + if matches: + out.extend(matches) + else: + out.append(arg) + + return out diff --git a/tests/test_utilities/test_csvjoin.py b/tests/test_utilities/test_csvjoin.py index 3faf6b9c..47a2fbe9 100644 --- a/tests/test_utilities/test_csvjoin.py +++ b/tests/test_utilities/test_csvjoin.py @@ -1,4 +1,6 @@ +import os import sys +import unittest from unittest.mock import patch from csvkit.utilities.csvjoin import CSVJoin, launch_new_instance @@ -34,6 +36,14 @@ def test_join_options(self): 'You must provide join column names when performing an outer join.', ) + @unittest.skipIf(os.name != 'nt', 'Windows only') + def test_glob(self): + self.assertRows(['--no-inference', '-c', 'a', 'examples/dummy?.csv'], [ + ['a', 'b', 'c', 'b2', 'c2'], + ['1', '2', '3', '2', '3'], + ['1', '2', '3', '4', '5'], + ]) + def test_sequential(self): output = self.get_output_as_io(['examples/join_a.csv', 'examples/join_b.csv']) self.assertEqual(len(output.readlines()), 4) diff --git a/tests/test_utilities/test_csvsql.py b/tests/test_utilities/test_csvsql.py index e7aab926..c6d13ca7 100644 --- a/tests/test_utilities/test_csvsql.py +++ b/tests/test_utilities/test_csvsql.py @@ -1,6 +1,7 @@ import io import os import sys +import unittest from textwrap import dedent from unittest.mock import patch @@ -63,6 +64,23 @@ def tearDown(self): if os.path.exists(self.db_file): os.remove(self.db_file) + @unittest.skipIf(os.name != 'nt', 'Windows only') + def test_glob(self): + sql = self.get_output(['examples/dummy?.csv']) + + self.assertEqual(sql.replace('\t', ' '), dedent('''\ + CREATE TABLE dummy2 ( + a BOOLEAN NOT NULL, + b DECIMAL NOT NULL, + c DECIMAL NOT NULL + ); + CREATE TABLE dummy3 ( + a BOOLEAN NOT NULL, + b DECIMAL NOT NULL, + c DECIMAL NOT NULL + ); + ''')) # noqa: W291 + def test_create_table(self): sql = self.get_output(['--tables', 'foo', 'examples/testfixed_converted.csv']) diff --git a/tests/test_utilities/test_csvstack.py b/tests/test_utilities/test_csvstack.py index cb1336a8..3e3898b9 100644 --- a/tests/test_utilities/test_csvstack.py +++ b/tests/test_utilities/test_csvstack.py @@ -1,4 +1,6 @@ +import os import sys +import unittest from unittest.mock import patch from csvkit.utilities.csvstack import CSVStack, launch_new_instance @@ -20,6 +22,18 @@ def test_options(self): 'The number of grouping values must be equal to the number of CSV files being stacked.', ) + @unittest.skipIf(os.name != 'nt', 'Windows only') + def test_glob(self): + self.assertRows(['examples/dummy*.csv'], [ + ['a', 'b', 'c', 'd'], + ['1', '2', '3', ''], + ['1', '2', '3', ''], + ['1', '2', '3', ''], + ['1', '4', '5', ''], + ['1', '2', '3', ''], + ['1', '2', '3', '4'], + ]) + def test_skip_lines(self): self.assertRows(['--skip-lines', '3', 'examples/test_skip_lines.csv', 'examples/test_skip_lines.csv'], [ ['a', 'b', 'c'],