Skip to content

Support dreamzero model & Add Numerical Alignment Workflow for skill#1142

Open
helloyongyang wants to merge 1 commit into
mainfrom
dreamzero
Open

Support dreamzero model & Add Numerical Alignment Workflow for skill#1142
helloyongyang wants to merge 1 commit into
mainfrom
dreamzero

Conversation

@helloyongyang

Copy link
Copy Markdown
Contributor

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds native support for the DreamZero model, introducing its model structure, inference pipelines, weights loading, a flow scheduler, and a runner for DROID image-to-video-action (I2VA) tasks. The review feedback identifies critical runtime issues, including the use of the non-existent torch.linalg.solve_ex function in the scheduler, a key-filtering bug in the checkpoint loading logic that incorrectly discards valid weights, and a potential crash in the runner when processing state directories that lack the expected files.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

if order == 2:
rhos_p = torch.full((1,), 0.5, dtype=x.dtype, device=self.sigmas.device)
else:
rhos_p = torch.linalg.solve_ex(matrix_r[:-1, :-1], vector_b[:-1])[0].to(x.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

torch.linalg.solve_ex is not a valid PyTorch function and will raise an AttributeError at runtime. Use the standard torch.linalg.solve function instead, which directly returns the solution tensor.

Suggested change
rhos_p = torch.linalg.solve_ex(matrix_r[:-1, :-1], vector_b[:-1])[0].to(x.dtype)
rhos_p = torch.linalg.solve(matrix_r[:-1, :-1], vector_b[:-1]).to(x.dtype)

if order == 1:
rhos_c = torch.full((1,), 0.5, dtype=x.dtype, device=self.sigmas.device)
else:
rhos_c = torch.linalg.solve_ex(matrix_r, vector_b)[0].to(x.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

torch.linalg.solve_ex is not a valid PyTorch function and will raise an AttributeError at runtime. Use the standard torch.linalg.solve function instead, which directly returns the solution tensor.

Suggested change
rhos_c = torch.linalg.solve_ex(matrix_r, vector_b)[0].to(x.dtype)
rhos_c = torch.linalg.solve(matrix_r, vector_b).to(x.dtype)

Comment on lines +43 to +48
def _keep_dreamzero_key(key):
if key.startswith("action_head.model."):
return True
if key.startswith(("blocks.", "patch_embedding.", "text_embedding.", "time_embedding.", "time_projection.", "img_emb.", "head.", "action_encoder.", "state_encoder.", "action_decoder.")):
return True
return False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _keep_dreamzero_key method checks the original key before stripping its prefix. This causes all keys starting with "model.diffusion_model." or "diffusion_model." to be filtered out and not loaded, as they do not match the allowed prefixes in the tuple. We should strip the prefix first before checking if the key should be kept.

Suggested change
def _keep_dreamzero_key(key):
if key.startswith("action_head.model."):
return True
if key.startswith(("blocks.", "patch_embedding.", "text_embedding.", "time_embedding.", "time_projection.", "img_emb.", "head.", "action_encoder.", "state_encoder.", "action_decoder.")):
return True
return False
def _keep_dreamzero_key(key):
for prefix in ("action_head.model.", "model.diffusion_model.", "diffusion_model."):
if key.startswith(prefix):
key = key[len(prefix) :]
if key.startswith(("blocks.", "patch_embedding.", "text_embedding.", "time_embedding.", "time_projection.", "img_emb.", "head.", "action_encoder.", "state_encoder.", "action_decoder.")):
return True
return False

Comment on lines +256 to +261
if os.path.isdir(state_path):
for name in ("state.json", "state.npy"):
candidate = os.path.join(state_path, name)
if os.path.exists(candidate):
state_path = candidate
break

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If state_path is a directory but does not contain state.json or state.npy, the loop completes without finding a file, and state_path remains a directory. This will cause np.loadtxt to crash with an IsADirectoryError or PermissionError. We should raise a clear FileNotFoundError if none of the expected files are found in the directory.

Suggested change
if os.path.isdir(state_path):
for name in ("state.json", "state.npy"):
candidate = os.path.join(state_path, name)
if os.path.exists(candidate):
state_path = candidate
break
if os.path.isdir(state_path):
for name in ("state.json", "state.npy"):
candidate = os.path.join(state_path, name)
if os.path.exists(candidate):
state_path = candidate
break
else:
raise FileNotFoundError(f"Could not find state.json or state.npy in directory: {state_path}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant