Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced code of scatterPlot (refer issue #43) #46

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

Tanvi-Jain01
Copy link

@Tanvi-Jain01 Tanvi-Jain01 commented Jun 30, 2023

This PR solves #43.

Before

vayu/vayu/scatterPlot.py

Lines 22 to 27 in ef99aef

pm10 = df.pm10
o3 = df.o3
ws = df.ws
wd = df.wd
nox = df.nox
no2 = df.no2

vayu/vayu/scatterPlot.py

Lines 42 to 57 in ef99aef

if x == "nox":
x = nox
elif x == "no2":
x = no2
elif x == "o3":
x = o3
elif x == "pm10":
x = pm10
if y == "nox":
y = nox
elif y == "no2":
y = no2
elif y == "o3":
y = o3
elif y == "pm10":
y = pm10

After

def scatterPlot(df, x, y, **kwargs):
    import seaborn as sns
    import matplotlib.pyplot as plt
    from math import pi

   # df1 = pd.DataFrame({"speed": ws, "direction": wd})
    df["speed_x"] = df['ws'] * np.sin(df['wd'] * pi / 180.0)
    df["speed_y"] = df['ws'] * np.cos(df['wd'] * pi / 180.0)
    #print(df)
    fig, ax = plt.subplots(figsize=(8, 8), dpi=80)
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    #ax.set_aspect("equal")
    _ = df.plot(kind="scatter", x="speed_x", y="speed_y", alpha=0.35, ax=ax)
    plt.show()
    
    sns.jointplot(x=df[x].values, y=df[y].values, kind="hex")
    #print(x,y)
    plt.xlabel(x)
    plt.ylabel(y)
    plt.show()

Usage

pivot_pollutant = "pm10"
other_pollutants=['o3','pm25']

for pollutant in other_pollutants:
    scatterPlot(df1, pivot_pollutant, pollutant)

Plots

sc1 sc2

@patel-zeel
Copy link
Member

I think currently, the function is confusing. I am thinking of following changes for this function.
Objective of this function: To visualize the correlation between two pollutants and their individual densities.

Changes suggested:

  • rename it from scatterPlot to scatter_plot for now (following PEP8 guideline: function names should be in small letters with snake case). We should discuss the final name change with Prof. @nipunbatra. I would suggest joint_plot or correlation_plot or correlation_scatter
  • This function can be exactly identical to sns.jointplot. Accept the hue argument also. We can add month in hue and it will be very useful. Checkout some beautiful visualizations here: https://seaborn.pydata.org/generated/seaborn.jointplot.html
  • Remove the wind speed code completely.
  • Use "typehints" in this function and all the functions we make in future

New function may look like:

import pandas as pd
from typing import Optional

def scatter_plot(df: pd.DataFrame, x: str, y: str, hue: Optional[str] = None):
    return sns.jointplot(data=df, x=x, y=y, hue=hue)

@Tanvi-Jain01
Copy link
Author

@patel-zeel
Now, is it good to go?

SAMPLE CODE:

import numpy as np
import pandas as pd
np.random.seed(42)  

start_date = pd.to_datetime('2022-01-01')
end_date = pd.to_datetime('2022-12-31')

dates = pd.date_range(start_date, end_date)

pm25_values = np.random.rand(365)  # Generate 365 random values
o3_values = np.random.rand(365) 
ws_values = np.random.rand(365)
wd_values = np.random.rand(365)
pm10_values = np.random.rand(365)

df1 = pd.DataFrame({
    'date': dates,
    'pm25': pm25_values,
    'o3':o3_values,
    'ws': ws_values,
    'wd': wd_values,
     'pm10': pm10_values
})

df1['date'] = df1['date'].dt.strftime('%Y-%m-%d')  # Convert date format to 'YYYY-MM-DD'
print(df1)

IMPROVED CODE:

 import seaborn as sns
 import matplotlib.pyplot as plt
 from math import pi
def scatter_plot(df:pd.DataFrame, x:str, y:str, **kwargs):
  
    sns.jointplot(x=df[x].values, y=df[y].values, kind="hex")
    #print(x,y)
    plt.xlabel(x)
    plt.ylabel(y)
    plt.savefig("scatterPlot.png", bbox_inches="tight",dpi=300)
    print("Your plots has also been saved")
    plt.show()
    
scatter_plot(df1, 'pm10', 'pm25')

OUTPUT:

scatterPlot

@patel-zeel
Copy link
Member

patel-zeel commented Jul 10, 2023

@Tanvi-Jain01 xlabel, ylabel are automatically added by sns.jointplot. Please remove the "saving plots" code. We can add a unified code later to save all the plots generated from vayu. Also, your PR has affected 3 files in total. It should only modify the file related to the issue.

@patel-zeel
Copy link
Member

@Tanvi-Jain01 I have added the quarto template in the master branch. So, please add a notebook showing the usage of scatter_plot function in the examples directory and add an entry for it in _quarto.yml.

Comment on lines 23 to +34
#########################################
# converts wind data to randians
df = pd.DataFrame({"speed": ws, "direction": wd})
df["speed_x"] = df["speed"] * np.sin(df["direction"] * pi / 180.0)
df["speed_y"] = df["speed"] * np.cos(df["direction"] * pi / 180.0)
#df1 = pd.DataFrame({"speed": ws, "direction": wd})
df["speed"+str(x)] = df['ws'] * np.sin(df['wd'] * pi / 180.0)
df["speed"+str(y)] = df['ws'] * np.cos(df['wd'] * pi / 180.0)
fig, ax = plt.subplots(figsize=(8, 8), dpi=80)
x0, x1 = ax.get_xlim()
y0, y1 = ax.get_ylim()
ax.set_aspect("equal")
_ = df.plot(kind="scatter", x="speed_x", y="speed_y", alpha=0.35, ax=ax)
#ax.set_aspect("equal")
_ = df.plot(kind="scatter", x="speed"+str(x), y="speed"+str(y), alpha=0.35, ax=ax)
plt.show()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this part of code.

Comment on lines +39 to 41
plt.xlabel(x)
plt.ylabel(y)
plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are also not needed and can be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants