diff options
Diffstat (limited to 'src/matrix.rs')
-rw-r--r-- | src/matrix.rs | 105 |
1 files changed, 72 insertions, 33 deletions
diff --git a/src/matrix.rs b/src/matrix.rs index 04a01bf..c324c36 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -10,14 +10,14 @@ //! println!("m1 + m2 =\n{}", m_add); //! ``` //! TODO:: Create matrix multiplication method - +use core::ops::AddAssign; use crate::error::{MatrixSetValueError, ParseMatrixError}; use std::{ fmt::Display, ops::{Add, Mul, Sub}, str::FromStr, }; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] pub struct Matrix { /// Number of rows in matrix. pub nrows: usize, @@ -26,12 +26,16 @@ pub struct Matrix { pub ncols: usize, /// Data stored in the matrix, you should not access this directly - data: Vec<Vec<i32>>, + data: Vec<Vec<f32>>, } pub trait MatrixMath { - fn inverse(&self) -> Matrix { - (1 / (self.determinant())) * &self.adjoint() + fn inverse(&self) -> Option<Matrix> { + let det_m = self.determinant(); + if det_m == 0.0 { + return None; + } + Some((1.0 / det_m) * &self.adjoint()) } /// Finds the matrix of cofactors for any N-by-N matrix fn cofactor(&self) -> Matrix { @@ -42,7 +46,7 @@ pub trait MatrixMath { todo!(); } /// Finds the determinant of any N-by-N matrix. - fn determinant(&self) -> i32 { + fn determinant(&self) -> f32 { todo!(); } /// Finds the transpose of any matrix. @@ -55,32 +59,58 @@ pub trait MatrixMath { } } impl MatrixMath for Matrix { + fn cofactor(&self) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + for (i, r) in self.data.iter().enumerate() { + let mut nr: Vec<f32> = Vec::new(); + for (j, v) in r.iter().enumerate() { + let count = self.ncols * i + j; + let nv = if count % 2 == 0 { -*v } else { *v }; + nr.push(nv); + } + d.push(nr); + } + Matrix::new(d) + } + fn minor(&self) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + for (i, r) in self.data.iter().enumerate() { + let mut nr: Vec<f32> = Vec::new(); + for (j, v) in r.iter().enumerate() { + let count = self.ncols * i + j; + let nv = if count % 2 == 0 { -*v } else { *v }; + nr.push(nv * self.splice(j, i).determinant()); + } + d.push(nr); + } + Matrix::new(d) + } /// Evaluates any N-by-N matrix. /// /// This function panics if the matrix is not square! - fn determinant(&self) -> i32 { + fn determinant(&self) -> f32 { if !self.is_square() { panic!() }; if self.nrows == 2 && self.ncols == 2 { return self.data[0][0] * self.data[1][1] - self.data[0][1] * self.data[1][0]; } - let mut tmp = 0; + let mut tmp: f32 = 0.0; for (i, n) in self.data[0].iter().enumerate() { let mult = if i % 2 == 0 { -*n } else { *n }; - let eval = self.splice(i).determinant(); - tmp += mult * eval; + let eval = self.splice(i, 0).determinant(); + tmp = tmp + mult * f32::from(eval); } - tmp + tmp.into() } /// Evaluates the tranpose of the matrix. /// /// Each row becomes a column, each column becomes a row. fn transpose(&self) -> Matrix { - let mut new_data = Vec::<Vec<i32>>::new(); + let mut new_data = Vec::<Vec<f32>>::new(); for i in 0..self.nrows { - let mut new_row = Vec::<i32>::new(); + let mut new_row = Vec::<f32>::new(); for j in 0..self.ncols { new_row.push(self.data[j][i]); } @@ -100,17 +130,26 @@ impl Matrix { /// /// TODOs /// - Add row length check - pub fn new(data: Vec<Vec<i32>>) -> Matrix { + pub fn new(data: Vec<Vec<f32>>) -> Matrix { + let mut d: Vec<Vec<f32>> = Vec::new(); + + for r in &data { + let mut nr = vec![]; + for x in r { + nr.push(*x); + } + d.push(nr); + } Matrix { nrows: data.len(), ncols: data[0].len(), - data, + data: d, } } /// Query one element at selected position. /// /// Returns `None` if index is out of bounds. - pub fn get(&self, row_index: usize, column_index: usize) -> Option<i32> { + pub fn get(&self, row_index: usize, column_index: usize) -> Option<f32> { let r = self.data.get(row_index)?; let n = r.get(column_index)?; Some(*n) @@ -123,7 +162,7 @@ impl Matrix { &mut self, row_index: usize, column_index: usize, - new_data: i32, + new_data: f32, ) -> Result<(), MatrixSetValueError> { self.data[row_index][column_index] = new_data; Ok(()) @@ -133,13 +172,13 @@ impl Matrix { pub fn is_square(&self) -> bool { self.nrows == self.ncols } - pub fn splice(&self, at_index: usize) -> Matrix { - let mut data: Vec<Vec<i32>> = Vec::new(); + fn splice(&self, at_index: usize, at_row: usize) -> Matrix { + let mut data: Vec<Vec<f32>> = Vec::new(); for i in 0..self.data.len() { - if i == 0 { + if i == at_row { continue; } - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for j in 0..self.data[i].len() { if j == at_index { continue; @@ -154,12 +193,12 @@ impl Matrix { impl FromStr for Matrix { type Err = ParseMatrixError; fn from_str(s: &str) -> Result<Self, Self::Err> { - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); let rows_iter = s.split('\n'); for txt in rows_iter { - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for ch in txt.split(',') { - let parsed = match i32::from_str(ch) { + let parsed = match f32::from_str(ch) { Ok(n) => Ok(n), Err(_e) => Err(ParseMatrixError), }; @@ -218,12 +257,12 @@ impl<'a, 'b> Sub<&'b Matrix> for &'a Matrix { todo!() } } -impl<'a> Mul<&'a Matrix> for i32 { +impl<'a> Mul<&'a Matrix> for f32 { type Output = Matrix; fn mul(self, rhs: &'a Matrix) -> Self::Output { - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); for r in &rhs.data { - let mut nr: Vec<i32> = Vec::new(); + let mut nr: Vec<f32> = Vec::new(); for v in r { nr.push(self * v); } @@ -235,21 +274,21 @@ impl<'a> Mul<&'a Matrix> for i32 { impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix { type Output = Matrix; fn mul(self, rhs: &'b Matrix) -> Self::Output { - fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> i32 { - let mut tmp = 0; + fn reduce(lhs: &Matrix, rhs: &Matrix, at_r: usize, at_c: usize) -> f32 { + let mut tmp = 0.0; for i in 0..lhs.ncols { tmp += lhs.get(at_r, i).unwrap() * rhs.get(i, at_c).unwrap(); } tmp } - let mut d: Vec<Vec<i32>> = Vec::new(); + let mut d: Vec<Vec<f32>> = Vec::new(); if self.ncols != rhs.nrows { println!("LHS: \n{}RHS: \n{}", self, rhs); println!("LHS nrows: {} ;; RHS ncols: {}", self.nrows, rhs.ncols); panic!() } for i in 0..self.nrows { - let mut r: Vec<i32> = Vec::new(); + let mut r: Vec<f32> = Vec::new(); for j in 0..rhs.ncols { r.push(reduce(self, rhs, i, j)); } @@ -259,8 +298,8 @@ impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix { } } -impl From<Vec<i32>> for Matrix { - fn from(value: Vec<i32>) -> Self { +impl From<Vec<f32>> for Matrix { + fn from(value: Vec<f32>) -> Self { Matrix { nrows: value.len(), ncols: 1, |