Kernels
danieldk HF Staff commited on
Commit
e786cb7
·
verified ·
1 Parent(s): 5959b5c

Build uploaded using `kernels`.

Browse files
Files changed (47) hide show
  1. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc +0 -0
  2. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc +0 -0
  3. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc +0 -0
  4. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc +0 -0
  5. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc +0 -0
  6. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/{_causal_conv1d_5b5a6b7.abi3.so → _causal_conv1d_306ae84.abi3.so} +2 -2
  7. build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py +3 -3
  8. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc +0 -0
  9. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc +0 -0
  10. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc +0 -0
  11. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc +0 -0
  12. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc +0 -0
  13. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/{_causal_conv1d_5b5a6b7.abi3.so → _causal_conv1d_306ae84.abi3.so} +2 -2
  14. build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py +3 -3
  15. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py +4 -0
  16. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc +0 -0
  17. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc +0 -0
  18. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc +0 -0
  19. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc +0 -0
  20. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc +0 -0
  21. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so +3 -0
  22. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/_ops.py +9 -0
  23. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/causal_conv1d_interface.py +242 -0
  24. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py +86 -0
  25. build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/cpp_functions.py +96 -0
  26. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py +4 -0
  27. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc +0 -0
  28. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc +0 -0
  29. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc +0 -0
  30. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc +0 -0
  31. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc +0 -0
  32. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so +3 -0
  33. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py +9 -0
  34. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py +242 -0
  35. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py +86 -0
  36. build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py +96 -0
  37. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py +4 -0
  38. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc +0 -0
  39. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc +0 -0
  40. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc +0 -0
  41. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc +0 -0
  42. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc +0 -0
  43. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so +3 -0
  44. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/_ops.py +9 -0
  45. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/causal_conv1d_interface.py +242 -0
  46. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py +86 -0
  47. build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/cpp_functions.py +96 -0
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/{_causal_conv1d_5b5a6b7.abi3.so → _causal_conv1d_306ae84.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f53abe29129caaa23e5e9d730b873cd426835679a362b8927dce9b52880a3bc0
3
- size 102592040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ec3c41413afbb69d499eae6a432fa9d41a580e7b1c6ee83d09e8dab51f91803
3
+ size 90795560
build/torch27-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _causal_conv1d_5b5a6b7
3
- ops = torch.ops._causal_conv1d_5b5a6b7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_causal_conv1d_5b5a6b7::{op_name}"
 
1
  import torch
2
+ from . import _causal_conv1d_306ae84
3
+ ops = torch.ops._causal_conv1d_306ae84
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_causal_conv1d_306ae84::{op_name}"
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/{_causal_conv1d_5b5a6b7.abi3.so → _causal_conv1d_306ae84.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6896f3385fb9d876c29eec9e015666a760ba6e6ba3f3b4ae2c3c9a800fde4be2
3
- size 110261680
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c3535b795cbc5baf363b0cd8636649153b017599c9992ea6c08aa4ab23ceae0
3
+ size 97678768
build/torch28-cxx11-cu129-aarch64-linux/causal_conv1d/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _causal_conv1d_5b5a6b7
3
- ops = torch.ops._causal_conv1d_5b5a6b7
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_causal_conv1d_5b5a6b7::{op_name}"
 
1
  import torch
2
+ from . import _causal_conv1d_306ae84
3
+ ops = torch.ops._causal_conv1d_306ae84
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_causal_conv1d_306ae84::{op_name}"
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
+ from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
+
4
+ __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (391 Bytes). View file
 
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (536 Bytes). View file
 
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc ADDED
Binary file (9.94 kB). View file
 
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc ADDED
Binary file (4.92 kB). View file
 
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc ADDED
Binary file (3.62 kB). View file
 
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d069be88c48d9a42dd8cceeabe1043f99f53196ada6760a85c2fc6b3bee22cfe
3
+ size 64340360
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _causal_conv1d_306ae84
3
+ ops = torch.ops._causal_conv1d_306ae84
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_causal_conv1d_306ae84::{op_name}"
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
+
8
+
9
+ class CausalConv1dFn(torch.autograd.Function):
10
+ @staticmethod
11
+ def forward(
12
+ ctx,
13
+ x,
14
+ weight,
15
+ bias=None,
16
+ seq_idx=None,
17
+ initial_states=None,
18
+ return_final_states=False,
19
+ final_states_out=None,
20
+ activation=None,
21
+ ):
22
+ if activation not in [None, "silu", "swish"]:
23
+ raise NotImplementedError("activation must be None, silu, or swish")
24
+ if x.stride(2) != 1 and x.stride(1) != 1:
25
+ x = x.contiguous()
26
+ bias = bias.contiguous() if bias is not None else None
27
+ if seq_idx is not None:
28
+ assert (
29
+ initial_states is None
30
+ ), "initial_states must be None if seq_idx is not None"
31
+ assert (
32
+ not return_final_states
33
+ ), "If seq_idx is not None, we don't return final_states_out"
34
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
+ if initial_states is not None and (
36
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
+ ):
38
+ initial_states = initial_states.contiguous()
39
+ if return_final_states:
40
+ assert (
41
+ x.stride(1) == 1
42
+ ), "Only channel-last layout support returning final_states_out"
43
+ if final_states_out is not None:
44
+ assert (
45
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
+ )
47
+ else:
48
+ batch, dim, seqlen = x.shape
49
+ width = weight.shape[1]
50
+ final_states_out = torch.empty(
51
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
52
+ ).transpose(1, 2)
53
+ else:
54
+ final_states_out = None
55
+ ctx.activation = activation in ["silu", "swish"]
56
+ out = causal_conv1d_fwd_function(
57
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
+ )
59
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
+ ctx.return_final_states = return_final_states
61
+ ctx.return_dinitial_states = (
62
+ initial_states is not None and initial_states.requires_grad
63
+ )
64
+ return out if not return_final_states else (out, final_states_out)
65
+
66
+ @staticmethod
67
+ def backward(ctx, dout, *args):
68
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
+ dfinal_states = args[0] if ctx.return_final_states else None
70
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
71
+ dout = dout.contiguous()
72
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
+ # backward of conv1d with the backward of chunk).
74
+ # Here we just pass in None and dx will be allocated in the C++ code.
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
+ x,
77
+ weight,
78
+ bias,
79
+ dout,
80
+ seq_idx,
81
+ initial_states,
82
+ dfinal_states,
83
+ None,
84
+ ctx.return_dinitial_states,
85
+ ctx.activation,
86
+ )
87
+ return (
88
+ dx,
89
+ dweight,
90
+ dbias if bias is not None else None,
91
+ None,
92
+ dinitial_states if initial_states is not None else None,
93
+ None,
94
+ None,
95
+ None,
96
+ )
97
+
98
+
99
+ def causal_conv1d_fn(
100
+ x,
101
+ weight,
102
+ bias=None,
103
+ seq_idx=None,
104
+ initial_states=None,
105
+ return_final_states=False,
106
+ final_states_out=None,
107
+ activation=None,
108
+ ):
109
+ """
110
+ x: (batch, dim, seqlen)
111
+ weight: (dim, width)
112
+ bias: (dim,)
113
+ seq_idx: (batch, seqlen)
114
+ initial_states: (batch, dim, width - 1)
115
+ final_states_out: (batch, dim, width - 1), to be written to
116
+ activation: either None or "silu" or "swish"
117
+
118
+ out: (batch, dim, seqlen)
119
+ """
120
+ return CausalConv1dFn.apply(
121
+ x,
122
+ weight,
123
+ bias,
124
+ seq_idx,
125
+ initial_states,
126
+ return_final_states,
127
+ final_states_out,
128
+ activation,
129
+ )
130
+
131
+
132
+ def causal_conv1d_ref(
133
+ x,
134
+ weight,
135
+ bias=None,
136
+ initial_states=None,
137
+ return_final_states=False,
138
+ final_states_out=None,
139
+ activation=None,
140
+ ):
141
+ """
142
+ x: (batch, dim, seqlen)
143
+ weight: (dim, width)
144
+ bias: (dim,)
145
+ initial_states: (batch, dim, width - 1)
146
+ final_states_out: (batch, dim, width - 1)
147
+
148
+ out: (batch, dim, seqlen)
149
+ """
150
+ if activation not in [None, "silu", "swish"]:
151
+ raise NotImplementedError("activation must be None, silu, or swish")
152
+ dtype_in = x.dtype
153
+ x = x.to(weight.dtype)
154
+ seqlen = x.shape[-1]
155
+ dim, width = weight.shape
156
+ if initial_states is None:
157
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
+ else:
159
+ x = torch.cat([initial_states, x], dim=-1)
160
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
+ out = out[..., :seqlen]
162
+ if return_final_states:
163
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
+ dtype_in
165
+ ) # (batch, dim, width - 1)
166
+ if final_states_out is not None:
167
+ final_states_out.copy_(final_states)
168
+ else:
169
+ final_states_out = final_states
170
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
+ return out if not return_final_states else (out, final_states_out)
172
+
173
+
174
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
+ """
176
+ x: (batch, dim) or (batch, dim, seqlen)
177
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
178
+ weight: (dim, width)
179
+ bias: (dim,)
180
+ cache_seqlens: (batch,), dtype int32.
181
+ If not None, the conv_state is treated as a circular buffer.
182
+ The conv_state will be updated by copying x to the conv_state starting at the index
183
+ @cache_seqlens % state_len.
184
+ conv_state_indices: (batch,), dtype int32
185
+ If None, the conv_state is a larger tensor along the batch dim,
186
+ and we are selecting the batch coords specified by conv_state_indices.
187
+ Useful for a continuous batching scenario.
188
+
189
+ out: (batch, dim) or (batch, dim, seqlen)
190
+ """
191
+ if activation not in [None, "silu", "swish"]:
192
+ raise NotImplementedError("activation must be None, silu, or swish")
193
+ activation = activation in ["silu", "swish"]
194
+ unsqueeze = x.dim() == 2
195
+ if unsqueeze:
196
+ x = x.unsqueeze(-1)
197
+ out = causal_conv1d_update_function(
198
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
+ )
200
+ if unsqueeze:
201
+ out = out.squeeze(-1)
202
+ return out
203
+
204
+
205
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
+ """
207
+ x: (batch, dim) or (batch, dim, seqlen)
208
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
209
+ weight: (dim, width)
210
+ bias: (dim,)
211
+ cache_seqlens: (batch,), dtype int32.
212
+ If not None, the conv_state is treated as a circular buffer.
213
+ The conv_state will be updated by copying x to the conv_state starting at the index
214
+ @cache_seqlens % state_len before performing the convolution.
215
+
216
+ out: (batch, dim) or (batch, dim, seqlen)
217
+ """
218
+ if activation not in [None, "silu", "swish"]:
219
+ raise NotImplementedError("activation must be None, silu, or swish")
220
+ dtype_in = x.dtype
221
+ unsqueeze = x.dim() == 2
222
+ if unsqueeze:
223
+ x = x.unsqueeze(-1)
224
+ batch, dim, seqlen = x.shape
225
+ width = weight.shape[1]
226
+ state_len = conv_state.shape[-1]
227
+ assert conv_state.shape == (batch, dim, state_len)
228
+ assert weight.shape == (dim, width)
229
+ if cache_seqlens is None:
230
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
+ conv_state.copy_(x_new[:, :, -state_len:])
232
+ else:
233
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
+ conv_state.scatter_(2, copy_idx, x)
239
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
+ if unsqueeze:
241
+ out = out.squeeze(-1)
242
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
build/torch29-cxx11-cu126-aarch64-linux/causal_conv1d/cpp_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def causal_conv1d_fwd_function(
8
+ x: torch.Tensor,
9
+ weight: torch.Tensor,
10
+ bias: torch.Tensor | None,
11
+ seq_idx: torch.Tensor | None,
12
+ initial_states: torch.Tensor | None,
13
+ final_states_out: torch.Tensor | None,
14
+ silu_activation: bool,
15
+ ) -> torch.Tensor:
16
+ out = torch.empty_like(x)
17
+ ops.causal_conv1d_fwd(
18
+ x=x,
19
+ weight=weight,
20
+ bias=bias,
21
+ seq_idx=seq_idx,
22
+ initial_states=initial_states,
23
+ out=out,
24
+ final_states_out=final_states_out,
25
+ silu_activation=silu_activation,
26
+ )
27
+ return out
28
+
29
+
30
+ def causal_conv1d_bwd_function(
31
+ x: torch.Tensor,
32
+ weight: torch.Tensor,
33
+ bias: torch.Tensor | None,
34
+ dout: torch.Tensor,
35
+ seq_idx: torch.Tensor | None,
36
+ initial_states: torch.Tensor | None,
37
+ dfinal_states: torch.Tensor | None,
38
+ dx: torch.Tensor | None,
39
+ return_dinitial_states: torch.Tensor,
40
+ silu_activation: bool,
41
+ ) -> tuple[torch.Tensor | None]:
42
+ batch_size, dim = x.size()[:2]
43
+ width = weight.size(-1)
44
+
45
+ if dx is None:
46
+ dx = torch.empty_like(x)
47
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
48
+ dbias = None
49
+ if bias is not None:
50
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
51
+ dinitial_states = None
52
+ if return_dinitial_states:
53
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
+
55
+ ops.causal_conv1d_bwd(
56
+ x=x,
57
+ weight=weight,
58
+ bias=bias,
59
+ dout=dout,
60
+ seq_idx=seq_idx,
61
+ initial_states=initial_states,
62
+ dfinal_states=dfinal_states,
63
+ dx=dx,
64
+ dweight=dweight,
65
+ dbias=dbias,
66
+ dinitial_states=dinitial_states,
67
+ silu_activation=silu_activation,
68
+ )
69
+
70
+ dweight = dweight.type_as(weight)
71
+ if dbias is not None:
72
+ dbias = dbias.type_as(bias)
73
+ return dx, dweight, dbias, dinitial_states
74
+
75
+
76
+ def causal_conv1d_update_function(
77
+ x: torch.Tensor,
78
+ conv_state: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor | None,
81
+ silu_activation: bool,
82
+ cache_seqlens: torch.Tensor | None,
83
+ conv_state_indices: torch.Tensor | None,
84
+ ) -> torch.Tensor:
85
+ out = torch.empty_like(x)
86
+ ops.causal_conv1d_update(
87
+ x=x,
88
+ conv_state=conv_state,
89
+ weight=weight,
90
+ bias=bias,
91
+ out=out,
92
+ silu_activation=silu_activation,
93
+ cache_seqlens=cache_seqlens,
94
+ conv_state_indices=conv_state_indices,
95
+ )
96
+ return out
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
+ from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
+
4
+ __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (391 Bytes). View file
 
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (536 Bytes). View file
 
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc ADDED
Binary file (9.94 kB). View file
 
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc ADDED
Binary file (4.92 kB). View file
 
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc ADDED
Binary file (3.62 kB). View file
 
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b252d8148e4fdb48015fbf2cfc3a5c2683c414066b7b825d32a9f5578e3869d6
3
+ size 90795712
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _causal_conv1d_306ae84
3
+ ops = torch.ops._causal_conv1d_306ae84
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_causal_conv1d_306ae84::{op_name}"
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
+
8
+
9
+ class CausalConv1dFn(torch.autograd.Function):
10
+ @staticmethod
11
+ def forward(
12
+ ctx,
13
+ x,
14
+ weight,
15
+ bias=None,
16
+ seq_idx=None,
17
+ initial_states=None,
18
+ return_final_states=False,
19
+ final_states_out=None,
20
+ activation=None,
21
+ ):
22
+ if activation not in [None, "silu", "swish"]:
23
+ raise NotImplementedError("activation must be None, silu, or swish")
24
+ if x.stride(2) != 1 and x.stride(1) != 1:
25
+ x = x.contiguous()
26
+ bias = bias.contiguous() if bias is not None else None
27
+ if seq_idx is not None:
28
+ assert (
29
+ initial_states is None
30
+ ), "initial_states must be None if seq_idx is not None"
31
+ assert (
32
+ not return_final_states
33
+ ), "If seq_idx is not None, we don't return final_states_out"
34
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
+ if initial_states is not None and (
36
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
+ ):
38
+ initial_states = initial_states.contiguous()
39
+ if return_final_states:
40
+ assert (
41
+ x.stride(1) == 1
42
+ ), "Only channel-last layout support returning final_states_out"
43
+ if final_states_out is not None:
44
+ assert (
45
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
+ )
47
+ else:
48
+ batch, dim, seqlen = x.shape
49
+ width = weight.shape[1]
50
+ final_states_out = torch.empty(
51
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
52
+ ).transpose(1, 2)
53
+ else:
54
+ final_states_out = None
55
+ ctx.activation = activation in ["silu", "swish"]
56
+ out = causal_conv1d_fwd_function(
57
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
+ )
59
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
+ ctx.return_final_states = return_final_states
61
+ ctx.return_dinitial_states = (
62
+ initial_states is not None and initial_states.requires_grad
63
+ )
64
+ return out if not return_final_states else (out, final_states_out)
65
+
66
+ @staticmethod
67
+ def backward(ctx, dout, *args):
68
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
+ dfinal_states = args[0] if ctx.return_final_states else None
70
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
71
+ dout = dout.contiguous()
72
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
+ # backward of conv1d with the backward of chunk).
74
+ # Here we just pass in None and dx will be allocated in the C++ code.
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
+ x,
77
+ weight,
78
+ bias,
79
+ dout,
80
+ seq_idx,
81
+ initial_states,
82
+ dfinal_states,
83
+ None,
84
+ ctx.return_dinitial_states,
85
+ ctx.activation,
86
+ )
87
+ return (
88
+ dx,
89
+ dweight,
90
+ dbias if bias is not None else None,
91
+ None,
92
+ dinitial_states if initial_states is not None else None,
93
+ None,
94
+ None,
95
+ None,
96
+ )
97
+
98
+
99
+ def causal_conv1d_fn(
100
+ x,
101
+ weight,
102
+ bias=None,
103
+ seq_idx=None,
104
+ initial_states=None,
105
+ return_final_states=False,
106
+ final_states_out=None,
107
+ activation=None,
108
+ ):
109
+ """
110
+ x: (batch, dim, seqlen)
111
+ weight: (dim, width)
112
+ bias: (dim,)
113
+ seq_idx: (batch, seqlen)
114
+ initial_states: (batch, dim, width - 1)
115
+ final_states_out: (batch, dim, width - 1), to be written to
116
+ activation: either None or "silu" or "swish"
117
+
118
+ out: (batch, dim, seqlen)
119
+ """
120
+ return CausalConv1dFn.apply(
121
+ x,
122
+ weight,
123
+ bias,
124
+ seq_idx,
125
+ initial_states,
126
+ return_final_states,
127
+ final_states_out,
128
+ activation,
129
+ )
130
+
131
+
132
+ def causal_conv1d_ref(
133
+ x,
134
+ weight,
135
+ bias=None,
136
+ initial_states=None,
137
+ return_final_states=False,
138
+ final_states_out=None,
139
+ activation=None,
140
+ ):
141
+ """
142
+ x: (batch, dim, seqlen)
143
+ weight: (dim, width)
144
+ bias: (dim,)
145
+ initial_states: (batch, dim, width - 1)
146
+ final_states_out: (batch, dim, width - 1)
147
+
148
+ out: (batch, dim, seqlen)
149
+ """
150
+ if activation not in [None, "silu", "swish"]:
151
+ raise NotImplementedError("activation must be None, silu, or swish")
152
+ dtype_in = x.dtype
153
+ x = x.to(weight.dtype)
154
+ seqlen = x.shape[-1]
155
+ dim, width = weight.shape
156
+ if initial_states is None:
157
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
+ else:
159
+ x = torch.cat([initial_states, x], dim=-1)
160
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
+ out = out[..., :seqlen]
162
+ if return_final_states:
163
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
+ dtype_in
165
+ ) # (batch, dim, width - 1)
166
+ if final_states_out is not None:
167
+ final_states_out.copy_(final_states)
168
+ else:
169
+ final_states_out = final_states
170
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
+ return out if not return_final_states else (out, final_states_out)
172
+
173
+
174
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
+ """
176
+ x: (batch, dim) or (batch, dim, seqlen)
177
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
178
+ weight: (dim, width)
179
+ bias: (dim,)
180
+ cache_seqlens: (batch,), dtype int32.
181
+ If not None, the conv_state is treated as a circular buffer.
182
+ The conv_state will be updated by copying x to the conv_state starting at the index
183
+ @cache_seqlens % state_len.
184
+ conv_state_indices: (batch,), dtype int32
185
+ If None, the conv_state is a larger tensor along the batch dim,
186
+ and we are selecting the batch coords specified by conv_state_indices.
187
+ Useful for a continuous batching scenario.
188
+
189
+ out: (batch, dim) or (batch, dim, seqlen)
190
+ """
191
+ if activation not in [None, "silu", "swish"]:
192
+ raise NotImplementedError("activation must be None, silu, or swish")
193
+ activation = activation in ["silu", "swish"]
194
+ unsqueeze = x.dim() == 2
195
+ if unsqueeze:
196
+ x = x.unsqueeze(-1)
197
+ out = causal_conv1d_update_function(
198
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
+ )
200
+ if unsqueeze:
201
+ out = out.squeeze(-1)
202
+ return out
203
+
204
+
205
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
+ """
207
+ x: (batch, dim) or (batch, dim, seqlen)
208
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
209
+ weight: (dim, width)
210
+ bias: (dim,)
211
+ cache_seqlens: (batch,), dtype int32.
212
+ If not None, the conv_state is treated as a circular buffer.
213
+ The conv_state will be updated by copying x to the conv_state starting at the index
214
+ @cache_seqlens % state_len before performing the convolution.
215
+
216
+ out: (batch, dim) or (batch, dim, seqlen)
217
+ """
218
+ if activation not in [None, "silu", "swish"]:
219
+ raise NotImplementedError("activation must be None, silu, or swish")
220
+ dtype_in = x.dtype
221
+ unsqueeze = x.dim() == 2
222
+ if unsqueeze:
223
+ x = x.unsqueeze(-1)
224
+ batch, dim, seqlen = x.shape
225
+ width = weight.shape[1]
226
+ state_len = conv_state.shape[-1]
227
+ assert conv_state.shape == (batch, dim, state_len)
228
+ assert weight.shape == (dim, width)
229
+ if cache_seqlens is None:
230
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
+ conv_state.copy_(x_new[:, :, -state_len:])
232
+ else:
233
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
+ conv_state.scatter_(2, copy_idx, x)
239
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
+ if unsqueeze:
241
+ out = out.squeeze(-1)
242
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
build/torch29-cxx11-cu128-aarch64-linux/causal_conv1d/cpp_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def causal_conv1d_fwd_function(
8
+ x: torch.Tensor,
9
+ weight: torch.Tensor,
10
+ bias: torch.Tensor | None,
11
+ seq_idx: torch.Tensor | None,
12
+ initial_states: torch.Tensor | None,
13
+ final_states_out: torch.Tensor | None,
14
+ silu_activation: bool,
15
+ ) -> torch.Tensor:
16
+ out = torch.empty_like(x)
17
+ ops.causal_conv1d_fwd(
18
+ x=x,
19
+ weight=weight,
20
+ bias=bias,
21
+ seq_idx=seq_idx,
22
+ initial_states=initial_states,
23
+ out=out,
24
+ final_states_out=final_states_out,
25
+ silu_activation=silu_activation,
26
+ )
27
+ return out
28
+
29
+
30
+ def causal_conv1d_bwd_function(
31
+ x: torch.Tensor,
32
+ weight: torch.Tensor,
33
+ bias: torch.Tensor | None,
34
+ dout: torch.Tensor,
35
+ seq_idx: torch.Tensor | None,
36
+ initial_states: torch.Tensor | None,
37
+ dfinal_states: torch.Tensor | None,
38
+ dx: torch.Tensor | None,
39
+ return_dinitial_states: torch.Tensor,
40
+ silu_activation: bool,
41
+ ) -> tuple[torch.Tensor | None]:
42
+ batch_size, dim = x.size()[:2]
43
+ width = weight.size(-1)
44
+
45
+ if dx is None:
46
+ dx = torch.empty_like(x)
47
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
48
+ dbias = None
49
+ if bias is not None:
50
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
51
+ dinitial_states = None
52
+ if return_dinitial_states:
53
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
+
55
+ ops.causal_conv1d_bwd(
56
+ x=x,
57
+ weight=weight,
58
+ bias=bias,
59
+ dout=dout,
60
+ seq_idx=seq_idx,
61
+ initial_states=initial_states,
62
+ dfinal_states=dfinal_states,
63
+ dx=dx,
64
+ dweight=dweight,
65
+ dbias=dbias,
66
+ dinitial_states=dinitial_states,
67
+ silu_activation=silu_activation,
68
+ )
69
+
70
+ dweight = dweight.type_as(weight)
71
+ if dbias is not None:
72
+ dbias = dbias.type_as(bias)
73
+ return dx, dweight, dbias, dinitial_states
74
+
75
+
76
+ def causal_conv1d_update_function(
77
+ x: torch.Tensor,
78
+ conv_state: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor | None,
81
+ silu_activation: bool,
82
+ cache_seqlens: torch.Tensor | None,
83
+ conv_state_indices: torch.Tensor | None,
84
+ ) -> torch.Tensor:
85
+ out = torch.empty_like(x)
86
+ ops.causal_conv1d_update(
87
+ x=x,
88
+ conv_state=conv_state,
89
+ weight=weight,
90
+ bias=bias,
91
+ out=out,
92
+ silu_activation=silu_activation,
93
+ cache_seqlens=cache_seqlens,
94
+ conv_state_indices=conv_state_indices,
95
+ )
96
+ return out
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
+ from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
+
4
+ __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (391 Bytes). View file
 
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (536 Bytes). View file
 
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc ADDED
Binary file (9.94 kB). View file
 
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc ADDED
Binary file (4.92 kB). View file
 
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc ADDED
Binary file (3.62 kB). View file
 
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/_causal_conv1d_306ae84.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f7ee76fbd178668417f36c383073d2aa2a20514ff05641f015ee9efaf6b1e6
3
+ size 58331720
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _causal_conv1d_306ae84
3
+ ops = torch.ops._causal_conv1d_306ae84
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_causal_conv1d_306ae84::{op_name}"
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
+
8
+
9
+ class CausalConv1dFn(torch.autograd.Function):
10
+ @staticmethod
11
+ def forward(
12
+ ctx,
13
+ x,
14
+ weight,
15
+ bias=None,
16
+ seq_idx=None,
17
+ initial_states=None,
18
+ return_final_states=False,
19
+ final_states_out=None,
20
+ activation=None,
21
+ ):
22
+ if activation not in [None, "silu", "swish"]:
23
+ raise NotImplementedError("activation must be None, silu, or swish")
24
+ if x.stride(2) != 1 and x.stride(1) != 1:
25
+ x = x.contiguous()
26
+ bias = bias.contiguous() if bias is not None else None
27
+ if seq_idx is not None:
28
+ assert (
29
+ initial_states is None
30
+ ), "initial_states must be None if seq_idx is not None"
31
+ assert (
32
+ not return_final_states
33
+ ), "If seq_idx is not None, we don't return final_states_out"
34
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
+ if initial_states is not None and (
36
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
+ ):
38
+ initial_states = initial_states.contiguous()
39
+ if return_final_states:
40
+ assert (
41
+ x.stride(1) == 1
42
+ ), "Only channel-last layout support returning final_states_out"
43
+ if final_states_out is not None:
44
+ assert (
45
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
+ )
47
+ else:
48
+ batch, dim, seqlen = x.shape
49
+ width = weight.shape[1]
50
+ final_states_out = torch.empty(
51
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
52
+ ).transpose(1, 2)
53
+ else:
54
+ final_states_out = None
55
+ ctx.activation = activation in ["silu", "swish"]
56
+ out = causal_conv1d_fwd_function(
57
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
+ )
59
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
+ ctx.return_final_states = return_final_states
61
+ ctx.return_dinitial_states = (
62
+ initial_states is not None and initial_states.requires_grad
63
+ )
64
+ return out if not return_final_states else (out, final_states_out)
65
+
66
+ @staticmethod
67
+ def backward(ctx, dout, *args):
68
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
+ dfinal_states = args[0] if ctx.return_final_states else None
70
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
71
+ dout = dout.contiguous()
72
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
+ # backward of conv1d with the backward of chunk).
74
+ # Here we just pass in None and dx will be allocated in the C++ code.
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
+ x,
77
+ weight,
78
+ bias,
79
+ dout,
80
+ seq_idx,
81
+ initial_states,
82
+ dfinal_states,
83
+ None,
84
+ ctx.return_dinitial_states,
85
+ ctx.activation,
86
+ )
87
+ return (
88
+ dx,
89
+ dweight,
90
+ dbias if bias is not None else None,
91
+ None,
92
+ dinitial_states if initial_states is not None else None,
93
+ None,
94
+ None,
95
+ None,
96
+ )
97
+
98
+
99
+ def causal_conv1d_fn(
100
+ x,
101
+ weight,
102
+ bias=None,
103
+ seq_idx=None,
104
+ initial_states=None,
105
+ return_final_states=False,
106
+ final_states_out=None,
107
+ activation=None,
108
+ ):
109
+ """
110
+ x: (batch, dim, seqlen)
111
+ weight: (dim, width)
112
+ bias: (dim,)
113
+ seq_idx: (batch, seqlen)
114
+ initial_states: (batch, dim, width - 1)
115
+ final_states_out: (batch, dim, width - 1), to be written to
116
+ activation: either None or "silu" or "swish"
117
+
118
+ out: (batch, dim, seqlen)
119
+ """
120
+ return CausalConv1dFn.apply(
121
+ x,
122
+ weight,
123
+ bias,
124
+ seq_idx,
125
+ initial_states,
126
+ return_final_states,
127
+ final_states_out,
128
+ activation,
129
+ )
130
+
131
+
132
+ def causal_conv1d_ref(
133
+ x,
134
+ weight,
135
+ bias=None,
136
+ initial_states=None,
137
+ return_final_states=False,
138
+ final_states_out=None,
139
+ activation=None,
140
+ ):
141
+ """
142
+ x: (batch, dim, seqlen)
143
+ weight: (dim, width)
144
+ bias: (dim,)
145
+ initial_states: (batch, dim, width - 1)
146
+ final_states_out: (batch, dim, width - 1)
147
+
148
+ out: (batch, dim, seqlen)
149
+ """
150
+ if activation not in [None, "silu", "swish"]:
151
+ raise NotImplementedError("activation must be None, silu, or swish")
152
+ dtype_in = x.dtype
153
+ x = x.to(weight.dtype)
154
+ seqlen = x.shape[-1]
155
+ dim, width = weight.shape
156
+ if initial_states is None:
157
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
+ else:
159
+ x = torch.cat([initial_states, x], dim=-1)
160
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
+ out = out[..., :seqlen]
162
+ if return_final_states:
163
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
+ dtype_in
165
+ ) # (batch, dim, width - 1)
166
+ if final_states_out is not None:
167
+ final_states_out.copy_(final_states)
168
+ else:
169
+ final_states_out = final_states
170
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
+ return out if not return_final_states else (out, final_states_out)
172
+
173
+
174
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
+ """
176
+ x: (batch, dim) or (batch, dim, seqlen)
177
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
178
+ weight: (dim, width)
179
+ bias: (dim,)
180
+ cache_seqlens: (batch,), dtype int32.
181
+ If not None, the conv_state is treated as a circular buffer.
182
+ The conv_state will be updated by copying x to the conv_state starting at the index
183
+ @cache_seqlens % state_len.
184
+ conv_state_indices: (batch,), dtype int32
185
+ If None, the conv_state is a larger tensor along the batch dim,
186
+ and we are selecting the batch coords specified by conv_state_indices.
187
+ Useful for a continuous batching scenario.
188
+
189
+ out: (batch, dim) or (batch, dim, seqlen)
190
+ """
191
+ if activation not in [None, "silu", "swish"]:
192
+ raise NotImplementedError("activation must be None, silu, or swish")
193
+ activation = activation in ["silu", "swish"]
194
+ unsqueeze = x.dim() == 2
195
+ if unsqueeze:
196
+ x = x.unsqueeze(-1)
197
+ out = causal_conv1d_update_function(
198
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
+ )
200
+ if unsqueeze:
201
+ out = out.squeeze(-1)
202
+ return out
203
+
204
+
205
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
+ """
207
+ x: (batch, dim) or (batch, dim, seqlen)
208
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
209
+ weight: (dim, width)
210
+ bias: (dim,)
211
+ cache_seqlens: (batch,), dtype int32.
212
+ If not None, the conv_state is treated as a circular buffer.
213
+ The conv_state will be updated by copying x to the conv_state starting at the index
214
+ @cache_seqlens % state_len before performing the convolution.
215
+
216
+ out: (batch, dim) or (batch, dim, seqlen)
217
+ """
218
+ if activation not in [None, "silu", "swish"]:
219
+ raise NotImplementedError("activation must be None, silu, or swish")
220
+ dtype_in = x.dtype
221
+ unsqueeze = x.dim() == 2
222
+ if unsqueeze:
223
+ x = x.unsqueeze(-1)
224
+ batch, dim, seqlen = x.shape
225
+ width = weight.shape[1]
226
+ state_len = conv_state.shape[-1]
227
+ assert conv_state.shape == (batch, dim, state_len)
228
+ assert weight.shape == (dim, width)
229
+ if cache_seqlens is None:
230
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
+ conv_state.copy_(x_new[:, :, -state_len:])
232
+ else:
233
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
+ conv_state.scatter_(2, copy_idx, x)
239
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
+ if unsqueeze:
241
+ out = out.squeeze(-1)
242
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
build/torch29-cxx11-cu130-aarch64-linux/causal_conv1d/cpp_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def causal_conv1d_fwd_function(
8
+ x: torch.Tensor,
9
+ weight: torch.Tensor,
10
+ bias: torch.Tensor | None,
11
+ seq_idx: torch.Tensor | None,
12
+ initial_states: torch.Tensor | None,
13
+ final_states_out: torch.Tensor | None,
14
+ silu_activation: bool,
15
+ ) -> torch.Tensor:
16
+ out = torch.empty_like(x)
17
+ ops.causal_conv1d_fwd(
18
+ x=x,
19
+ weight=weight,
20
+ bias=bias,
21
+ seq_idx=seq_idx,
22
+ initial_states=initial_states,
23
+ out=out,
24
+ final_states_out=final_states_out,
25
+ silu_activation=silu_activation,
26
+ )
27
+ return out
28
+
29
+
30
+ def causal_conv1d_bwd_function(
31
+ x: torch.Tensor,
32
+ weight: torch.Tensor,
33
+ bias: torch.Tensor | None,
34
+ dout: torch.Tensor,
35
+ seq_idx: torch.Tensor | None,
36
+ initial_states: torch.Tensor | None,
37
+ dfinal_states: torch.Tensor | None,
38
+ dx: torch.Tensor | None,
39
+ return_dinitial_states: torch.Tensor,
40
+ silu_activation: bool,
41
+ ) -> tuple[torch.Tensor | None]:
42
+ batch_size, dim = x.size()[:2]
43
+ width = weight.size(-1)
44
+
45
+ if dx is None:
46
+ dx = torch.empty_like(x)
47
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
48
+ dbias = None
49
+ if bias is not None:
50
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
51
+ dinitial_states = None
52
+ if return_dinitial_states:
53
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
+
55
+ ops.causal_conv1d_bwd(
56
+ x=x,
57
+ weight=weight,
58
+ bias=bias,
59
+ dout=dout,
60
+ seq_idx=seq_idx,
61
+ initial_states=initial_states,
62
+ dfinal_states=dfinal_states,
63
+ dx=dx,
64
+ dweight=dweight,
65
+ dbias=dbias,
66
+ dinitial_states=dinitial_states,
67
+ silu_activation=silu_activation,
68
+ )
69
+
70
+ dweight = dweight.type_as(weight)
71
+ if dbias is not None:
72
+ dbias = dbias.type_as(bias)
73
+ return dx, dweight, dbias, dinitial_states
74
+
75
+
76
+ def causal_conv1d_update_function(
77
+ x: torch.Tensor,
78
+ conv_state: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor | None,
81
+ silu_activation: bool,
82
+ cache_seqlens: torch.Tensor | None,
83
+ conv_state_indices: torch.Tensor | None,
84
+ ) -> torch.Tensor:
85
+ out = torch.empty_like(x)
86
+ ops.causal_conv1d_update(
87
+ x=x,
88
+ conv_state=conv_state,
89
+ weight=weight,
90
+ bias=bias,
91
+ out=out,
92
+ silu_activation=silu_activation,
93
+ cache_seqlens=cache_seqlens,
94
+ conv_state_indices=conv_state_indices,
95
+ )
96
+ return out