diff options
author | Zhongheng Liu <z.liu@outlook.com.gr> | 2025-02-14 22:25:21 +0200 |
---|---|---|
committer | Zhongheng Liu <z.liu@outlook.com.gr> | 2025-02-14 22:25:21 +0200 |
commit | 65ecac02461d810fa851b80807289d3ae7c1495d (patch) | |
tree | adf411090238cbc5cb5ea75bed5294a4505b37dd | |
parent | ef9f07b1963e4a4b86f657d3981d38fc378f4073 (diff) | |
download | matrix-rs-65ecac02461d810fa851b80807289d3ae7c1495d.tar.gz matrix-rs-65ecac02461d810fa851b80807289d3ae7c1495d.tar.bz2 matrix-rs-65ecac02461d810fa851b80807289d3ae7c1495d.zip |
feat: create subtraction of matrices
-rw-r--r-- | src/matrix.rs | 28 | ||||
-rw-r--r-- | src/tests/matrix_test_ops.rs | 46 |
2 files changed, 58 insertions, 16 deletions
diff --git a/src/matrix.rs b/src/matrix.rs index c324c36..a81fcf0 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -10,8 +10,8 @@ //! println!("m1 + m2 =\n{}", m_add); //! ``` //! TODO:: Create matrix multiplication method -use core::ops::AddAssign; use crate::error::{MatrixSetValueError, ParseMatrixError}; +use core::ops::AddAssign; use std::{ fmt::Display, ops::{Add, Mul, Sub}, @@ -146,6 +146,9 @@ impl Matrix { data: d, } } + pub fn same_dimensions(&self, other: &Matrix) -> bool { + self.nrows == other.nrows && self.ncols == other.ncols + } /// Query one element at selected position. /// /// Returns `None` if index is out of bounds. @@ -172,6 +175,21 @@ impl Matrix { pub fn is_square(&self) -> bool { self.nrows == self.ncols } + pub fn permute_op(&self, op: fn(f32) -> f32) -> Matrix { + let mut i = 0; + let mut j = 0; + let mut rs: Vec<Vec<f32>> = Vec::new(); + while i < self.nrows { + let mut r: Vec<f32> = Vec::new(); + while j < self.ncols { + r.push(op(self.get(i, j).unwrap())); + j += 1; + } + rs.push(r); + i += 1; + } + Matrix::new(rs) + } fn splice(&self, at_index: usize, at_row: usize) -> Matrix { let mut data: Vec<Vec<f32>> = Vec::new(); for i in 0..self.data.len() { @@ -235,7 +253,7 @@ impl Display for Matrix { impl<'a, 'b> Add<&'b Matrix> for &'a Matrix { type Output = Matrix; fn add(self, rhs: &'b Matrix) -> Self::Output { - if (self.nrows != rhs.nrows) || (self.ncols != rhs.ncols) { + if !self.same_dimensions(rhs) { panic!("Cannot add two matrices with different dimensions"); } let mut x = Matrix { @@ -254,7 +272,11 @@ impl<'a, 'b> Add<&'b Matrix> for &'a Matrix { impl<'a, 'b> Sub<&'b Matrix> for &'a Matrix { type Output = Matrix; fn sub(self, rhs: &'b Matrix) -> Self::Output { - todo!() + if !self.same_dimensions(rhs) { + panic!("Cannot subtract two matrices with different dimensions"); + } + let neg = rhs.permute_op(|x| -x); + self - &neg } } impl<'a> Mul<&'a Matrix> for f32 { diff --git a/src/tests/matrix_test_ops.rs b/src/tests/matrix_test_ops.rs index 2737ac4..6933ac2 100644 --- a/src/tests/matrix_test_ops.rs +++ b/src/tests/matrix_test_ops.rs @@ -13,14 +13,35 @@ struct TestCase { test_type: TestCaseType, test_data: Vec<Matrix>, } +fn build_mul_test_cases() -> Vec<TestCase> { + let mut v = vec![]; + let from_strs = vec![ + "1,2\n1,2", + "1,3\n2,4", + "5,11\n5,11", + "1,2\n3,4", + "1\n2", + "5\n11", + ]; + let mut i = 0; + while i < from_strs.len() { + let m1 = Matrix::from_str(from_strs[i]).unwrap(); + let m2 = Matrix::from_str(from_strs[i + 1]).unwrap(); + let re = Matrix::from_str(from_strs[i + 2]).unwrap(); + v.push(TestCase { + test_type: TestCaseType::Mul, + test_data: vec![m1, m2, re], + }); + i += 3; + } + v +} fn build_add_test_cases() -> Vec<TestCase> { let mut v = vec![]; let from_strs = vec![ "1,2,3\n4,5,6\n7,8,9", "1,1,1\n1,1,1\n1,1,1", "2,3,4\n5,6,7\n8,9,10", - - "1,1,1\n1,1,1\n1,1,1", "0,0,0\n0,0,0\n0,0,0", "1,1,1\n1,1,1\n1,1,1", @@ -28,8 +49,8 @@ fn build_add_test_cases() -> Vec<TestCase> { let mut i = 0; while i < from_strs.len() { let m1 = Matrix::from_str(from_strs[i]).unwrap(); - let m2 = Matrix::from_str(from_strs[i+1]).unwrap(); - let mr = Matrix::from_str(from_strs[i+2]).unwrap(); + let m2 = Matrix::from_str(from_strs[i + 1]).unwrap(); + let mr = Matrix::from_str(from_strs[i + 2]).unwrap(); v.push(TestCase { test_type: TestCaseType::Add, test_data: vec![m1, m2, mr], @@ -55,7 +76,11 @@ pub fn test_matrix_determinate() -> Result<(), ParseMatrixError> { } #[test] pub fn test_matrix_inverse_on_singular() -> Result<(), ()> { - let m = Matrix::new(vec![vec![1.0,2.0,3.0], vec![4.0,5.0,6.0], vec![7.0,8.0,9.0]]); + let m = Matrix::new(vec![ + vec![1.0, 2.0, 3.0], + vec![4.0, 5.0, 6.0], + vec![7.0, 8.0, 9.0], + ]); match m.inverse() { Some(_inverse) => Err(()), None => Ok(()), @@ -70,14 +95,9 @@ pub fn test_matrix_transpose() -> Result<(), ParseMatrixError> { } #[test] pub fn test_matrix_mul() -> Result<(), ParseMatrixError> { - let m1 = Matrix::from_str("1,2\n1,2")?; - let m2 = Matrix::from_str("1,3\n2,4")?; - let m3 = Matrix::from_str("1,2\n3,4")?; - let m4 = Matrix::from_str("1\n2")?; - let t1 = Matrix::from_str("5,11\n5,11")?; - let t2 = Matrix::from_str("5\n11")?; - assert_eq!(&m1 * &m2, t1); - assert_eq!(&m3 * &m4, t2); + for case in build_mul_test_cases() { + assert_eq!(&case.test_data[0] * &case.test_data[1], case.test_data[2]); + } Ok(()) } #[test] |