Skip to content

Commit 24304f9

Browse files
committed
add ucache
1 parent d939f6e commit 24304f9

File tree

5 files changed

+530
-15
lines changed

5 files changed

+530
-15
lines changed

examples/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,5 @@ Generation Options:
125125
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126126
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
127127
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
128+
--ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95)
128129
```

examples/cli/main.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,9 @@ struct SDGenerationParams {
10641064
std::string easycache_option;
10651065
sd_easycache_params_t easycache_params;
10661066

1067+
std::string ucache_option;
1068+
sd_ucache_params_t ucache_params;
1069+
10671070
float moe_boundary = 0.875f;
10681071
int video_frames = 1;
10691072
int fps = 16;
@@ -1414,6 +1417,38 @@ struct SDGenerationParams {
14141417
return consumed;
14151418
};
14161419

1420+
auto on_ucache_arg = [&](int argc, const char** argv, int index) {
1421+
const std::string default_values = "1.0,0.15,0.95";
1422+
auto looks_like_value = [](const std::string& token) {
1423+
if (token.empty()) {
1424+
return false;
1425+
}
1426+
if (token[0] != '-') {
1427+
return true;
1428+
}
1429+
if (token.size() == 1) {
1430+
return false;
1431+
}
1432+
unsigned char next = static_cast<unsigned char>(token[1]);
1433+
return std::isdigit(next) || token[1] == '.';
1434+
};
1435+
1436+
std::string option_value;
1437+
int consumed = 0;
1438+
if (index + 1 < argc) {
1439+
std::string next_arg = argv[index + 1];
1440+
if (looks_like_value(next_arg)) {
1441+
option_value = argv_to_utf8(index + 1, argv);
1442+
consumed = 1;
1443+
}
1444+
}
1445+
if (option_value.empty()) {
1446+
option_value = default_values;
1447+
}
1448+
ucache_option = option_value;
1449+
return consumed;
1450+
};
1451+
14171452
options.manual_options = {
14181453
{"-s",
14191454
"--seed",
@@ -1449,6 +1484,10 @@ struct SDGenerationParams {
14491484
"--easycache",
14501485
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
14511486
on_easycache_arg},
1487+
{"",
1488+
"--ucache",
1489+
"enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)",
1490+
on_ucache_arg},
14521491

14531492
};
14541493

@@ -1614,6 +1653,59 @@ struct SDGenerationParams {
16141653
easycache_params.enabled = false;
16151654
}
16161655

1656+
if (!ucache_option.empty()) {
1657+
float values[3] = {0.0f, 0.0f, 0.0f};
1658+
std::stringstream ss(ucache_option);
1659+
std::string token;
1660+
int idx = 0;
1661+
while (std::getline(ss, token, ',')) {
1662+
auto trim = [](std::string& s) {
1663+
const char* whitespace = " \t\r\n";
1664+
auto start = s.find_first_not_of(whitespace);
1665+
if (start == std::string::npos) {
1666+
s.clear();
1667+
return;
1668+
}
1669+
auto end = s.find_last_not_of(whitespace);
1670+
s = s.substr(start, end - start + 1);
1671+
};
1672+
trim(token);
1673+
if (token.empty()) {
1674+
fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str());
1675+
return false;
1676+
}
1677+
if (idx >= 3) {
1678+
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1679+
return false;
1680+
}
1681+
try {
1682+
values[idx] = std::stof(token);
1683+
} catch (const std::exception&) {
1684+
fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str());
1685+
return false;
1686+
}
1687+
idx++;
1688+
}
1689+
if (idx != 3) {
1690+
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1691+
return false;
1692+
}
1693+
if (values[0] < 0.0f) {
1694+
fprintf(stderr, "error: ucache threshold must be non-negative\n");
1695+
return false;
1696+
}
1697+
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1698+
fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1699+
return false;
1700+
}
1701+
ucache_params.enabled = true;
1702+
ucache_params.reuse_threshold = values[0];
1703+
ucache_params.start_percent = values[1];
1704+
ucache_params.end_percent = values[2];
1705+
} else {
1706+
ucache_params.enabled = false;
1707+
}
1708+
16171709
sample_params.guidance.slg.layers = skip_layers.data();
16181710
sample_params.guidance.slg.layer_count = skip_layers.size();
16191711
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
@@ -2292,6 +2384,7 @@ int main(int argc, const char* argv[]) {
22922384
}, // pm_params
22932385
ctx_params.vae_tiling_params,
22942386
gen_params.easycache_params,
2387+
gen_params.ucache_params,
22952388
};
22962389

22972390
results = generate_image(sd_ctx, &img_gen_params);
@@ -2317,6 +2410,7 @@ int main(int argc, const char* argv[]) {
23172410
gen_params.video_frames,
23182411
gen_params.vace_strength,
23192412
gen_params.easycache_params,
2413+
gen_params.ucache_params,
23202414
};
23212415

23222416
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

0 commit comments

Comments
 (0)