Regularization¶
from noxton.nn import StochasticDepth
StochasticDepth¶
Stochastic Depth (DropPath) regularisation layer (Huang et al., 2016). During training, randomly drops the entire input tensor (mode="batch") or individual rows (mode="row") with probability p. Surviving elements are rescaled by 1 / (1 - p) to preserve expected values. At inference the layer is a no-op.
Constructor¶
StochasticDepth(
p: float,
mode: str, # "batch" or "row"
inference: bool = False,
)
| Parameter | Description |
|---|---|
p |
Drop probability in [0, 1]. |
mode |
"batch" — single mask broadcast over whole tensor; "row" — independent mask per row. |
inference |
Always pass through unchanged. Default False. |
__call__¶
sd(input: Array, key: PRNGKeyArray) -> Array
Example¶
import jax
import jax.numpy as jnp
from noxton.nn import StochasticDepth
key = jax.random.PRNGKey(0)
sd = StochasticDepth(p=0.2, mode="row")
x = jnp.ones((4, 8))
out = sd(x, key)
# out.shape -> (4, 8)
# during training some rows may be zeroed out
Tip
To use in a residual block, apply stochastic depth to the residual branch:
python
x = x + sd(residual_branch(x), key)