Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ttnn.sampling support for Llama3 8b N150 #18143

Open
ipotkonjak-tt opened this issue Feb 21, 2025 · 0 comments
Open

ttnn.sampling support for Llama3 8b N150 #18143

ipotkonjak-tt opened this issue Feb 21, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@ipotkonjak-tt
Copy link
Contributor

Describe the bug
ttnn.sampling invoked with tensor of the shape [1, 1, 32, 128256] (tensor size of the llama3 8b output on N150) isn't supported.

When called the following error occurs:
Always | FATAL | TODO: add support for multi-paged buffer with page size > 64KB

To Reproduce

@pytest.mark.parametrize(
    "shape",
    [
        # [1, 1, 32, 32 * 8], # PASS
        [1, 1, 32, 128256], # llama3 8b n150 - FAIL
    ],
)
@pytest.mark.parametrize("k", [[1] * 32])  # Example of per-user k
@pytest.mark.parametrize("p", [[1.0] * 32])  # Example of per-user p
@pytest.mark.parametrize("seed", [2024])
@pytest.mark.parametrize(
    "sub_core_grids", [ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(8 - 1, 4 - 1))})]
)
def test_argmax(shape, k, p, seed, device, sub_core_grids, use_program_cache):
    device.enable_async(True)
    torch.manual_seed(seed)

    # Input tensors
    input_values = torch.randn(shape)
    input_indices = torch.arange(0, shape[-1], dtype=torch.int32).expand(shape)

    # TTNN input tensors
    input_values_tensor = ttnn.from_torch(input_values, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
    input_indices_tensor = ttnn.from_torch(input_indices, device=device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT)

    # Argmax compute
    argmax_output = torch.argmax(input_values, -1, keepdim=True).transpose(-1, -2)

    # ttnn.sampling as argmax
    ttnn_output = ttnn.sampling(
        input_values_tensor, input_indices_tensor, k=k, p=p, seed=seed, sub_core_grids=sub_core_grids
    )
    output = ttnn.to_torch(ttnn_output)

    assert_with_pcc(output, argmax_output, 0.9999)

Additional context
This feature is needed to avoid host fallback and make data parallel execution of llama models performant.

@ipotkonjak-tt ipotkonjak-tt added the bug Something isn't working label Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants