Instructions to use kernels-community/flash-mla with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-mla with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-mla") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import random | |
| import torch.nn.functional as F | |
| import flash_mla | |
| # TODO: revise to use the same test as the original code | |
| def test_flash_mla(): | |
| # b = 128 | |
| # s_q = 4096 | |
| # mean_sk = 8192 | |
| # h_q = 16 | |
| # h_kv = 1 | |
| # d = 576 | |
| # dv = 512 | |
| b = 16 | |
| s_q = 16 | |
| mean_sk = 16 | |
| h_q = 16 | |
| h_kv = 1 | |
| d = 576 | |
| dv = 512 | |
| causal = True | |
| varlen = False | |
| print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") | |
| cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) | |
| if varlen: | |
| for i in range(b): | |
| cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) | |
| total_seqlens = cache_seqlens.sum().item() | |
| mean_seqlens = cache_seqlens.float().mean().int().item() | |
| max_seqlen = cache_seqlens.max().item() | |
| # TODO: avoid triton from original code | |
| # max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 | |
| print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") | |
| max_seqlen_pad = max_seqlen + 255 & ~255 # round up to multiple of 256 | |
| q = torch.randn(b, s_q, h_q, d) | |
| block_size = 64 | |
| block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view( | |
| b, max_seqlen_pad // block_size | |
| ) | |
| blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) | |
| print(blocked_k.shape) | |
| for i in range(b): | |
| blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float( | |
| "nan" | |
| ) | |
| blocked_v = blocked_k[..., :dv] | |
| print(blocked_k.shape, blocked_v.shape) | |
| cache_seqlens = cache_seqlens.to("cuda") | |
| tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( | |
| seqlens_k=cache_seqlens, | |
| # | |
| s_q=s_q * h_q // h_kv, | |
| h_kv=h_kv, | |
| ) | |
| print(tile_scheduler_metadata, num_splits) | |
| # TODO: update to expect the correct output | |
| assert False | |