diff --git a/src/masoniteorm/connections/MySQLConnection.py b/src/masoniteorm/connections/MySQLConnection.py index 5d0c2175..355f2da8 100644 --- a/src/masoniteorm/connections/MySQLConnection.py +++ b/src/masoniteorm/connections/MySQLConnection.py @@ -31,10 +31,16 @@ def __init__( if str(port).isdigit(): self.port = int(self.port) self.database = database + self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} + self.connection_pool_size = ( + full_details.get( + "connection_pooling_max_size", 100 + ) + ) self.options = options or {} self._cursor = None self.open = 0 @@ -48,42 +54,80 @@ def make_connection(self): if self._dry: return + if self.has_global_connection(): + return self.get_global_connection() + + # Check if there is an available connection in the pool + self._connection = self.create_connection() + self.enable_disable_foreign_keys() + + return self + + def close_connection(self): + if ( + self.full_details.get("connection_pooling_enabled") + and len(CONNECTION_POOL) < self.connection_pool_size + ): + CONNECTION_POOL.append(self._connection) + self.open = 0 + self._connection = None + + def create_connection(self, autocommit=True): + try: import pymysql except ModuleNotFoundError: raise DriverNotFound( - "You must have the 'pymysql' package installed to make a connection to MySQL. Please install it using 'pip install pymysql'" + "You must have the 'pymysql' package " + "installed to make a connection to MySQL. " + "Please install it using 'pip install pymysql'" ) + import pendulum + import pymysql.converters - try: - import pendulum - import pymysql.converters - - pymysql.converters.conversions[ - pendulum.DateTime - ] = pymysql.converters.escape_datetime - except ImportError: - pass - - if self.has_global_connection(): - return self.get_global_connection() - - self._connection = pymysql.connect( - cursorclass=pymysql.cursors.DictCursor, - autocommit=True, - host=self.host, - user=self.user, - password=self.password, - port=self.port, - db=self.database, - **self.options + pymysql.converters.conversions[pendulum.DateTime] = ( + pymysql.converters.escape_datetime ) - self.enable_disable_foreign_keys() + # Initialize the connection pool if the option is set + initialize_size = self.full_details.get("connection_pooling_min_size") + if initialize_size and len(CONNECTION_POOL) < initialize_size: + for _ in range(initialize_size - len(CONNECTION_POOL)): + connection = pymysql.connect( + cursorclass=pymysql.cursors.DictCursor, + autocommit=autocommit, + host=self.host, + user=self.user, + password=self.password, + port=self.port, + database=self.database, + **self.options + ) + CONNECTION_POOL.append(connection) + + if ( + self.full_details.get("connection_pooling_enabled") + and CONNECTION_POOL + and len(CONNECTION_POOL) > 0 + ): + connection = CONNECTION_POOL.pop() + else: + connection = pymysql.connect( + cursorclass=pymysql.cursors.DictCursor, + autocommit=autocommit, + host=self.host, + user=self.user, + password=self.password, + port=self.port, + database=self.database, + **self.options + ) + + connection.close = self.close_connection self.open = 1 - return self + return connection def reconnect(self): self._connection.connect() @@ -139,15 +183,19 @@ def get_cursor(self): return self._cursor def query(self, query, bindings=(), results="*"): - """Make the actual query that will reach the database and come back with a result. + """Make the actual query that + will reach the database and come back with a result. Arguments: - query {string} -- A string query. This could be a qmarked string or a regular query. + query {string} -- A string query. + This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: - results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' - else it will return 'fetchOne' and return a single record. (default: {"*"}) + results {str|1} -- If the results is equal to an + asterisks it will call 'fetchAll' + else it will return 'fetchOne' and + return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None @@ -156,7 +204,10 @@ def query(self, query, bindings=(), results="*"): if self._dry: return {} - if not self._connection.open: + if not self.open: + if self._connection is None: + self._connection = self.create_connection() + self._connection.connect() self._cursor = self._connection.cursor() diff --git a/src/masoniteorm/connections/PostgresConnection.py b/src/masoniteorm/connections/PostgresConnection.py index 19919d95..0bbfe172 100644 --- a/src/masoniteorm/connections/PostgresConnection.py +++ b/src/masoniteorm/connections/PostgresConnection.py @@ -34,8 +34,10 @@ def __init__( self.database = database self.user = user self.password = password + self.prefix = prefix self.full_details = full_details or {} + self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.transaction_level = 0 @@ -56,16 +58,7 @@ def make_connection(self): if self.has_global_connection(): return self.get_global_connection() - schema = self.schema or self.full_details.get("schema") - - self._connection = psycopg2.connect( - database=self.database, - user=self.user, - password=self.password, - host=self.host, - port=self.port, - options=f"-c search_path={schema}" if schema else "", - ) + self._connection = self.create_connection() self._connection.autocommit = True @@ -75,6 +68,53 @@ def make_connection(self): return self + def create_connection(self): + import psycopg2 + + # Initialize the connection pool if the option is set + initialize_size = self.full_details.get("connection_pooling_min_size") + if ( + self.full_details.get("connection_pooling_enabled") + and initialize_size + and len(CONNECTION_POOL) < initialize_size + ): + for _ in range(initialize_size - len(CONNECTION_POOL)): + connection = psycopg2.connect( + database=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + options=( + f"-c search_path={self.schema or self.full_details.get('schema')}" + if self.schema or self.full_details.get("schema") + else "" + ), + ) + CONNECTION_POOL.append(connection) + + if ( + self.full_details.get("connection_pooling_enabled") + and CONNECTION_POOL + and len(CONNECTION_POOL) > 0 + ): + connection = CONNECTION_POOL.pop() + else: + connection = psycopg2.connect( + database=self.database, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + options=( + f"-c search_path={self.schema or self.full_details.get('schema')}" + if self.schema or self.full_details.get("schema") + else "" + ), + ) + + return connection + def get_database_name(self): return self.database @@ -93,6 +133,17 @@ def get_default_post_processor(cls): def reconnect(self): pass + def close_connection(self): + if ( + self.full_details.get("connection_pooling_enabled") + and len(CONNECTION_POOL) < self.connection_pool_size + ): + CONNECTION_POOL.append(self._connection) + else: + self._connection.close() + + self._connection = None + def commit(self): """Transaction""" if self.get_transaction_level() == 1: @@ -140,7 +191,7 @@ def query(self, query, bindings=(), results="*"): dict|None -- Returns a dictionary of results or None """ try: - if self._connection.closed: + if not self._connection or self._connection.closed: self.make_connection() self.set_cursor() @@ -164,4 +215,5 @@ def query(self, query, bindings=(), results="*"): finally: if self.get_transaction_level() <= 0: self.open = 0 - self._connection.close() + self.close_connection() + # self._connection.close() diff --git a/tests/integrations/config/database.py b/tests/integrations/config/database.py index fe4473fa..ed1fd02e 100644 --- a/tests/integrations/config/database.py +++ b/tests/integrations/config/database.py @@ -24,6 +24,7 @@ They can be named whatever you want. """ + DATABASES = { "default": "mysql", "mysql": { @@ -37,6 +38,9 @@ "options": {"charset": "utf8mb4"}, "log_queries": True, "propagate": False, + "connection_pooling_enabled": True, + "connection_pooling_max_size": 10, + "connection_pooling_min_size": None, }, "t": {"driver": "sqlite", "database": "orm.sqlite3", "log_queries": True, "foreign_keys": True}, "devprod": { @@ -69,6 +73,9 @@ "password": os.getenv("POSTGRES_DATABASE_PASSWORD"), "database": os.getenv("POSTGRES_DATABASE_DATABASE"), "port": os.getenv("POSTGRES_DATABASE_PORT"), + "connection_pooling_enabled": True, + "connection_pooling_max_size": 10, + "connection_pooling_min_size": 2, "prefix": "", "log_queries": True, "propagate": False, @@ -101,6 +108,8 @@ "authentication": "ActiveDirectoryPassword", "driver": "ODBC Driver 17 for SQL Server", "connection_timeout": 15, + "connection_pooling": False, + "connection_pooling_size": 100, }, }, }