"""
Abovo: A C++ neural network engine with Python bindings for educational performance optimization.
"""
__version__ = "0.1.4"
from _abovo import (
Matrix as _Matrix,
DenseLayer as _DenseLayer,
Sequential as _Sequential,
LossType,
ActivationType,
MatMulType,
)
[docs]
class Matrix(_Matrix):
"""
A 2D matrix supporting basic operations like transpose, min/max lookup, and element access.
Args:
rows (int): Number of rows.
cols (int): Number of columns.
"""
[docs]
def get_rows(self) -> int:
"""Returns the number of rows in the matrix."""
return super().get_rows()
[docs]
def get_cols(self) -> int:
"""Returns the number of columns in the matrix."""
return super().get_cols()
[docs]
def transpose(self) -> "Matrix":
"""Returns a new matrix that is the transpose of this one."""
return super().transpose()
[docs]
def get_min(self) -> float:
"""Returns the minimum value in the matrix."""
return super().get_min()
[docs]
def get_max(self) -> float:
"""Returns the maximum value in the matrix."""
return super().get_max()
[docs]
def print(self):
"""Prints the matrix to stdout."""
return super().print()
[docs]
class DenseLayer(_DenseLayer):
"""
A fully connected neural network layer with optional quantization and activation.
Args:
input_size (int): Number of input features.
output_size (int): Number of output features.
activation (ActivationType): Activation function to apply.
"""
[docs]
def forward(self, input: Matrix) -> Matrix:
"""Performs a forward pass through the layer."""
return super().forward(input)
[docs]
def quantize(self):
"""Quantizes the layer weights (useful for efficient inference)."""
return super().quantize()
[docs]
def dequantize(self):
"""Dequantizes the layer weights."""
return super().dequantize()
[docs]
def is_quantized(self) -> bool:
"""Returns whether the layer is currently quantized."""
return super().is_quantized()
[docs]
def print(self):
"""Prints the layer's weights and configuration."""
return super().print()
[docs]
class Sequential(_Sequential):
"""
A sequential model for stacking layers and training.
Methods:
add(layer): Add a DenseLayer to the model.
forward(X): Perform a forward pass.
train(...): Train the model with given parameters.
evaluate(...): Evaluate model accuracy.
quantize_all(): Quantize all layers.
dequantize_all(): Revert quantization.
enable_qat(): Enable quantization-aware training.
enable_adam(...): Use Adam optimizer.
"""
[docs]
def add(self, layer: DenseLayer):
"""Adds a new DenseLayer to the model."""
return super().add(layer)
[docs]
def forward(self, X: Matrix) -> Matrix:
"""Feeds input X through the model and returns the output."""
return super().forward(X)
[docs]
def train(self, X, y, epochs, batch_size, learning_rate, loss_type=LossType.MSE):
"""
Trains the model on data using gradient descent.
Args:
X (Matrix): Input data.
y (Matrix): Target labels.
epochs (int): Number of training epochs.
batch_size (int): Batch size.
learning_rate (float): Learning rate.
loss_type (LossType): Loss function to use.
"""
return super().train(X, y, epochs, batch_size, learning_rate, loss_type)
[docs]
def evaluate(self, X, y):
"""Evaluates model performance on test data."""
return super().evaluate(X, y)
[docs]
def quantize_all(self, per_channel=True):
"""Quantizes all layers in the model."""
return super().quantize_all(per_channel)
[docs]
def dequantize_all(self):
"""Dequantizes all layers."""
return super().dequantize_all()
[docs]
def enable_qat(self):
"""Enables quantization-aware training."""
return super().enable_qat()
[docs]
def enable_adam(self, enable=True, beta1=0.9, beta2=0.999, epsilon=1e-8):
"""
Enables Adam optimizer with given hyperparameters.
Args:
enable (bool): Whether to use Adam.
beta1 (float): Exponential decay rate for 1st moment.
beta2 (float): Exponential decay rate for 2nd moment.
epsilon (float): Small value to avoid division by zero.
"""
return super().enable_adam(enable, beta1, beta2, epsilon)
[docs]
def print(self):
"""Prints the model structure and parameters."""
return super().print()
__all__ = [
"Matrix",
"DenseLayer",
"Sequential",
"LossType",
"ActivationType",
"MatMulType",
]