diff options
author | Zhongheng Liu <z.liu@outlook.com.gr> | 2025-01-25 14:36:32 +0200 |
---|---|---|
committer | Zhongheng Liu <z.liu@outlook.com.gr> | 2025-01-25 14:36:32 +0200 |
commit | 3c2dbe0d01bcf6d7597d9d418f0734c3e3605bcf (patch) | |
tree | e735d9112b6f8faa43bbeadc131cf4aa4f97af0f | |
parent | dd32b816d3680e0882177064c07955dbf491d2bc (diff) | |
download | matrix-rs-3c2dbe0d01bcf6d7597d9d418f0734c3e3605bcf.tar.gz matrix-rs-3c2dbe0d01bcf6d7597d9d418f0734c3e3605bcf.tar.bz2 matrix-rs-3c2dbe0d01bcf6d7597d9d418f0734c3e3605bcf.zip |
feat: support float and better test case
tests: use function to generate data, in future can use JSON test data
record to do it
matrix: use f32 type
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/matrix.rs | 105 | ||||
-rw-r--r-- | src/tests/matrix_test_ops.rs | 56 | ||||
-rw-r--r-- | src/tests/matrix_test_parse.rs | 2 |
4 files changed, 124 insertions, 41 deletions
@@ -28,7 +28,7 @@ pub use matrix::{Matrix, MatrixMath}; pub fn test() { println!("Testing code here"); - let m = Matrix::from(vec![1,2,3,4,5]); + let m = Matrix::from(vec![1.0,2.0,3.0,4.0,5.0]); m.transpose(); m.determinant(); } diff --git a/src/matrix.rs b/src/matrix.rs index 04a01bf..c324c36 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -10,14 +10,14 @@ //! println!("m1 + m2 =\n{}", m_add); //! ``` //! TODO:: Create matrix multiplication method - +use core::ops::AddAssign; use crate::error::{MatrixSetValueError, ParseMatrixError}; use std::{ fmt::Display, ops::{Add, Mul, Sub}, str::FromStr, }; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] pub struct Matrix { /// Number of rows in matrix. pub nrows: usize, @@ -26,12 +26,16 @@ pub struct Matrix { pub ncols: usize, /// Data stored in the matrix, you should not access this directly - data: Vec<Vec<i32>>, + data: Vec<Vec<f32>>, } pub trait MatrixMath { - fn inverse(&self) -> Matrix { - (1 / (self.determinant())) * &self.adjoint() + fn inverse(&self) -> Option<Matrix> { + let det_m = self.determinant(); + if det_m == 0.0 { + return None; + } + Some((1.0 / det_m) * &self.adjoint()) } /// Finds the matrix of cofactors for any N-by-N matrix fn cofactor(&self) -> Matrix { @@ -42,7 +46,7 @@ pub trait MatrixMath { todo!(); } /// Finds the determinant of any N-by-N matrix. - fn determinant(&self) -> i32 { + fn determinant(&self) -> f32 { todo!(); } /// Finds the transpose of any matrix. @@ -55,32 +59,58 @@ pub trait MatrixMath { } } impl MatrixMath for Matrix { + fn cofactor(&self) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + for (i, r) in self.data.iter().enumerate() { + let mut nr: Vec<f32> = Vec::new(); + for (j, v) in r.iter().enumerate() { + let count = self.ncols * i + j; + let nv = if count % 2 == 0 { -*v } else { *v }; + nr.push(nv); + } + d.push(nr); + } + Matrix::new(d) + } + fn minor(&self) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + for (i, r) in self.data.iter().enumerate() { + let mut nr: Vec<f32> = Vec::new(); + for (j, v) in r.iter().enumerate() { + let count = self.ncols * i + j; + let nv = if count % 2 == 0 { -*v } else { *v }; + nr.push(nv * self.splice(j, i).determinant()); + } + d.push(nr); + } + Matrix::new(d) + } /// Evaluates any N-by-N matrix. /// /// This function panics if the matrix is not square! - fn determinant(&self) -> i32 { + fn determinant(&self) -> f32 { if !self.is_square() { panic!() }; if self.nrows == 2 && self.ncols == 2 { return self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0]; } - let mut tmp = 0; + let mut tmp: f32 = 0.0; for (i, n) in self.data[0].iter().enumerate() { let mult = if i % 2 == 0 { -*n } else { *n }; - let eval = self.splice(i).determinant(); - tmp += mult * eval; + let eval = self.splice(i, 0).determinant(); + tmp = tmp + mult * f32::from(eval); } - tmp + tmp.into() } /// Evaluates the tranpose of the matrix. /// /// Each row becomes a column, each column becomes a row. fn transpose(&self) -> Matrix { - let mut new_data = Vec::<Vec<i32>>::new(); + let mut new_data = Vec::<Vec<f32>>::new(); for i in 0..self.nrows { - let mut new_row = Vec::<i32>::new(); + let mut new_row = Vec::<f32>::new(); for j in 0..self.ncols { new_row.push(self.data[j][i]); } @@ -100,17 +130,26 @@ impl Matrix { /// /// TODOs /// - Add row length check - pub fn new(data: Vec<Vec<i32>>) -> Matrix { + pub fn new(data: Vec<Vec<f32>>) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + + for r in &data { + let mut nr = vec![]; + for x in r { + nr.push(*x); + } + d.push(nr); + } Matrix { nrows: data.len(), ncols: data[0].len(), - data, + data: d, } } /// Query one element at selected position. /// /// Returns `None` if index is out of bounds. - pub fn get(&self, row_index: usize, column_index: usize) -> Option<i32> { + pub fn get(&self, row_index: usize, column_index: usize) -> Option<f32> { let r = self.data.get(row_index)?; let n = r.get(column_index)?; Some(*n) @@ -123,7 +162,7 @@ impl Matrix { &mut self, row_index: usize, column_index: usize, - new_data: i32, + new_data: f32, ) -> Result<(), MatrixSetValueError> { self.data[row_index][column_index] = new_data; Ok(()) @@ -133,13 +172,13 @@ impl Matrix { pub fn is_square(&self) -> bool { self.nrows == self.ncols } - pub fn splice(&self, at_index: usize) -> Matrix { - let mut data: Vec<Vec<i32>> = Vec::new(); + fn splice(&self, at_index: usize, at_row: usize) -> Matrix { + let mut data: Vec<Vec<f32>> = Vec::new(); for i in 0..self.data.len() { - if i == 0 { + if i == at_row { continue; } - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for j in 0..self.data[i].len() { if j == at_index { continue; @@ -154,12 +193,12 @@ impl Matrix { impl FromStr for Matrix { type Err = ParseMatrixError; fn from_str(s: &str) -> Result<Self, Self::Err> { - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); let rows_iter = s.split('\n'); for txt in rows_iter { - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for ch in txt.split(',') { - let parsed = match i32::from_str(ch) { + let parsed = match f32::from_str(ch) { Ok(n) => Ok(n), Err(_e) => Err(ParseMatrixError), }; @@ -218,12 +257,12 @@ impl<'a, 'b> Sub<&'b Matrix> for &'a Matrix { todo!() } } -impl<'a> Mul<&'a Matrix> for i32 { +impl<'a> Mul<&'a Matrix> for f32 { type Output = Matrix; fn mul(self, rhs: &'a Matrix) -> Self::Output { - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); for r in &rhs.data { - let mut nr: Vec<i32> = Vec::new(); + let mut nr: Vec<f32> = Vec::new(); for v in r { nr.push(self * v); } @@ -235,21 +274,21 @@ impl<'a> Mul<&'a Matrix> for i32 { impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix { type Output = Matrix; fn mul(self, rhs: &'b Matrix) -> Self::Output { - fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> i32 { - let mut tmp = 0; + fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> f32 { + let mut tmp = 0.0; for i in 0..lhs.ncols { tmp += lhs.get(at_r, i).unwrap() * rhs.get(i, at_c).unwrap(); } tmp } - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); if self.ncols != rhs.nrows { println!("LHS: \n{}RHS: \n{}", self, rhs); println!("LHS nrows: {} ;; RHS ncols: {}", self.nrows, rhs.ncols); panic!() } for i in 0..self.nrows { - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for j in 0..rhs.ncols { r.push(reduce(self, rhs, i, j)); } @@ -259,8 +298,8 @@ impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix { } } -impl From<Vec<i32>> for Matrix { - fn from(value: Vec<i32>) -> Self { +impl From<Vec<f32>> for Matrix { + fn from(value: Vec<f32>) -> Self { Matrix { nrows: value.len(), ncols: 1, diff --git a/src/tests/matrix_test_ops.rs b/src/tests/matrix_test_ops.rs index e207c3e..2737ac4 100644 --- a/src/tests/matrix_test_ops.rs +++ b/src/tests/matrix_test_ops.rs @@ -2,23 +2,67 @@ use std::str::FromStr; use crate::{error::ParseMatrixError, matrix::Matrix, MatrixMath}; +enum TestCaseType { + Add, + Mul, + Inv, + CmpErr, +} + +struct TestCase { + test_type: TestCaseType, + test_data: Vec<Matrix>, +} +fn build_add_test_cases() -> Vec<TestCase> { + let mut v = vec![]; + let from_strs = vec![ + "1,2,3\n4,5,6\n7,8,9", + "1,1,1\n1,1,1\n1,1,1", + "2,3,4\n5,6,7\n8,9,10", + + + "1,1,1\n1,1,1\n1,1,1", + "0,0,0\n0,0,0\n0,0,0", + "1,1,1\n1,1,1\n1,1,1", + ]; + let mut i = 0; + while i < from_strs.len() { + let m1 = Matrix::from_str(from_strs[i]).unwrap(); + let m2 = Matrix::from_str(from_strs[i+1]).unwrap(); + let mr = Matrix::from_str(from_strs[i+2]).unwrap(); + v.push(TestCase { + test_type: TestCaseType::Add, + test_data: vec![m1, m2, mr], + }); + i += 3; + } + v +} #[test] pub fn test_matrix_add() -> Result<(), ParseMatrixError> { - let m1 = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?; - let m2 = Matrix::from_str("1,1,1\n1,1,1\n1,1,1")?; - let t = Matrix::from_str("2,3,4\n5,6,7\n8,9,10")?; - assert_eq!(&m1 + &m2, t); + let cases = build_add_test_cases(); + for case in cases { + assert_eq!(&case.test_data[0] + &case.test_data[1], case.test_data[2]); + } Ok(()) } #[test] pub fn test_matrix_determinate() -> Result<(), ParseMatrixError> { let m = Matrix::from_str("3,4\n5,6")?; - let det = 3 * 6 - 4 * 5; + let det = 3.0 * 6.0 - 4.0 * 5.0; assert_eq!(m.determinant(), det); Ok(()) } #[test] -pub fn test_matrix_transposition() -> Result<(), ParseMatrixError> { +pub fn test_matrix_inverse_on_singular() -> Result<(), ()> { + let m = Matrix::new(vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]]); + match m.inverse() { + Some(_inverse) => Err(()), + None => Ok(()), + } +} +#[test] +pub fn test_matrix_transpose() -> Result<(), ParseMatrixError> { let m = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?; let t = Matrix::from_str("1,4,7\n2,5,8\n3,6,9")?; assert_eq!(m.transpose(), t); diff --git a/src/tests/matrix_test_parse.rs b/src/tests/matrix_test_parse.rs index a31851a..771eea7 100644 --- a/src/tests/matrix_test_parse.rs +++ b/src/tests/matrix_test_parse.rs @@ -4,7 +4,7 @@ use crate::{matrix::Matrix, error::ParseMatrixError}; #[test] pub fn test_matrix_init_from_string() -> Result<(), ParseMatrixError> { - let data_target = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; + let data_target = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], vec![7.0, 8.0, 9.0]]; let target = Matrix::new(data_target); let test = Matrix::from_str("1,2,3\n4,5,6\n7,8,9")?; assert_eq!(target, test); |