Skip to content

Commit

Permalink
updaete roundtriper and client -> need modify dial_early
Browse files Browse the repository at this point in the history
  • Loading branch information
ElNiak committed Jan 4, 2024
1 parent f13bd96 commit ea4d1d3
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 157 deletions.
259 changes: 123 additions & 136 deletions py-ssh3/client_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,137 +326,8 @@ async def main():
logging.debug(f"Dialing QUIC host at {hostname}:{port}")

conv = None # uninitialized conversation

async def establish_client_connection(client):
log.info(f"Establishing client connection with {client}")
if not client or client == -1:
return exit(-1)

tls_state = client._quic.tls.state
log.info(f"TLS state is {tls_state}")

conv = await new_client_conversation(30000,10, tls_state)

log.info(f"Conversation is {conv}")

# HTTP request over QUIC
# perform request
new_url = URL(url_from_param.replace("https","ssh3")) # TODO -> should replace Proto
log.info(f"New URL is {new_url}")
req = HttpRequest(method="CONNECT", url=new_url)
# req.headers[b'user-agent'] = get_current_version()
req.headers[':protocol'] = "ssh3" # TODO -> should replace Proto
log.info(f"Request is {req}")
# TODO seems not totally correct and secure
log.info(f"Request is {req}")

# await asyncio.gather(*coros)
# req.Proto = "ssh3" # TODO
# process http pushes
# process_http_pushes(client=client)

# Handle authentication methods
auth_methods = []
priv_key_file = args.privkey
if not args.privkey:
priv_key_file = '~/.ssh/id_rsa'
pubkey_for_agent = '' # TODO

if not args.useOidc:
# Private key and agent authentication
if priv_key_file:
# Add private key auth method
auth_methods.append(PrivkeyFileAuthMethod(priv_key_file)) # Implement based on your application logic

if pubkey_for_agent:
agent = paramiko.Agent()
agent_keys = agent.get_keys()
# Compare and add agent keys to auth methods
# TODO
pass # Implement based on your application logic

if args.usePassword:
# Add password auth method
auth_methods.append(PasswordAuthMethod()) # Implement based on your application logic
else:
# OIDC authentication
# TODO
issuer_url = args.useOidc
if issuer_url:
# Add OIDC auth method based on issuer URL
for issuer_config in oidc_config:
if issuer_url == issuer_config.issuer_url:
auth_methods.append(OIDCAuthMethod(args.doPkce,issuer_config))

else:
log.error("OIDC was asked explicitly but did not find suitable issuer URL")
exit(-1)

auth_methods.append(config_auth_methods)

if oidc_config:
for issuer_config in oidc_config:
if issuer_url == issuer_config.issuer_url:
auth_methods.append(OIDCAuthMethod(args.doPkce,issuer_config))

log.debug(f"Try the following auth methods: {auth_methods}")

identity = None
for method in auth_methods:
if isinstance(method, PasswordAuthMethod):
password = input(f"Password: ")
identity = method.into_identity(password)
elif isinstance(method, PrivkeyFileAuthMethod):
try:
identity = method.into_identity_without_passphrase()
except Exception as e: # Replace with specific passphrase missing exception
# Handle passphrase protected key
passphrase = input(f"Passphrase for private key stored in {method.filename()}: ")
identity = method.into_identity_passphrase(passphrase)
if identity is None:
log.error("Could not load private key with passphrase")
elif isinstance(method, AgentAuthMethod):
# Assuming an SSH agent is already connected
# identity = method.into_identity(agent_client)
pass # TODO
elif isinstance(method, OIDCAuthMethod):
# Assuming an OIDC connection method
# TODO
# token, err = oicd_connect(method.oidc_config(), method.oidc_config().issuer_url, method.do_pkce)
token, err = None, None
if err:
log.error(f"Could not get token: {err}")
else:
identity = method.into_identity(token)

if identity:
break # Exit the loop once an identity is found

if identity is None:
log.error("No suitable identity found")
# Handle the error or exit
exit(-1)

log.debug(f"Try the following Identity: {identity}")

try:
identity.set_authorization_header(req, username, conv)
except Exception as e:
log.error(f"Could not set authorization header in HTTP request: {e}")
exit(-1)

log.debug(f"Send CONNECT request to the server {req}")

try:
ret, err = await conv.establish_client_conversation(req, client)
if ret == "Unauthorized": # Replace with your specific error class
log.error("Access denied from the server: unauthorized")
exit(-1)
except Exception as e:
log.error(f"Could not open channel: {e}")
exit(-1)



# TODO no need to connect in dial since connect() is called in RoundTripper
async def dial_quic_host(hostname, port, quic_config, known_hosts_path):
try:
# Check if hostname is an IP address and format it appropriately
Expand All @@ -478,8 +349,9 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path):
# Connection established
client = cast(HttpClient, client)
log.info(f"Connected to {hostname}:{port} with client {client}")
return client
# await establish_client_connection(client)
return client

# coros = [
# perform_http_request(
# client=client,
Expand Down Expand Up @@ -537,8 +409,7 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path):
except Exception as e:
log.error("Could not establish client QUIC connection: %s", e)
return -1



log.info(f"Starting client to {url_from_param}")
client = await dial_quic_host(
hostname=hostname,
Expand All @@ -548,14 +419,131 @@ async def dial_quic_host(hostname, port, quic_config, known_hosts_path):
)

if client == -1 or client == 0:
log.error(f"Could not establish client QUIC connection: {client}")
return

# // TODO: could be nice ?? dirty hack: ensure only one QUIC connection is used
def dial(addr:str, tls_config, quic_config):
return client, None
round_tripper.dial = dial

await establish_client_connection(client)
tls_state = client._quic.tls.state
log.info(f"TLS state is {tls_state}")

log.info(f"Creating client conversation with {client}")
conv = await new_client_conversation(30000,10, tls_state)
log.info(f"Conversation is {conv}")

# HTTP request over QUIC
# perform request
new_url = URL(url_from_param.replace("https","ssh3")) # TODO -> should replace Proto
# new_url = URL(url_from_param)
log.info(f"New URL is {new_url}")
req = HttpRequest(method="CONNECT", url=new_url)
# req.Proto = "ssh3" # TODO
req.headers[':protocol'] = "ssh3" # TODO -> should replace Proto
log.info(f"Request is {req}")

# Handle authentication methods
auth_methods = []
priv_key_file = args.privkey
if not args.privkey:
priv_key_file = '~/.ssh/id_rsa'
pubkey_for_agent = '' # TODO

if not args.useOidc:
# Private key and agent authentication
if priv_key_file:
# Add private key auth method
auth_methods.append(PrivkeyFileAuthMethod(priv_key_file)) # Implement based on your application logic

if pubkey_for_agent:
agent = paramiko.Agent()
agent_keys = agent.get_keys()
# Compare and add agent keys to auth methods
# TODO
pass # Implement based on your application logic

if args.usePassword:
# Add password auth method
auth_methods.append(PasswordAuthMethod()) # Implement based on your application logic
else:
# OIDC authentication
# TODO
issuer_url = args.useOidc
if issuer_url:
# Add OIDC auth method based on issuer URL
for issuer_config in oidc_config:
if issuer_url == issuer_config.issuer_url:
auth_methods.append(OIDCAuthMethod(args.doPkce,issuer_config))

else:
log.error("OIDC was asked explicitly but did not find suitable issuer URL")
exit(-1)

auth_methods.append(config_auth_methods)

if oidc_config:
for issuer_config in oidc_config:
if issuer_url == issuer_config.issuer_url:
auth_methods.append(OIDCAuthMethod(args.doPkce,issuer_config))

log.debug(f"Try the following auth methods: {auth_methods}")

identity = None
for method in auth_methods:
if isinstance(method, PasswordAuthMethod):
password = input(f"Password: ")
identity = method.into_identity(password)
elif isinstance(method, PrivkeyFileAuthMethod):
try:
identity = method.into_identity_without_passphrase()
except Exception as e: # Replace with specific passphrase missing exception
# Handle passphrase protected key
passphrase = input(f"Passphrase for private key stored in {method.filename()}: ")
identity = method.into_identity_passphrase(passphrase)
if identity is None:
log.error("Could not load private key with passphrase")
elif isinstance(method, AgentAuthMethod):
# Assuming an SSH agent is already connected
# identity = method.into_identity(agent_client)
pass # TODO
elif isinstance(method, OIDCAuthMethod):
# Assuming an OIDC connection method
# TODO
# token, err = oicd_connect(method.oidc_config(), method.oidc_config().issuer_url, method.do_pkce)
token, err = None, None
if err:
log.error(f"Could not get token: {err}")
else:
identity = method.into_identity(token)

if identity:
break # Exit the loop once an identity is found

if identity is None:
log.error("No suitable identity found")
# Handle the error or exit
exit(-1)

log.debug(f"Try the following Identity: {identity}")

try:
identity.set_authorization_header(req, username, conv)
except Exception as e:
log.error(f"Could not set authorization header in HTTP request: {e}")
exit(-1)

log.debug(f"Send CONNECT request to the server {req}")

try:
ret, err = await conv.establish_client_conversation(req, round_tripper)
if ret == "Unauthorized": # Replace with your specific error class
log.error("Access denied from the server: unauthorized")
exit(-1)
except Exception as e:
log.error(f"Could not open channel: {e}")
exit(-1)

try:
channel = conv.open_channel("session", 30000, 0)
Expand Down Expand Up @@ -662,7 +650,6 @@ async def transfer_stdin_to_channel(channel):

await transfer_stdin_to_channel(channel)


async def udp_forwarding(local_udp_addr, remote_udp_addr, conv):
log.debug(f"Start forwarding from {local_udp_addr} to {remote_udp_addr}")

Expand Down
Loading

0 comments on commit ea4d1d3

Please sign in to comment.