Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential to track more points in live_demo.py? #71

Open
weirenorweiren opened this issue Dec 5, 2023 · 10 comments
Open

Potential to track more points in live_demo.py? #71

weirenorweiren opened this issue Dec 5, 2023 · 10 comments

Comments

@weirenorweiren
Copy link

I am curious about the potential to run live_demo.py with better GPU in order to track more points in real-time.

Have you ever done some testing to run it on cloud? If not, what would be the bottleneck? One thing I could think of is the streaming delay between local and cloud, but I am not sure whether it's a big problem.

Maybe switching & stacking better GPU on local would be more straightforward?

The high-level question would be: Have you thought about ways to scale the model for more points tracking?

Although it might be hard to answer without any experiments, it's always good to have some discussions in advance! Appreciate any comments and feedbacks!

@cdoersch
Copy link
Collaborator

cdoersch commented Dec 7, 2023

Indeed, hard to answer without more specifics on your goals. For RoboTAP we tracked 128 points on 480p videos on an RTX 3090 and ran the entire controller at about 10 fps. Tracking fewer points at lower resolution would make that faster.

Whether you can fit in network latency and get a speedup depends heavily on the network. Doing the above would require latency of about 5 ms, which is pretty unlikely unless your datacenter is quite close. If you're tracking thousands of points, though, maybe the latency is worth it.

@weirenorweiren
Copy link
Author

Now I have a more specific goal and want to estimate what GPU would be required!

I want to track 20 points on local for real-time 1080p(1920x1080) videos at 10fps. According to #69, live_demo.py is capable for ~17 fps for 8 points on 480x480 images w/ Quadro RTX 4000. I have the following 3 questions:

  1. I am wondering if everything scales linearly with the need of GPU that I would need 2 RTX 4090 to achieve my goal***?
  2. If not, then is there a way to estimate what GPU or GPU combinations would be required?
  3. If GPU combinations are needed, would the performance scale with the quantity of GPUs linearly? Are there anything else I need to consider?

***My reasoning is as follows. According to https://gpu.userbenchmark.com/Compare/Nvidia-RTX-4090-vs-Nvidia-Quadro-RTX-4000/4136vsm716215, RTX 4090 ~ 4 * Quadro RTX 4000 in terms of effective speed. If everything scales linearly with the computational requirement, then Required GPUs for my goal in unit of Quadro RTX 4000 = (10/17) * (20/8) * (1920/480) * (1080/480) ~ 0.5 * 2 * 4 * 2 = 8. Thus I would require 2 RTX 4090.

=====

Update: I tried to verify my hypothesis with the example you gave in the previous reply. If the linear scaling is true, then Required GPUs for RoboTAP in unit of Quadro RTX 4000 = (10/17) * (128/8) * (854/480) * (480/480) ~ 0.5 * 16 * 2 * 1 = 16, which implies that I need 4 RTX 4090.
There is a mismatch with your experiment but maybe the aforementioned RoboTAP is based on causal_bootstapir_checkpoint.npy instead of causal_tapir_checkpoint.npy. Does the mismatch result from the difference in backbone/model? Or my linear scaling hypothesis is wrong?

@cdoersch
Copy link
Collaborator

With BootsTAPIR we don't really see benefits for tracking at resolution higher than 512x512 (we haven't trained at such high resolutions, and the larger search space tends to result in false positives). The only reason this might change is if the textures aren't visible at 512x512, but in this case you might want to try with custom resolutions. Actually getting improvements will be an empirical effort. Have you verified that it actually works better at 1080p?

Parallelizing the live demo across two GPUs would not be trivial: you would need to shard the feature computation across space, and then shard the pips updates across the point axes, and even then the communication overhead might outweigh the computational speedup. We have internally made partir work with this codebase and I can give some guidance, but it's a bit painful.

However, an RTX 4090 is substantially more powerful than an Quadro RTX 4000 (a 5-year-old mobile GPU), so you probably don't need it. I would be very surprised if you can't do 20 points on 512x512 at well over 10fps on that hardware. Scaling is probably somewhat better than linear: the PIPs updates dominated the computation for our 128-point computation. Is there anything preventing you from simply trying it on a single 4090?

@weirenorweiren
Copy link
Author

weirenorweiren commented Jul 4, 2024

Really appreciate your prompt response and let me elaborate the current situation. I want to use tapnet for real-time tracking of several identical-look black spheres under a clear (i.e., white) or patterned (i.e., alternating black and white) background with a scientific CMOS camera. I have tested with some recorded videos through your Colab demo and it worked decently. Now I am working towards running live_demo.py with the CMOS camera on PC. [vc = cv2.VideoCapture() in live_demo.py doesn't work for my CMOS camera, so I am reading for other packages for the camera access. BTW if you happen to have any experience on the access of a CMOS camera, please let me know!]

=====

The 1080p resolution is chosen based on the CMOS camera resolution and the needed field of view for my application. I haven't tried on different resolutions yet since I am working on the previous camera issue and figuring out the right GPU to use/buy.

With BootsTAPIR we don't really see benefits for tracking at resolution higher than 512x512 (we haven't trained at such high resolutions, and the larger search space tends to result in false positives).

Regarding the above statement, a few follow-ups are as follows:

  1. Are you saying that you haven't trained at 512x512 and above resolutions but the BootsTAPIR would work for 512x512 and above resolutions?
  2. Each frame is resized to 256x256 in the Colab demo for real-time tracking but no resizing is performed in live_demo.py. Are there some considerations here? Should I resize or not for my application?
  3. Based on 2, I am wondering that are there lower and upper boundaries on real-time video resolution for BootsTAPIR to function properly?

=====

Regarding the GPU, currently I have a RTX 2080 and that would be the first GPU for me to try after I solve the camera issue. I was trying to estimate the GPU requirement for my case and the difficulty to stack GPUs. Then I would accommodate the number of points and input resolution accordingly since I have limited budget for hardwares and zero experience in GPU parallelizing. After reading your previous reply, it seems I won't think about parallelization and set my computation upper limit as a RTX 4090. With your better-than-linear scaling guess and my previous hypothesis as quoted below:

My reasoning is as follows. According to https://gpu.userbenchmark.com/Compare/Nvidia-RTX-4090-vs-Nvidia-Quadro-RTX-4000/4136vsm716215, RTX 4090 ~ 4 * Quadro RTX 4000 in terms of effective speed. If everything scales linearly with the computational requirement, then Required GPUs for my goal in unit of Quadro RTX 4000 = (10/17) * (20/8) * (1920/480) * (1080/480) ~ 0.5 * 2 * 4 * 2 = 8. Thus I would require 2 RTX 4090.

Do you think it's a conservative estimate that a single RTX 4090 should manage 10 points tracking on 10fps real-time 1080p videos for a local setup?

=====

The last question is for the scaling.

For RoboTAP we tracked 128 points on 480p videos on an RTX 3090 and ran the entire controller at about 10 fps.

I tried to verify my hypothesis with the example you gave in the above line. If the linear scaling is true, then Required GPUs for RoboTAP in unit of Quadro RTX 4000 = (10/17) * (128/8) * (854/480) * (480/480) ~ 0.5 * 16 * 2 * 1 = 16, which implies that I need 4 RTX 4090.
There is a mismatch with your experiment but maybe the aforementioned RoboTAP is based on causal_bootstapir_checkpoint.npy instead of causal_tapir_checkpoint.npy. Could you clarify whether you use BootsTAPIR or TAPIR for the above scenario? If it's BootsTAPIR, does it imply an even better scaling than that of TAPIR used in the live_demo.py testing?

@cdoersch
Copy link
Collaborator

cdoersch commented Jul 9, 2024

Sorry for the slow response, there's a lot here.

Are you saying that you haven't trained at 512x512 and above resolutions but the BootsTAPIR would work for 512x512 and above resolutions?

Yes. It's a convolutional architecture, so it can run at a different resolution than what it was trained on, but whether it generalizes depends on boundary effects. Empirically generalization to higher resolution seems to be OK, but it doesn't really improve things: the boundary effects seem to cancel out the extra benefits from higher resolution.

Each frame is resized to 256x256 in the Colab demo for real-time tracking but no resizing is performed in live_demo.py. Are there some considerations here? Should I resize or not for my application?

Again, this is an empirical question. live_demo.py hasn't been re-tested with bootstapir, and was really just the result of an afternoon of me playing with my webcam. The closest experiments we ran to your suggested setup are with the current bootstapir non-causal checkpoint, where we found the best performance at 512x512, which is the highest resolution that bootstapir was trained on, and higher resolutions performed about the same. I expect the same would be true for the causal checkpoint, but I haven't tested. I would encourage you to just record some data, dump it to mp4, and run causal tapir on whatever data you have. You can then see if it works well enough even if it doesn't work in real time.

Based on 2, I am wondering that are there lower and upper boundaries on real-time video resolution for BootsTAPIR to function properly?

Lower bound is 256x256; I can't imagine it would work well at lower resolution. Whether it works best at 512x512 or higher resolution depends on the application.

Could you clarify whether you use BootsTAPIR or TAPIR for the above scenario?

It was TAPIR for robotap. However, for 20 points I would expect the difference wouldn't be large. Unfortunately I can't guess what the performance would be. RTX 4090 has about 9x the peak flops of the 2080, so that's an upper-bound; however, a non-trivial amount of time in TAPIR is spent on gathers. But I wouldn't be suprised if you see 5x the framerate on a 4090 vs a 2080. What framerate do you get with your desired setup on the 2080?

TBH, the right answer may be to just find a cloud provider that lets you rent a 4090 so you can get an actual benchmark.

@weirenorweiren
Copy link
Author

Hi Carl, thanks again for your detailed responses! They are always clear and insightful. Please see below for my follow-ups.

I would encourage you to just record some data, dump it to mp4, and run causal tapir on whatever data you have.

I have already tested with some recorded videos in causal_tapir_demo.ipynb. The performance is desirable, so I am trying to run live_demo.py on local and verify the performance of TAPIR/BootsTAPIR under real-time setting.

Currently there are 2 issues. The first is to read RGB data from our CMOS camera through Python that is not accessible with vc = cv2.VideoCapture(). I made some progress to output something, however, it doesn't look like what the camera software shows. Maybe there is a format issue, but overall I think I could fix it soon. Another is about running JAX with GPU on Windows system. I tried to check on the tracking performance of live_demo.py on Sunday with a RTX 4080 on a gaming laptop, however, I failed to run the JAX-based demo on the GPU after several hours of attempts. I checked #25 but there are other problems as mentioned in #49.

Considering that there are installation and compatibility issues for JAX on Windows, is it possible to provide a live demo in PyTorch if it doesn't cost too much as requested in #102? It would be super helpful for non-CS research teams who use Windows as the primary system to explore your work and check the potential for collaboration. I am from one of those teams in applied physics and currently working on a demo for our setting (as quoted below) to persuade my PI for some official collaboration.

I want to use tapnet for real-time tracking of several identical-look black spheres under a clear (i.e., white) or patterned (i.e., alternating black and white) background with a scientific CMOS camera.

Our project could use not only TAPIR/BootsTAP for real-time tracking that is only a starting point, but also RoboTAP for subsequent applications. Thus, I really appreciate it if #102 could be addressed to help accelerate the implementation of your work and potential collaboration!

@cdoersch
Copy link
Collaborator

Checking #49 it seems the issue is that import tapnet is failing? We now support pip (although I haven't done a lot of testing), so it's possible that you can get around that error just by installing via pip. Are you still seeing AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'?

@weirenorweiren
Copy link
Author

weirenorweiren commented Aug 28, 2024

Hi Carl, I have some updates!

I studied through the code base and managed to run live_demo.py on Windows with our CMOS camera. I followed the new installation guidance and didn't encounter the above problems.

Now I have some questions for live_demo.py:

  1. For jax.block_until_ready(query_features), would jax.block_until_ready(_) make more sense that we wait for the finish of compiling before call online_init_apply again? What's the purpose for jax.block_until_ready(prediction["tracks"])?

    tapnet/tapnet/live_demo.py

    Lines 121 to 142 in 58c5225

    query_points = jnp.zeros([NUM_POINTS, 3], dtype=jnp.float32)
    _ = online_init_apply(
    frames=model_utils.preprocess_frames(frame[None, None]),
    points=query_points[None, 0:1],
    )
    jax.block_until_ready(query_features)
    query_features = online_init_apply(
    frames=model_utils.preprocess_frames(frame[None, None]),
    points=query_points[None, :],
    )
    causal_state = tapir.construct_initial_causal_state(
    NUM_POINTS, len(query_features.resolutions) - 1
    )
    prediction, causal_state = online_predict_apply(
    frames=model_utils.preprocess_frames(frame[None, None]),
    features=query_features,
    causal_context=causal_state,
    )
    jax.block_until_ready(prediction["tracks"])
  2. I read tapir.update_query_features and it has the option for causal_state=None. Do you think it's a good idea to have causal_state=None since causal_state has already been initialized previously to save some computation? Or it doesn't matter at all?

    tapnet/tapnet/live_demo.py

    Lines 179 to 184 in 58c5225

    query_features, causal_state = tapir.update_query_features(
    query_features=query_features,
    new_query_features=init_query_features,
    idx_to_update=np.array([next_query_idx]),
    causal_state=causal_state,
    )

Appreciate any comments and thanks in advance!

@cdoersch
Copy link
Collaborator

Great to hear that it's working on Windows!

Regarding 1, yes, the block_until_ready must be a remnant from an earlier version of the code. It's None at that point, so I guess it's accomplishing nothing. Blocking on _ might make sense though, since it will force jax to compile the online_init code.

Regarding 2, setting causal_state to None might save a tiny bit of computation the first time the function is called, but the result is that causal_state will have a different size every time you add another point (up to 8 points). That means that online_model_predict will need to get recompiled every time you click a new point, which can take minutes.

@weirenorweiren
Copy link
Author

weirenorweiren commented Sep 1, 2024

Following up on 1, I would like to clarify that the first use of jax.block_until_ready makes sense if blocking on _ and the second use of jax.block_until_ready(prediction["tracks"]) is unnecessary. Is it right? Basically I want to get rid of any unnecessary steps since I try to save computation as much as I can.

Following up on 2, I am confused why causal_state will have a different size and online_model_predict needs a recompiling after each click. Based on my understanding, the size of causal_state is fixed after the very beginning initialization. causal_state=causal_state just updates the beginning initialization for the newest click with another new initialization, which seems redundant to me.

tapnet/tapnet/live_demo.py

Lines 169 to 218 in a9ef766

while rval:
rval, frame = get_frame(vc)
if query_frame:
query_points = jnp.array((0,) + pos, dtype=jnp.float32)
init_query_features = online_init_apply(
frames=model_utils.preprocess_frames(frame[None, None]),
points=query_points[None, None],
)
query_frame = False
query_features, causal_state = tapir.update_query_features(
query_features=query_features,
new_query_features=init_query_features,
idx_to_update=np.array([next_query_idx]),
causal_state=causal_state,
)
have_point[next_query_idx] = True
next_query_idx = (next_query_idx + 1) % NUM_POINTS
if pos:
prediction, causal_state = online_predict_apply(
frames=model_utils.preprocess_frames(frame[None, None]),
features=query_features,
causal_context=causal_state,
)
track = prediction["tracks"][0, :, 0]
occlusion = prediction["occlusion"][0, :, 0]
expected_dist = prediction["expected_dist"][0, :, 0]
visibles = model_utils.postprocess_occlusions(occlusion, expected_dist)
track = np.round(track)
for i, _ in enumerate(have_point):
if visibles[i] and have_point[i]:
cv2.circle(
frame, (int(track[i, 0]), int(track[i, 1])), 5, (255, 0, 0), -1
)
if track[i, 0] < 16 and track[i, 1] < 16:
print((i, next_query_idx))
cv2.imshow("Point Tracking", frame[:, ::-1])
if pos:
step_counter += 1
if time.time() - t > 5:
print(f"{step_counter/(time.time()-t)} frames per second")
t = time.time()
step_counter = 0
else:
t = time.time()
key = cv2.waitKey(1)
if key == 27: # exit on ESC
break
def update_query_features(
self, query_features, new_query_features, idx_to_update, causal_state=None
):
if isinstance(idx_to_update, int):
idx_to_update = tuple([idx_to_update])
idx_to_update = np.array(idx_to_update)
def apply_update_idx(s1, s2):
return s1.at[:, idx_to_update].set(s2)
query_features = QueryFeatures(
lowres=jax.tree_util.tree_map(
apply_update_idx, query_features.lowres, new_query_features.lowres
),
hires=jax.tree_util.tree_map(
apply_update_idx, query_features.hires, new_query_features.hires
),
resolutions=query_features.resolutions,
)
if causal_state is not None:
init_causal_state = self.construct_initial_causal_state(
len(idx_to_update), len(query_features.resolutions) - 1
)
causal_state = jax.tree_util.tree_map(
apply_update_idx, causal_state, init_causal_state
)
return query_features, causal_state
return query_features

Another 3 questions are as follows:

  1. In causal_tapir_demo.ipynb, it mentions to average across resolutions for running on high resolution. Could you elaborate how to do that? Or is there already some implementations somewhere in the code base?
    " # Take only the predictions for the final resolution.\n",
    " # For running on higher resolution, it's typically better to average across\n",
    " # resolutions.\n",
  2. My next step is to test the working code using CPU under our experiment setup. However, our camera has a resolution of 2048x2448 and our computer could only afford 0.1FPS if not specifying the resolutions. 1FPS is preferable for the this test. Do you have any suggestions on how I should specify the resolutions to balance the frame rate and the tracking performance? An extra piece of information is that I got 1FPS performance for 360x360 resolution when I tested with a webcam.
  3. After 4 is done, I would start (actually retry) running the code with GPU. Since I have no experience (actually some terrible experience) on this process, I would like to check with you whether my planning is on the right track. So I am thinking of following jax requires jaxlib #25 (comment) by installing a wheel first. If it doesn't work, then try jax requires jaxlib #25 (comment) for WSL2. If you have extra information, I am always willing to know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants