diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ac9b65b --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +REPLICATE_API_TOKEN= + +SUPABASE_KEY= +SUPABASE_URL= + +GOOGLE_CLIENT_ID= +GOOGLE_CLIENT_SECRET= +SECRET_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d17870 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +.env \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..19ca54f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.8-slim + +WORKDIR /usr/src/app +COPY . . +RUN pip install --no-cache-dir gradio +EXPOSE 7860 +ENV GRADIO_SERVER_NAME="0.0.0.0" + +CMD ["python", "app.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..b3ddb54 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +test card +4242 4242 4242 4242 +wip \ No newline at end of file diff --git a/assets/logo.jpg b/assets/logo.jpg new file mode 100644 index 0000000..f941660 Binary files /dev/null and b/assets/logo.jpg differ diff --git a/auth.py b/auth.py new file mode 100644 index 0000000..31e2bc2 --- /dev/null +++ b/auth.py @@ -0,0 +1,12 @@ +from starlette.config import Config +from authlib.integrations.starlette_client import OAuth +from config import GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET + +config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} +starlette_config = Config(environ=config_data) +oauth = OAuth(starlette_config) +oauth.register( + name='google', + server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', + client_kwargs={'scope': 'openid email profile'}, +) \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..5053f76 --- /dev/null +++ b/config.py @@ -0,0 +1,18 @@ +import os +from dotenv import load_dotenv +import logging + +load_dotenv() + +SUPABASE_URL = os.getenv("SUPABASE_URL") +SUPABASE_KEY = os.getenv("SUPABASE_KEY") +SECRET_KEY = os.getenv("SECRET_KEY") + +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") + +STRIPE_API_KEY = os.getenv("STRIPE_API_KEY") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) \ No newline at end of file diff --git a/database.py b/database.py new file mode 100644 index 0000000..93e7f9a --- /dev/null +++ b/database.py @@ -0,0 +1,45 @@ +import json +from supabase import create_client, Client +from config import SUPABASE_URL, SUPABASE_KEY + +supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) + +def get_user_credits(user_id): + user = supabase.table("users").select("generation_credits, train_credits").eq("id", user_id).execute() + if user.data: + return user.data[0]["generation_credits"], user.data[0]["train_credits"] + return 0, 0 + +def update_user_credits(user_id, generation_credits, train_credits): + supabase.table("users").update({ + "generation_credits": generation_credits, + "train_credits": train_credits + }).eq("id", user_id).execute() + +def get_or_create_user(google_id, email, name, given_name, profile_picture): + user = supabase.table("users").select("*").eq("google_id", google_id).execute() + + if not user.data: + new_user = { + "google_id": google_id, + "email": email, + "name": name, + "profile_picture": profile_picture, + "generation_credits": 2, + "train_credits": 1, + "given_name": given_name + } + result = supabase.table("users").insert(new_user).execute() + return result.data[0] + else: + return user.data[0] + +def get_lora_models_info(): + lora_models = supabase.table("lora_models").select("*").execute() + return lora_models.data + +def get_user_by_id(user_id): + user = supabase.table("users").select("*").eq("id", user_id).execute() + if user.data: + return user.data[0] + return None \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..091124d --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,266 @@ +import gradio as gr + +import os +import json +import zipfile +from pathlib import Path + +from database import get_user_credits, update_user_credits, get_lora_models_info +from services.image_generation import generate_image +from services.train_lora import lora_pipeline +from utils.image_utils import url_to_pil_image + +lora_models = get_lora_models_info() + + +if not isinstance(lora_models, list): + raise ValueError("Expected loras_models to be a list of dictionaries.") + +login_css_path = Path(__file__).parent / 'static/css/login.css' +main_css_path = Path(__file__).parent / 'static/css/main.css' +landing_html_path = Path(__file__).parent / 'static/html/landing.html' +main_header_path = Path(__file__).parent / 'static/html/main_header.html' + +if login_css_path.is_file(): # Check if the file exists + with login_css_path.open() as file: + login_css = file.read() + +if main_css_path.is_file(): # Check if the file exists + with main_css_path.open() as file: + main_css = file.read() + +if landing_html_path.is_file(): + with landing_html_path.open() as file: + landin_page = file.read() + +if main_header_path.is_file(): + with main_header_path.open() as file: + main_header = file.read() + +def update_selection(evt: gr.SelectData, width, height): + selected_lora = lora_models[evt.index] + new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}" + trigger_word = selected_lora["trigger_word"] + updated_text = f"#### Palabra clave: {trigger_word} ✨" + + if "aspect" in selected_lora: + if selected_lora["aspect"] == "portrait": + width, height = 768, 1024 + elif selected_lora["aspect"] == "landscape": + width, height = 1024, 768 + + return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height + +def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate): + if not files: + return "No images uploaded. Please upload images before training." + + # Create a directory in the user's home folder + output_dir = os.path.expanduser("~/gradio_training_data") + os.makedirs(output_dir, exist_ok=True) + + # Create a zip file in the output directory + zip_path = os.path.join(output_dir, "training_data.zip") + + with zipfile.ZipFile(zip_path, 'w') as zipf: + for file_info in files: + file_path = file_info[0] # The first element of the tuple is the file path + file_name = os.path.basename(file_path) + zipf.write(file_path, file_name) + + print(f"Zip file created at: {zip_path}") + + print(f'[INFO] Procesando {trigger_word}') + # Now call the train_lora function with the zip file path + result = lora_pipeline(zip_path, + model_name, + trigger_word=trigger_word, + steps=train_steps, + lora_rank=lora_rank, + batch_size=batch_size, + autocaption=True, + learning_rate=learning_rate) + + return f"{result}\n\nZip file saved at: {zip_path}" + +def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): + user = request.session.get('user') + if not user: + raise gr.Error("User not authenticated. Please log in.") + + generation_credits, _ = get_user_credits(user['id']) + + if generation_credits <= 0: + raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.") + + image_url = generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress) + image = url_to_pil_image(image_url) + + # Update user's credits + new_generation_credits = generation_credits - 1 + update_user_credits(user['id'], new_generation_credits, user['train_credits']) + + # Update session data + user['generation_credits'] = new_generation_credits + request.session['user'] = user + + print(f"Generation credits remaining: {new_generation_credits}") + + return image, new_generation_credits + +def display_credits(request: gr.Request): + user = request.session.get('user') + if user: + generation_credits, train_credits = get_user_credits(user['id']) + return generation_credits, train_credits + return 0, 0 + +def load_greet_and_credits(request: gr.Request): + greeting = greet(request) + generation_credits, train_credits = display_credits(request) + return greeting, generation_credits, train_credits + +def greet(request: gr.Request): + user = request.session.get('user') + if user: + with gr.Column(): + with gr.Row(): + greeting = f"Hola 👋 {user['given_name']}!" + return f"{greeting}\n" + return "OBTU AI. Please log in." + +with gr.Blocks(theme=gr.themes.Soft(), css=login_css) as login_demo: + with gr.Column(elem_id="google-btn-container", elem_classes="google-btn-container svelte-vt1mxs gap"): + btn = gr.Button("Iniciar Sesion con Google", elem_classes="login-with-google-btn") + _js_redirect = """ + () => { + url = '/login' + window.location.search; + window.open(url, '_blank'); + } + """ + btn.click(None, js=_js_redirect) + gr.HTML(landin_page) + + +header = '' + +with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo: + title = gr.HTML(main_header) + + with gr.Column(elem_id="logout-btn-container"): + gr.Button("Salir", link="/logout", elem_id="logout_btn") + + + greetings = gr.Markdown("Loading user information...") + gr.Button("Comprar Creditos", link="/buy_credits", elem_id="buy_credits_btn") + + selected_index = gr.State(None) + + with gr.Row(): + with gr.Column(): + generation_credits_display = gr.Number(label="Generation Credits", precision=0, interactive=False) + with gr.Column(): + train_credits_display = gr.Number(label="Training Credits", precision=0, interactive=False) + + + with gr.Tabs(): + with gr.TabItem('Generacion'): + with gr.Row(): + with gr.Column(scale=3): + prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Ingresa un prompt para empezar a crear") + with gr.Column(scale=1, elem_id="gen_column"): + generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") + + with gr.Row(): + with gr.Column(scale=4): + result = gr.Image(label="Imagen Generada") + + with gr.Column(scale=3): + with gr.Accordion("Tus Modelos"): + user_model_gallery = gr.Gallery( + label="Galeria de Modelos", + allow_preview=False, + columns=3, + elem_id="galley" + ) + + with gr.Accordion("Modelos Publicos", open=False): + selected_info = gr.Markdown("") + gallery = gr.Gallery( + [(item["image_url"], item["lora_name"]) for item in lora_models], + label="Galeria de Modelos Publicos", + allow_preview=False, + columns=3, + elem_id="gallery" + ) + + + with gr.Accordion("Configuracion Avanzada", open=False): + with gr.Row(): + cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) + steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) + with gr.Row(): + width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) + height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) + with gr.Row(): + randomize_seed = gr.Checkbox(True, label="Randomize seed") + lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95) + + gallery.select( + update_selection, + inputs=[width, height], + outputs=[prompt, selected_info, selected_index, width, height] + ) + + gr.on( + triggers=[generate_button.click, prompt.submit], + fn=run_lora, + inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale], + outputs=[result, generation_credits_display] + ) + + with gr.TabItem("Training"): + gr.Markdown("# Entrena tu propio modelo 🧠") + gr.Markdown("En esta seccion podes entrenar tu propio modelo a partir de tus imagenes.") + with gr.Row(): + with gr.Column(): + train_dataset = gr.Gallery(columns=4, interactive=True, label="Tus Imagenes") + model_name = gr.Textbox(label="Nombre del Modelo",) + trigger_word = gr.Textbox(label="Palabra clave", + info="Esta seria una palabra clave para luego indicar al modelo cuando debe usar estas nuevas capacidad es que le vamos a ensenar", + ) + train_button = gr.Button("Comenzar Training") + with gr.Accordion("Configuracion Avanzada", open=False): + train_steps = gr.Slider(label="Training Steps", minimum=100, maximum=10000, step=100, value=1000) + lora_rank = gr.Number(label='lora_rank', value=16) + batch_size = gr.Number(label='batch_size', value=1) + learning_rate = gr.Number(label='learning_rate', value=0.0004) + training_status = gr.Textbox(label="Training Status") + + + + train_button.click( + compress_and_train, + inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate], + outputs=training_status + ) + + + #main_demo.load(greet, None, title) + #main_demo.load(greet, None, greetings) + #main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display]) + main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display]) + + + +# TODO: +''' +- Galeria Modelos Propios (si existe alguno del user, si no, mostrar un mensaje para entrenar) +- Galeria Modelos Open Source (accordion) +- Training con creditos. +- Stripe(?) +- Mejorar boton de login/logout +- Retoque landing page +''' + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..b6fcb5c --- /dev/null +++ b/main.py @@ -0,0 +1,31 @@ +import uvicorn +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from starlette.middleware.sessions import SessionMiddleware +from config import SECRET_KEY +from routes import router, get_user +from gradio_app import login_demo, main_demo +import gradio as gr +from pathlib import Path +from fastapi.middleware.cors import CORSMiddleware + +app = FastAPI() + +login_demo.queue() +main_demo.queue() + +static_dir = Path("./static") +app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static") +#app.mount("/assets", StaticFiles(directory="assets", html=True), name="assets") + +app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) + +app.include_router(router) + +app = gr.mount_gradio_app(app, login_demo, path="/main") +app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user, show_error=True) + +if __name__ == "__main__": + uvicorn.run(app) + + \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000..6d4ac76 --- /dev/null +++ b/models.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +class User(BaseModel): + id: str + google_id: str + email: str + name: str + given_name: str + profile_picture: str + generation_credits: int + train_credits: int \ No newline at end of file diff --git a/routes.py b/routes.py new file mode 100644 index 0000000..36c2309 --- /dev/null +++ b/routes.py @@ -0,0 +1,154 @@ +# routes.py +from fastapi import APIRouter, Depends, Request +from starlette.responses import RedirectResponse +from auth import oauth +from database import get_or_create_user, update_user_credits, get_user_by_id +from authlib.integrations.starlette_client import OAuthError +import gradio as gr +from utils.stripe_utils import create_checkout_session, verify_webhook, retrieve_stripe_session + +router = APIRouter() + +def get_user(request: Request): + user = request.session.get('user') + return user['name'] if user else None + +@router.get('/') +def public(request: Request, user = Depends(get_user)): + root_url = gr.route_utils.get_root_url(request, "/", None) + print(f'Root URL: {root_url}') + if user: + return RedirectResponse(url=f'{root_url}/gradio/') + else: + return RedirectResponse(url=f'{root_url}/main/') + +@router.route('/logout') +async def logout(request: Request): + request.session.pop('user', None) + return RedirectResponse(url='/') + +@router.route('/login') +async def login(request: Request): + root_url = gr.route_utils.get_root_url(request, "/login", None) + redirect_uri = f"{root_url}/auth" + return await oauth.google.authorize_redirect(request, redirect_uri) + +@router.route('/auth') +async def auth(request: Request): + try: + token = await oauth.google.authorize_access_token(request) + user_info = token.get('userinfo') + if user_info: + google_id = user_info['sub'] + email = user_info['email'] + name = user_info['name'] + given_name = user_info['given_name'] + profile_picture = user_info.get('picture', '') + + user = get_or_create_user(google_id, email, name, given_name, profile_picture) + request.session['user'] = user + + return RedirectResponse(url='/gradio') + else: + return RedirectResponse(url='/main') + except OAuthError as e: + print(f"OAuth Error: {str(e)}") + return RedirectResponse(url='/main') + +# Handle Stripe payments +@router.get("/buy_credits") +async def buy_credits(request: Request): + user = request.session.get('user') + if not user: + return {"error": "User not authenticated"} + session = create_checkout_session(100, 50, user['id']) # $1 for 50 credits + + # Store the session ID and user ID in the session + request.session['stripe_session_id'] = session['id'] + request.session['user_id'] = user['id'] + print(f"Stripe session created: {session['id']} for user {user['id']}") + + return RedirectResponse(session['url']) + +@router.post("/webhook") +async def stripe_webhook(request: Request): + payload = await request.body() + sig_header = request.headers.get("Stripe-Signature") + + event = verify_webhook(payload, sig_header) + + if event is None: + return {"error": "Invalid payload or signature"} + + if event['type'] == 'checkout.session.completed': + session = event['data']['object'] + user_id = session.get('client_reference_id') + + if user_id: + # Fetch the user from the database + user = get_user_by_id(user_id) # You'll need to implement this function + if user: + # Update user's credits + new_credits = user['generation_credits'] + 50 # Assuming 50 credits were purchased + update_user_credits(user['id'], new_credits, user['train_credits']) + print(f"Credits updated for user {user['id']}") + else: + print(f"User not found for ID: {user_id}") + else: + print("No client_reference_id found in the session") + + return {"status": "success"} + +# @router.get("/success") +# async def payment_success(request: Request): +# print("Payment successful") +# user = request.session.get('user') +# print(user) +# if user: +# updated_user = get_user_by_id(user['id']) +# if updated_user: +# request.session['user'] = updated_user +# return RedirectResponse(url='/gradio', status_code=303) +# return RedirectResponse(url='/login', status_code=303) + +@router.get("/cancel") +async def payment_cancel(request: Request): + print("Payment cancelled") + user = request.session.get('user') + print(user) + if user: + return RedirectResponse(url='/gradio', status_code=303) + return RedirectResponse(url='/login', status_code=303) + +@router.get("/success") +async def payment_success(request: Request): + print("Payment successful") + stripe_session_id = request.session.get('stripe_session_id') + user_id = request.session.get('user_id') + + print(f"Session data: stripe_session_id={stripe_session_id}, user_id={user_id}") + + if stripe_session_id and user_id: + # Retrieve the Stripe session + stripe_session = retrieve_stripe_session(stripe_session_id) + + if stripe_session.get('payment_status') == 'paid': + user = get_user_by_id(user_id) + if user: + # Update the session with the latest user data + request.session['user'] = user + print(f"User session updated: {user}") + + # Clear the stripe_session_id and user_id from the session + request.session.pop('stripe_session_id', None) + request.session.pop('user_id', None) + + return RedirectResponse(url='/gradio', status_code=303) + else: + print(f"User not found for ID: {user_id}") + else: + print(f"Payment not completed for session: {stripe_session_id}") + else: + print("No Stripe session ID or user ID found in the session") + + return RedirectResponse(url='/login', status_code=303) \ No newline at end of file diff --git a/services/get_stripe.py b/services/get_stripe.py new file mode 100644 index 0000000..bb3f786 --- /dev/null +++ b/services/get_stripe.py @@ -0,0 +1,2 @@ +import stripe + diff --git a/services/image_generation.py b/services/image_generation.py new file mode 100644 index 0000000..a4f3707 --- /dev/null +++ b/services/image_generation.py @@ -0,0 +1,21 @@ +import replicate +from PIL import Image +import requests +from io import BytesIO + +#model_custom_test = "josebenitezg/flux-dev-ruth-estilo-1:c7ff81b58007c7cee3f69416e1e999192dafd8d1b1f269ea6cae137f04b34172" +flux_pro = "black-forest-labs/flux-pro" +def generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress, trigger_word='hi'): + print(f"Generating image for prompt: {prompt}") + img_url = replicate.run( + flux_pro, + input={ + "steps": steps, + "prompt": prompt, + "guidance": cfg_scale, + "interval": 2, + "aspect_ratio": "1:1", + "safety_tolerance": 2 + } + ) + return img_url diff --git a/services/train_lora.py b/services/train_lora.py new file mode 100644 index 0000000..2db18fa --- /dev/null +++ b/services/train_lora.py @@ -0,0 +1,46 @@ +import replicate +import os +from huggingface_hub import create_repo + +REPLICATE_OWNER = "josebenitezg" + +def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004): + print(f'Creating dataset for {model_name}') + repo_name = f"joselobenitezg/flux-dev-{model_name}" + create_repo(repo_name, repo_type='model') + + lora_name = f"flux-dev-{model_name}" + + model = replicate.models.create( + owner=REPLICATE_OWNER, + name=lora_name, + visibility="public", # or "private" if you prefer + hardware="gpu-t4", # Replicate will override this for fine-tuned models + description="A fine-tuned FLUX.1 model" + ) + + print(f"Model created: {model.name}") + print(f"Model URL: https://replicate.com/{model.owner}/{model.name}") + + # Now use this model as the destination for your training + print(f"[INFO] Starting training") + + print(f'\n[INFO] Parametros a entrenar: \n Trigger word: {trigger_word}\n steps: {steps} \n lora_rank: {lora_rank}\n autocaption: {autocaption}\n learning_rate: {learning_rate}\n') + training = replicate.trainings.create( + version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02", + input={ + "input_images": open(zip_path, "rb"), + "steps": steps, + "lora_rank": lora_rank, + "batch_size": batch_size, + "autocaption": autocaption, + "trigger_word": trigger_word, + "learning_rate": learning_rate, + "hf_token": os.getenv('HF_TOKEN'), # optional + "hf_repo_id": repo_name, # optional + }, + destination=f"{model.owner}/{model.name}" + ) + + print(f"Training started: {training.status}") + print(f"Training URL: https://replicate.com/p/{training.id}") diff --git a/static/css/login.css b/static/css/login.css new file mode 100644 index 0000000..50aef78 --- /dev/null +++ b/static/css/login.css @@ -0,0 +1,76 @@ + +.login-with-google-btn { + display: inline-block; + width: 220px; /* Ancho fijo */ + max-width: 100%; /* Para asegurar responsividad */ + transition: background-color .3s, box-shadow .3s; + padding: 8px 12px 8px 35px; + border: none; + border-radius: 3px; + box-shadow: 0 -1px 0 rgba(0, 0, 0, .04), 0 1px 1px rgba(0, 0, 0, .25); + color: #757575; + font-size: 12px; + font-weight: 500; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; + background-image: url(); + background-color: white; + background-repeat: no-repeat; + background-position: 10px 50%; + background-size: 15px 15px; + text-align: center; +} + +/* Contenedor para centrar el botón */ + /* Estilos adicionales para el botón de Google y su contenedor */ +.google-btn-container { + display: flex; + justify-content: flex-end; + width: 100%; + padding-right: 20px; + box-sizing: border-box; + position: absolute; + top: 20px; + right: 0; +} +.svelte-vt1mxs.gap { + position: static !important; + margin-top: 0 !important; +} + +.login-with-google-btn:active { + background-color: #eeeeee; +} + +.login-with-google-btn:focus { + outline: none; + box-shadow: + 0 -1px 0 rgba(0, 0, 0, .04), + 0 2px 4px rgba(0, 0, 0, .25), + 0 0 0 3px #c8dafc; +} + +.login-with-google-btn:disabled { + filter: grayscale(100%); + background-color: #ebebeb; + box-shadow: 0 -1px 0 rgba(0, 0, 0, .04), 0 1px 1px rgba(0, 0, 0, .25); + cursor: not-allowed; +} + +/* Estilos específicos para trabajar con las clases de Gradio */ +.svelte-vt1mxs.gap { + position: absolute; + top: 20px; + right: 20px; + z-index: 1000; +} +@media(max-width: 768px) { + .feature-grid { + grid-template-columns: 1fr; + } + .google-btn-container { + position: static; + justify-content: center; + padding-right: 0; + margin-top: 20px; + } +} \ No newline at end of file diff --git a/static/css/main.css b/static/css/main.css new file mode 100644 index 0000000..5de2ab3 --- /dev/null +++ b/static/css/main.css @@ -0,0 +1,60 @@ +#gen_btn { + height: 100% +} + +#title { + text-align: center +} + +#title h1 { + font-size: 3em; + display: inline-flex; + align-items: center +} + +#title img { + width: 100px; + margin-right: 0.5em +} + +#gallery .grid-wrap { + height: 10vh +} + +/* Estilo para el contenedor del botón */ +#logout-btn-container.svelte-vt1mxs.gap { + position: absolute; + top: 10px; + right: 10px; + z-index: 1000; + display: flex; + justify-content: flex-end; + width: auto; +} + +/* Estilo para el botón de logout */ +#logout_btn.lg.secondary.svelte-cmf5ev { + width: auto; + min-width: 80px; + background-color: #f44336; + color: white; + border: none; + padding: 5px 10px; + border-radius: 5px; + cursor: pointer; + font-size: 0.9em; + transition: background-color 0.3s; + text-align: center; + text-decoration: none; + display: inline-block; + margin-left: auto; /* Empuja el botón hacia la derecha */ +} + +#logout_btn.lg.secondary.svelte-cmf5ev:hover { + background-color: #d32f2f; +} + +/* Ajuste del layout principal si es necesario */ +.gradio-container { + position: relative; +} \ No newline at end of file diff --git a/static/html/landing.html b/static/html/landing.html new file mode 100644 index 0000000..317c582 --- /dev/null +++ b/static/html/landing.html @@ -0,0 +1,189 @@ + + + + + + ObtuAI - Creación Visual con IA + + + +
+
+
+ +
+ +
+
+
+
+ +
+
+
+

🚀 Bienvenido al Futuro de la Creación Visual

+

Crea imágenes con IA en segundos. ¡Escribe tu idea y mira cómo se convierte en arte!

+
+
+ +
+

🌟 Descubre el Poder de la Generación de Imágenes por IA

+
+
+

Personaliza

+

Alimenta tu modelo con tus propias imágenes y estilos.

+
+
+

Entrena

+

Nuestra IA aprende de tus preferencias.

+
+
+

Crea

+

Genera imágenes que reflejen tu visión única.

+
+
+
+ +
+
+

💬 Lo Que Dicen Nuestros Usuarios

+
+

"ObtuAI ha revolucionado mi proceso creativo. ¡Ahora puedo visualizar mis ideas más locas en minutos!"

+

- Ana, Diseñadora Gráfica

+
+
+

"Entrenar mi propio modelo fue sorprendentemente fácil. Ahora hago fotografías mías y de mis clientes en segundos."

+

- Carlos, Fotógrafo Profesional

+
+
+
+
+ + + + \ No newline at end of file diff --git a/static/html/main_header.html b/static/html/main_header.html new file mode 100644 index 0000000..c9abdf8 --- /dev/null +++ b/static/html/main_header.html @@ -0,0 +1,82 @@ + + + + + + ObtuAI Header + + + +
+
+

Obtu AI 📸

+
+ + + GPU🔥 +
+
+
+ + \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000..ce9d5e3 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,23 @@ +import requests +from PIL import Image +from io import BytesIO + +def url_to_pil_image(url): + try: + # Ensure url is a string, not a list + if isinstance(url, list): + url = url[0] # Take the first URL if it's a list + + response = requests.get(url) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + # Convert to RGB if the image is in RGBA mode (for transparency) + if image.mode == 'RGBA': + image = image.convert('RGB') + + return image + except Exception as e: + print(f"Error loading image from URL: {url}") + print(f"Error details: {str(e)}") + return None \ No newline at end of file diff --git a/utils/stripe_utils.py b/utils/stripe_utils.py new file mode 100644 index 0000000..db497d3 --- /dev/null +++ b/utils/stripe_utils.py @@ -0,0 +1,44 @@ +import stripe +from config import STRIPE_API_KEY, STRIPE_WEBHOOK_SECRET + +stripe.api_key = STRIPE_API_KEY + + +def create_checkout_session(amount, quantity, user_id): + session = stripe.checkout.Session.create( + payment_method_types=['card'], + line_items=[{ + 'price_data': { + 'currency': 'usd', + 'unit_amount': amount, + 'product_data': { + 'name': f'Buy {quantity} credits', + }, + }, + 'quantity': 1, + }], + mode='payment', + success_url='http://localhost:8000/success?session_id={CHECKOUT_SESSION_ID}&user_id=' + str(user_id), + cancel_url='http://localhost:8000/cancel?user_id=' + str(user_id), + + client_reference_id=str(user_id), # Add this line + ) + return session + +def verify_webhook(payload, signature): + try: + event = stripe.Webhook.construct_event( + payload, signature, STRIPE_WEBHOOK_SECRET + ) + return event + except ValueError as e: + return None + except stripe.error.SignatureVerificationError as e: + return None + +def retrieve_stripe_session(session_id): + try: + return stripe.checkout.Session.retrieve(session_id) + except stripe.error.StripeError as e: + print(f"Error retrieving Stripe session: {str(e)}") + return None \ No newline at end of file