Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Visualization Options #231

Open
4 tasks
jacobbieker opened this issue Oct 23, 2023 · 8 comments
Open
4 tasks

Add Visualization Options #231

jacobbieker opened this issue Oct 23, 2023 · 8 comments
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed

Comments

@jacobbieker
Copy link
Member

jacobbieker commented Oct 23, 2023

Detailed Description

We want to be able to easily see what our batches look like and have utilities that plot them to help with debugging and ensuring that our pipelines are doing what we expect.

We have had multiple one-off visualization scripts before, but the goal of this is to build them into datapipes, and ideally keep them up to date, and possibly run them on PRs to give a quick, automatic view if any of the datapipes are changed or updated.

I think the steps would be

  • Make visualization module in datapipes
  • Add visualizing a whole example of all modalities as an image
  • Add visualizing examples as little videos (to see the timeseries in the videos)
  • Add option to save out batches in more interpretable format (i.e. NetCDF or something that keeps coordinates and the like, vs PyTorch tensors)

Possible Implementation

Satip used to have a step in the workflows that ran visualization code of the outputs of some processing steps on PRs, it was quite helpful to know if changes broke end-to-end processing pipelines, and for the images coming out still looked correct.

Notes

Goal:

  • to show what is in the batches right before the model runs
  • To show in training what is going in at any timestep
  • User can step through periods
  • Time and space is aligned
    Users:
  • ML team only
  • Prototype examples
  • NWP data wasn’t aligned with GSP data - James found this when plotting these out
  • Early on, Jacob found the satellite data was 500 km off
    Effort to build:
  • Make it so people don’t need to rebuild anything from scratch
  • Build something a bit less ad-hoc than before
    Effort to run:
  • Hopefully takes someone <1 min to run this from Datapipes
  • It would be useful for training & production use cases
@jacobbieker jacobbieker added enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed labels Oct 23, 2023
@jacobbieker
Copy link
Member Author

@dfulu @peterdudfield does this sound right to you? I was thinking matplotlib plots by default, as they can be saved out to disk easily, or opened in streamlit with st.pyplot for the dashboard.

@peterdudfield
Copy link
Contributor

@dfulu @peterdudfield does this sound right to you? I was thinking matplotlib plots by default, as they can be saved out to disk easily, or opened in streamlit with st.pyplot for the dashboard.

This looks really great, just got rew comments

  • have an function that takes in a file_location and makes this visualisation
  • Think its good to put in ocf_datapipes, but perhaps doesnt have to be a datapipe itself. More of a function with some actions to do
    I would encourage to use plotly, which can also be easily saved. I think it is just a better graph library.

@dfulu
Copy link
Member

dfulu commented Nov 7, 2023

Yeh this sounds good. I had also been wondering if it might be a generally good idea to save out batches in something like a netcdf. Do you think it would be slower to load or larger on disk to use a netcdf for each batch compared to a pytorch tensor?

@jacobbieker
Copy link
Member Author

It would be a bit slower to load, as you'd have to convert it to a pytorch tensor before putting it into the model, but it would make the batches a lot easier to visualize, could mostly just call the inbuilt xarray plotting. I would probably lean towards saving them out as netCDFs and then just doing the conversion on the fly. I don't think they would be much larger, they'd still have the metadata which might make a difference, but I think it should be fine.

@jacobbieker
Copy link
Member Author

@peterdudfield sounds good for having it just be a function. If we did move to NetCDF files being saved to disk, I would probably stick with matplotlib as that is what xarray uses in its in-built plotting methods, and it would reduce the work needed for doing this.

@reticent-roklimber
Copy link

reticent-roklimber commented Apr 2, 2024

Hi, I am quite familiar with plotly and I am currently working with weather data, handling visualisation and building ML models for flood inundation. I came across this while looking at issues as part of GSOC. I am interested in contributing to this.

@reticent-roklimber
Copy link

reticent-roklimber commented Apr 4, 2024

Hi, I ended up not applying for GSoC due to the time constraints at work, but I am still interested in contributing here. Can you let me know how to proceed?

@peterdudfield
Copy link
Contributor

peterdudfield commented Jun 11, 2024

here's my very small attempt, that takes
batches --> spits out some sort of markdown file.
This only does wind and nwp and is pretty delicate

""" The idea is visualize one of the batches """
import pandas as pd
import sys

from ocf_datapipes.batch import NumpyBatch, BatchKey, NWPBatchKey
import torch
import plotly.graph_objects as go


def visualize_batch(batch: NumpyBatch):

    # Wind
    print('# Batch visualization')
    print('## Wind \n')
    keys = [
        BatchKey.wind,
        BatchKey.wind_t0_idx,
        BatchKey.wind_time_utc,
        BatchKey.wind_id,
        BatchKey.wind_observed_capacity_mwp,
        BatchKey.wind_nominal_capacity_mwp,
        BatchKey.wind_time_utc,
        BatchKey.wind_latitude,
        BatchKey.wind_longitude,
        BatchKey.wind_solar_azimuth,
        BatchKey.wind_solar_elevation,
    ]
    for key in keys:
        if key in batch.keys():
            print('\n')
            value = batch[key]
            if isinstance(value, torch.Tensor):
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, int):
                print(f"{key} {value}")
            else:
                print(f"{key} {value}")

    # NWP
    print('## NWP \n')

    keys = [
        NWPBatchKey.nwp,
        NWPBatchKey.nwp_target_time_utc,
        NWPBatchKey.nwp_channel_names,
        NWPBatchKey.nwp_step,
        NWPBatchKey.nwp_t0_idx,
        NWPBatchKey.nwp_init_time_utc,
    ]

    nwp = batch[BatchKey.nwp]

    nwp_providers = nwp.keys()
    for provider in nwp_providers:
        print('\n')
        print(f"Provider {provider}")
        nwp_provider = nwp[provider]

        # plot nwp main data
        nwp_data = nwp_provider[NWPBatchKey.nwp]
        # average of lat and lon
        nwp_data = nwp_data.mean(dim=(3, 4))
        fig = go.Figure()
        for i in range(len(nwp_provider[NWPBatchKey.nwp_channel_names])):
            channel = nwp_provider[NWPBatchKey.nwp_channel_names][i]
            nwp_data_one_channel = nwp_data[0,:,i]
            time = nwp_provider[NWPBatchKey.nwp_target_time_utc][0]
            time = pd.to_datetime(time, unit='s')
            fig.add_trace(go.Scatter(x=time, y=nwp_data_one_channel, mode='lines', name=channel))

        fig.update_layout(title=f'{provider} NWP', xaxis_title='Time', yaxis_title='Value')
        fig.show(renderer='browser')
        name = f'{provider}_nwp.png'
        fig.write_image(name)
        print(f'![]({name})')
        print('\n')


        for key in keys:
            print('\n')
            value = nwp_provider[key]
            if 'time' in key.name:
                value = pd.to_datetime(value[0], unit='s')
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, torch.Tensor):
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, int):
                print(f"{key} {value}")
            else:
                print(f"{key} {value}")

with open('batch.md', 'w') as f:
    sys.stdout = f
    d = torch.load("device_batch_0.pt")
    visualize_batch(d)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants