Skip to content

Wan Animate Pipeline#367

Open
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/wananimate-pipeline
Open

Wan Animate Pipeline#367
csgoogle wants to merge 1 commit into
mainfrom
sagarchapara/wananimate-pipeline

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Mar 28, 2026

Wan Animate Pipeline

This CL publishes add the Wan Animate pipepline.

  • Reused the existing Wan attention operator for face encoder cross attention.
  • Swept Flash Attention block-size configurations to identify the best inference setting.

Links

Performance

  • compile_time: 292.73833787906915
  • generation_time: 157.68515427410603

Configuration

  • cp: 8 (v6e8)
  • cfg: 1.0
  • prev_segments: 5
  • resolution: 1280x720
  • fps: 24
  • generated_frames: 77

@github-actions
Copy link
Copy Markdown

@csgoogle csgoogle marked this pull request as ready for review April 6, 2026 16:33
@csgoogle csgoogle requested a review from entrpn as a code owner April 6, 2026 16:33
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch 2 times, most recently from e281524 to 349d080 Compare April 13, 2026 09:10
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Comment thread assets/wan_animate/src_face.mp4 Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Comment thread src/maxdiffusion/generate_wan_animate.py
@Perseus14
Copy link
Copy Markdown
Collaborator

Please resolve conflicts and enable support for diagnostics and profiling as in this PR

Comment thread .gitignore Outdated
Comment thread src/maxdiffusion/generate_wan_animate.py
Comment thread src/maxdiffusion/configs/base_wan_animate.yml
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

The Pull Request introduces the Wan Animate pipeline, which includes the transformer model architecture, inference entry point, and necessary utilities. The implementation is comprehensive and follows the established patterns in the repository, including support for segment-based inference and parity with Diffusers.

🔍 General Feedback

  • Performance Optimization: The current implementation of the transformer re-encodes the face video frames during every denoising step. Since the face video is static throughout the inference process, this encoding can be pre-computed once per segment to significantly reduce redundant computation and speed up generation.
  • Compilation Efficiency: The generation script performs two full inference passes. For high-resolution video generation, this double work is expensive. Consider reducing the number of steps in the first (compile) pass.
  • Robustness: Added checks for optional inputs in the transformer to prevent potential runtime errors when face_pixel_values is not provided.
  • Code Quality: The reuse of the Wan attention operator and the integration with the existing configuration system is well-done. The use of nnx.scan for transformer blocks ensures memory efficiency during inference.

Comment thread src/maxdiffusion/generate_wan_animate.py
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch 3 times, most recently from ef88d04 to f6b4c22 Compare May 11, 2026 12:22
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements the Wan Animate pipeline in MaxDiffusion, including the 3D transformer architecture, motion and face encoders, and the inference pipeline with temporal tiling. The implementation is robust, follows established patterns for sharding and JIT optimization, and is backed by comprehensive parity tests.

🔍 General Feedback

  • Efficiency: The use of pre-computed motion vectors once per segment and the implementation of scan_layers show good attention to performance on TPU/GPU hardware.
  • Testing: The parity tests are exhaustive and provide high confidence in the implementation's correctness relative to the reference Diffusers implementation.
  • Modularity: Reusing the Wan attention operator for face conditioning is a clean approach.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Additional technical suggestions for the Wan Animate transformer implementation.

🔍 General Feedback

  • Optimization: Minor optimizations in MotionConv2d to avoid repeated array conversions.
  • Robustness: Defensive check in WanAnimateFaceBlockCrossAttention for sequence length consistency.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request introduces the Wan Animate pipeline and its associated 3D transformer model to the maxdiffusion library. The implementation is comprehensive, including checkpointing support, sharding-aware attention, and extensive parity tests against the reference implementation. The code follows the established architectural patterns of the project and integrates well with existing Wan and NNX utilities.

🔍 General Feedback

  • Parity Testing: Excellent inclusion of detailed parity tests (wan_animate_module_parity_test.py) which ensures the JAX/Flax implementation matches the reference torch model.
  • NNX Migration: The use of flax.nnx for the new models is consistent with the project's direction.
  • Robustness: Some minor improvements suggested for robustness (e.g., division by zero checks) and consistency in parameter access within nnx modules.
  • Documentation: The docstrings are informative and follow the project's style.

Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
Comment thread src/maxdiffusion/generate_wan_animate.py
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
Comment thread src/maxdiffusion/generate_wan_animate.py
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch 2 times, most recently from b2ea208 to a2d4356 Compare May 12, 2026 09:52
- Add WanAnimateTransformer3DModel with motion encoder, face encoder,
  and face adapter cross-attention blocks
- Add WanAnimatePipeline supporting animate and replace modes with
  multi-segment temporal stitching
- Add generate_wan_animate.py inference entrypoint
- Add base_wan_animate.yml config for 720p inference
- Pre-compute face motion vectors once per segment instead of every
  denoising step for faster inference
- Simplify face block cross-attention forward pass: replace einops
  with jnp.reshape, remove redundant sharding constraints
- Add parity tests for animate modules and diffusers comparison
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch from a2d4356 to eddfd4d Compare May 12, 2026 09:55
@Perseus14
Copy link
Copy Markdown
Collaborator

LGTM!

Test case failed is unrelated to this change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants