Skip to content

Commit

Permalink
refactor(back): connection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kareemmahlees committed Dec 15, 2023
1 parent 7b9570e commit 697d3bc
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 116 deletions.
31 changes: 28 additions & 3 deletions src-tauri/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use crate::utils::{read_from_connections_file, write_into_connections_file, Drivers};
use sqlx::{AnyConnection, Connection};
use std::time::Duration;

use crate::{
utils::{read_from_connections_file, write_into_connections_file, Drivers},
DbInstance,
};
use sqlx::{any::AnyPoolOptions, AnyConnection, Connection};
use tauri::State;

#[tauri::command]
pub async fn test_connection(conn_string: String) -> Result<String, String> {
Expand All @@ -22,16 +28,35 @@ pub fn create_connection_record(
app: tauri::AppHandle,
conn_string: String,
conn_name: String,
driver: Drivers,
) -> Result<(), String> {
write_into_connections_file(
app.path_resolver().app_config_dir(),
Drivers::SQLITE,
driver,
conn_string,
conn_name,
);
Ok(())
}

#[tauri::command]
pub async fn establish_connection(
connection: State<'_, DbInstance>,
conn_string: String,
driver: Drivers,
) -> Result<(), String> {
sqlx::any::install_default_drivers();
let pool = AnyPoolOptions::new()
.acquire_timeout(Duration::new(5, 0))
.test_before_acquire(true)
.connect(&conn_string)
.await
.map_err(|_| "Couldn't establish connection to db".to_string())?;
*connection.pool.lock().await = Some(pool);
*connection.driver.lock().await = Some(driver);
Ok(())
}

#[tauri::command]
pub fn connections_exist(app: tauri::AppHandle) -> Result<bool, String> {
let (_, connections) = read_from_connections_file(app.path_resolver().app_config_dir());
Expand Down
18 changes: 10 additions & 8 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,38 @@

mod connection;
mod sqlite;
mod table;
mod utils;

use connection::{
connections_exist, create_connection_record, get_connection_details, get_connections,
test_connection,
};
use sqlite::{
connect_sqlite, create_row, delete_row, get_columns, get_columns_definition, get_rows,
get_tables, update_row,
connections_exist, create_connection_record, establish_connection, get_connection_details,
get_connections, test_connection,
};
use sqlite::{create_row, delete_row, get_columns, get_columns_definition, get_rows, update_row};
use sqlx::Pool;
use table::get_tables;
use tokio::sync::Mutex;
use utils::Drivers;

#[derive(Default, Debug)]
#[derive(Default)]
pub struct DbInstance {
pool: Mutex<Option<Pool<sqlx::any::Any>>>,
driver: Mutex<Option<Drivers>>,
}

fn main() {
tauri::Builder::default()
.manage(DbInstance {
pool: Default::default(),
driver: Default::default(),
})
.invoke_handler(tauri::generate_handler![
test_connection,
create_connection_record,
establish_connection,
connections_exist,
get_connections,
get_connection_details,
connect_sqlite,
get_tables,
get_rows,
get_columns,
Expand Down
33 changes: 0 additions & 33 deletions src-tauri/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,6 @@ use std::result::Result::Ok;
use std::vec;
use tauri::State;

#[tauri::command]
pub async fn connect_sqlite(
connection: State<'_, DbInstance>,
conn_string: String,
) -> Result<(), String> {
sqlx::any::install_default_drivers();
*connection.pool.lock().await = Some(sqlx::AnyPool::connect(&conn_string).await.unwrap());
Ok(())
}

#[tauri::command]
pub async fn get_tables(connection: State<'_, DbInstance>) -> Result<Option<Vec<String>>, ()> {
let long_lived = connection.pool.lock().await;
let conn = long_lived.as_ref().unwrap();
let rows = sqlx::query(
"SELECT name
FROM sqlite_schema
WHERE type ='table'
AND name NOT LIKE 'sqlite_%';",
)
.fetch_all(conn)
.await
.unwrap();
if rows.len() == 0 {
()
}
let mut result: Vec<String> = vec![];
for (_, row) in rows.iter().enumerate() {
result.push(row.get::<String, &str>("name"))
}
Ok(Some(result))
}

#[tauri::command]
pub async fn get_rows(
connection: State<'_, DbInstance>,
Expand Down
26 changes: 26 additions & 0 deletions src-tauri/src/table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::DbInstance;
use sqlx::Row;
use tauri::State;

#[tauri::command]
pub async fn get_tables(connection: State<'_, DbInstance>) -> Result<Option<Vec<String>>, ()> {
let long_lived = connection.pool.lock().await;
let conn = long_lived.as_ref().unwrap();
let rows = sqlx::query(
"SELECT name
FROM sqlite_schema
WHERE type ='table'
AND name NOT LIKE 'sqlite_%';",
)
.fetch_all(conn)
.await
.unwrap();
if rows.len() == 0 {
()
}
let mut result: Vec<String> = vec![];
for (_, row) in rows.iter().enumerate() {
result.push(row.get::<String, &str>("name"))
}
Ok(Some(result))
}
11 changes: 7 additions & 4 deletions src-tauri/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@ use std::io::{BufReader, BufWriter, Write};
use std::path::PathBuf;
use uuid::Uuid;

#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum Drivers {
SQLITE,
PSQL,
MYSQL,
SQLite,
PostgreSQL,
MySQL,
}

#[derive(Serialize, Deserialize)]
pub struct ConnConfig {
driver: Drivers,
#[serde(rename = "connString")]
conn_string: String,
#[serde(rename = "connName")]
conn_name: String,
}

Expand Down
9 changes: 6 additions & 3 deletions src/app/connect/_components/conn-params.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import {
FormMessage
} from "@/components/ui/form"
import { Input } from "@/components/ui/input"
import { SupportedDrivers } from "@/lib/types"
import { Drivers, DriversValues } from "@/lib/types"
import { constructConnectionString } from "@/lib/utils"
import { zodResolver } from "@hookform/resolvers/zod"
import { useRouter } from "next/navigation"
import { type FC } from "react"
import { useForm } from "react-hook-form"
import { z } from "zod"
Expand All @@ -27,10 +28,11 @@ const formSchema = z.object({
})

interface ConnectionParamsFormProps {
driver: SupportedDrivers.MYSQL | SupportedDrivers.PSQL
driver: Exclude<DriversValues, typeof Drivers.SQLite>
}

const ConnectionParamsForm: FC<ConnectionParamsFormProps> = ({ driver }) => {
const router = useRouter()
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema)
})
Expand All @@ -50,7 +52,8 @@ const ConnectionParamsForm: FC<ConnectionParamsFormProps> = ({ driver }) => {
port,
db
})
await createConnectionRecord(connName, connString)
await createConnectionRecord(connName, connString, driver)
router.push("/connections")
}

const onClickTest = async ({
Expand Down
6 changes: 3 additions & 3 deletions src/app/connect/_components/conn-radio.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { Label } from "@/components/ui/label"
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"
import { SupportedDrivers } from "@/lib/types"
import { Drivers, type DriversValues } from "@/lib/types"
import { FC, useState } from "react"
import ConnectionParamsForm from "./conn-params"
import ConnectionStringForm from "./conn-string"

interface ConnectionParamsProps {
driver: SupportedDrivers.PSQL | SupportedDrivers.MYSQL
driver: Exclude<DriversValues, typeof Drivers.SQLite>
}

const ConnectionRadio: FC<ConnectionParamsProps> = ({ driver }) => {
Expand Down Expand Up @@ -40,7 +40,7 @@ const ConnectionRadio: FC<ConnectionParamsProps> = ({ driver }) => {
</div>
</RadioGroup>
{radioValue === "conn_params" && <ConnectionParamsForm driver={driver} />}
{radioValue === "conn_string" && <ConnectionStringForm />}
{radioValue === "conn_string" && <ConnectionStringForm driver={driver} />}
</>
)
}
Expand Down
28 changes: 21 additions & 7 deletions src/app/connect/_components/conn-string.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import {
FormMessage
} from "@/components/ui/form"
import { Input } from "@/components/ui/input"
import { Drivers, DriversValues } from "@/lib/types"
import { zodResolver } from "@hookform/resolvers/zod"
import { useRouter } from "next/navigation"
import { FC } from "react"
import { useForm } from "react-hook-form"
import { z } from "zod"
import { createConnectionRecord, testConnection } from "../actions"
Expand All @@ -19,20 +22,28 @@ const formSchema = z.object({
connString: z.string()
})

const ConnectionStringForm = () => {
interface ConnectionParamsFormProps {
driver: Exclude<DriversValues, typeof Drivers.SQLite>
}

const ConnectionStringForm: FC<ConnectionParamsFormProps> = ({ driver }) => {
const router = useRouter()
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema)
})
const onSubmit = async (values: z.infer<typeof formSchema>) => {
await createConnectionRecord(values.connName, values.connString)

const onClickConnect = async (values: z.infer<typeof formSchema>) => {
await createConnectionRecord(values.connName, values.connString, driver)
router.push("/connections")
}

const onTest = async (values: z.infer<typeof formSchema>) => {
const onClickTest = async (values: z.infer<typeof formSchema>) => {
await testConnection(values.connString)
}

return (
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-10">
<form className="space-y-10">
<FormField
control={form.control}
name="connName"
Expand Down Expand Up @@ -64,12 +75,15 @@ const ConnectionStringForm = () => {
)}
/>
<div className="col-span-2 flex justify-center items-center gap-x-4">
<Button variant={"secondary"} type="submit">
<Button
variant={"secondary"}
onClick={form.handleSubmit(onClickConnect)}
>
Connect
</Button>
<Button
className="bg-green-500 hover:bg-green-700"
onClick={form.handleSubmit(onTest)}
onClick={form.handleSubmit(onClickTest)}
>
Test
</Button>
Expand Down
8 changes: 4 additions & 4 deletions src/app/connect/_components/sqlite-connection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
FormMessage
} from "@/components/ui/form"
import { Input } from "@/components/ui/input"
import { SupportedDrivers } from "@/lib/types"
import { Drivers } from "@/lib/types"
import { constructConnectionString } from "@/lib/utils"
import { zodResolver } from "@hookform/resolvers/zod"
import { open } from "@tauri-apps/api/dialog"
Expand Down Expand Up @@ -62,16 +62,16 @@ const ConnectionForm: FC<ConnectionFormProps> = ({ selectedPath }) => {
})
const onClickConnect = (values: z.infer<typeof formSchema>) => {
const connString = constructConnectionString({
driver: SupportedDrivers.SQLITE,
driver: Drivers.SQLite,
filePath: selectedPath
})
createConnectionRecord(values.connName, connString)
createConnectionRecord(values.connName, connString, Drivers.SQLite)
router.push("/connections")
}

const onClickTest = async () => {
const connString = constructConnectionString({
driver: SupportedDrivers.SQLITE,
driver: Drivers.SQLite,
filePath: selectedPath
})
await testConnection(connString)
Expand Down
7 changes: 5 additions & 2 deletions src/app/connect/actions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { type DriversValues } from "@/lib/types"
import { invoke } from "@tauri-apps/api/tauri"
import toast from "react-hot-toast"

Expand All @@ -20,10 +21,12 @@ export const testConnection = async (connString: string) => {

export const createConnectionRecord = async (
connName: string,
connString: string
connString: string,
driver: DriversValues
) => {
await invoke("create_connection_record", {
connString,
connName
connName,
driver
})
}
Loading

0 comments on commit 697d3bc

Please sign in to comment.