diff --git a/kernels/src/compose.rs b/kernels/src/compose.rs index 709a6232c..1dad7f655 100644 --- a/kernels/src/compose.rs +++ b/kernels/src/compose.rs @@ -1,16 +1,18 @@ +use crate::device::ScalarArgument; use crate::Scalar; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use std::fmt; use telamon::helper::tensor::*; use telamon::helper::{AutoOperand, Builder, Reduce}; use telamon::ir; /// Multiplies a matrix `lhs` with a vector `rhs` -pub fn matrix_vector_multiply( +pub fn matrix_vector_multiply<'a, S: ScalarArgument>( builder: &mut Builder, - lhs: &VirtualTensor, - rhs: &VirtualTensor, -) -> VirtualTensor { + lhs: &VirtualTensor<'_, S>, + rhs: &VirtualTensor<'_, S>, +) -> VirtualTensor<'a, S> { assert!(lhs.num_dims() == 2 && rhs.num_dims() == 1); assert!(lhs[lhs.num_dims() - 1].size_eq(&rhs[0], builder.function())); @@ -44,11 +46,11 @@ pub fn matrix_vector_multiply( } /// Multiplies two matrices `lhs` and `rhs` -pub fn matrix_matrix_multiply( +pub fn matrix_matrix_multiply<'a, S: ScalarArgument>( builder: &mut Builder, - lhs: &VirtualTensor, - rhs: &VirtualTensor, -) -> VirtualTensor { + lhs: &VirtualTensor, + rhs: &VirtualTensor, +) -> VirtualTensor<'a, S> { assert!(lhs.num_dims() == 2 && rhs.num_dims() == 2); assert!(lhs[lhs.num_dims() - 1].size_eq(&rhs[0], builder.function())); @@ -93,11 +95,11 @@ pub fn matrix_matrix_multiply( } /// Adds two tensors `lhs` and `rhs` of the same shape -pub fn tensor_add( +pub fn tensor_add<'a, S: ScalarArgument>( builder: &mut Builder, - lhs: &VirtualTensor, - rhs: &VirtualTensor, -) -> VirtualTensor { + lhs: &VirtualTensor, + rhs: &VirtualTensor, +) -> VirtualTensor<'a, S> { assert!(lhs.same_shape(rhs, builder.function())); let dims = lhs @@ -128,12 +130,12 @@ pub fn tensor_add( /// Multiplies all elements of `lhs_mul` with `rhs_mul_operand` and /// adds the result to the tensor `rhs_add` -pub fn tensor_mad( +pub fn tensor_mad<'a, S: ScalarArgument>( builder: &mut Builder, - lhs_mul: &VirtualTensor, + lhs_mul: &VirtualTensor, rhs_mul_operand: &dyn AutoOperand, - rhs_add: &VirtualTensor, -) -> VirtualTensor { + rhs_add: &VirtualTensor, +) -> VirtualTensor<'a, S> { assert!(lhs_mul.same_shape(rhs_add, builder.function())); let dims = lhs_mul @@ -168,11 +170,11 @@ pub fn tensor_mad( /// instructions created by `f` using the builder will be placed in a /// set of dimensions mapped to the dimensions of the virtual input /// tensor `a`. -pub fn tensor_map( +pub fn tensor_map<'a, S: ScalarArgument>( builder: &mut Builder, - a: &VirtualTensor, + a: &VirtualTensor, f: impl FnOnce(&ir::Operand<()>, &mut Builder) -> ir::InstId, -) -> VirtualTensor { +) -> VirtualTensor<'a, S> { let dims = a .iter() .map(|dim| builder.open_mapped_dim(&dim)) @@ -193,13 +195,25 @@ pub fn tensor_map( VirtualTensor::new(res_instr, dims) } +/// Divides each element of a virtual tensor `t` by a scalar +/// operand `s` +pub fn tensor_elementwise_div<'a, S: ScalarArgument>( + builder: &mut Builder, + t: &VirtualTensor, + s: &dyn AutoOperand, +) -> VirtualTensor<'a, S> { + tensor_map(builder, t, |tensor_operand, builder| { + builder.div(tensor_operand, s) + }) +} + /// Multiplies each element of a virtual tensor `rhs` with a scalar /// operand `lhs` -pub fn tensor_elementwise_mul( +pub fn tensor_elementwise_mul<'a, S: ScalarArgument>( builder: &mut Builder, lhs: &dyn AutoOperand, - rhs: &VirtualTensor, -) -> VirtualTensor { + rhs: &VirtualTensor, +) -> VirtualTensor<'a, S> { tensor_map(builder, rhs, |tensor_operand, builder| { builder.mul(tensor_operand, lhs) }) @@ -207,17 +221,17 @@ pub fn tensor_elementwise_mul( /// Applies the `max` function to all elements of a virtual tensor /// `lhs` with `rhs` as the second argument to `max` -pub fn tensor_elementwise_max( +pub fn tensor_elementwise_max<'a, S: ScalarArgument>( builder: &mut Builder, - lhs: &VirtualTensor, + lhs: &VirtualTensor, rhs: &dyn AutoOperand, -) -> VirtualTensor { +) -> VirtualTensor<'a, S> { tensor_map(builder, lhs, |tensor_operand, builder| { builder.max(tensor_operand, rhs) }) } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Debug, Copy)] pub enum ActivationFunction { /// Linear rectifier (i.e., max(0, v)) ReLU, @@ -226,21 +240,147 @@ pub enum ActivationFunction { Sigmoid, } +impl std::fmt::Display for ActivationFunction { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> fmt::Result { + match *self { + ActivationFunction::ReLU => fmt.write_str("relu"), + ActivationFunction::Sigmoid => fmt.write_str("sigmoid"), + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ActivationFunctionParseError { + token: String, +} + +impl fmt::Display for ActivationFunctionParseError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "cannot parse activation function '{}'", self.token) + } +} + impl ActivationFunction { - /// Creates a new virtual tensor by applying the activation - /// function to each element of `t` - pub fn apply( - &self, - builder: &mut Builder, - t: &VirtualTensor, - ) -> VirtualTensor { - match self { - ActivationFunction::ReLU => tensor_elementwise_max(builder, t, &S::zero()), - ActivationFunction::Sigmoid => tensor_map(builder, t, |operand, builder| { + pub fn opt_to_display( + activation_opt: &Option, + ) -> OptionDisplay { + OptionDisplay { + inner: &activation_opt, + default: "identity", + } + } + + pub fn opt_from_string( + s: &str, + ) -> Result, ActivationFunctionParseError> { + match s { + "identity" => Ok(None), + "relu" => Ok(Some(ActivationFunction::ReLU)), + "sigmoid" => Ok(Some(ActivationFunction::Sigmoid)), + _ => Err(ActivationFunctionParseError { + token: s.to_string(), + }), + } + } +} + +pub struct OptionDisplay<'a, T> { + inner: &'a Option, + default: &'a str, +} + +impl fmt::Display for OptionDisplay<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(inner) = self.inner { + fmt::Display::fmt(inner, fmt) + } else { + fmt.write_str(self.default) + } + } +} + +/// Applies an optional activation function element-wise to the +/// virtual tensor `a` and returns a new virtual tensor with the +/// result. If no activation function has been specified, the instance +/// itself is returned. +pub fn tensor_activate<'a, S: Scalar>( + builder: &mut Builder, + t: VirtualTensor<'a, S>, + f: &Option, +) -> VirtualTensor<'a, S> { + match f { + Some(ActivationFunction::ReLU) => tensor_elementwise_max(builder, &t, &S::zero()), + Some(ActivationFunction::Sigmoid) => { + tensor_map(builder, &t, |operand, builder| { let exp = builder.exp(operand); let add = builder.add(&S::one(), &exp); builder.div(&S::one(), &add) - }), + }) + } + None => t, + } +} + +/// Applies an optional activation function element-wise and in place +/// to the Array `a`. If no activation function has been specified, +/// `a` is left unmodified. +pub fn array_activate_inplace( + a: &mut ndarray::Array, + f: &Option, +) where + S: Scalar, + D: ndarray::Dimension, +{ + match f { + Some(ActivationFunction::ReLU) => { + a.mapv_inplace(|c| c.max(S::zero())); + } + Some(ActivationFunction::Sigmoid) => { + let one = S::one(); + a.mapv_inplace(|c| one / (one + S::exp(c))); } + None => {} } } + +/// Applies the softmax function to an array `a` in place, i.e., each +/// element `o_i` of the result has the value `o_i = exp(a_i) / +/// sum(j = 0 to N, exp(a_j))`, where `a_i` is the i-th value of the +/// input array `a` and `N` is the number of elements of `a`. +pub fn array_softmax_inplace(a: &mut ndarray::Array) +where + S: Scalar, + D: ndarray::Dimension, +{ + a.mapv_inplace(|c| S::exp(c)); + let sum = a.scalar_sum(); + a.mapv_inplace(|c| c / sum); +} + +/// Calculates the sum of all elements of a tensor `t` and returns it +/// as a 0-dimensional virtual tensor. +pub fn tensor_sum<'a, S: ScalarArgument>( + builder: &mut Builder, + t: &VirtualTensor, +) -> VirtualTensor<'a, S> { + let sum_init_instr = builder.mov(&0f32); + + let dims = t + .iter() + .map(|dim| builder.open_mapped_dim(dim)) + .collect_vec(); + + let t_operand = t.dim_map( + &dims.iter().collect_vec(), + ir::DimMapScope::Global(()), + builder, + ); + + let sum_instr = builder.add(&t_operand, &Reduce(sum_init_instr)); + + for dim in &dims { + builder.close_dim(&dim); + } + + VirtualTensor::new(sum_instr, vec![]) +} diff --git a/kernels/src/linalg.rs b/kernels/src/linalg.rs index e00c57f28..2522dc8ab 100644 --- a/kernels/src/linalg.rs +++ b/kernels/src/linalg.rs @@ -2,13 +2,15 @@ #![allow(clippy::many_single_char_names)] use std::sync::Arc; -use crate::compose::{ - matrix_matrix_multiply, matrix_vector_multiply, tensor_elementwise_mul, tensor_mad, - ActivationFunction, -}; +pub use crate::compose; use crate::kernel::Kernel; use crate::{build_candidate, check_output, create_size, infer_tiling, Scalar}; -use ::ndarray::{Array1, Array2, Array3, ArrayD}; +use ::ndarray::{Array0, Array1, Array2, Array3, ArrayD}; +use compose::{ + array_activate_inplace, array_softmax_inplace, matrix_matrix_multiply, + matrix_vector_multiply, tensor_activate, tensor_add, tensor_elementwise_div, + tensor_elementwise_mul, tensor_mad, tensor_map, tensor_sum, ActivationFunction, +}; use rand; use serde::{Deserialize, Serialize}; use telamon::explorer::Candidate; @@ -388,12 +390,9 @@ impl<'a, S: Scalar> Kernel<'a> for FusedMM<'a, S> { let ab = matrix_matrix_multiply(&mut builder, &a, &b); - if let Some(activation_fun) = &self.params.activation_fun { - let res = activation_fun.apply::(&mut builder, &ab); - res.store(&self.c, &mut builder); - } else { - ab.store(&self.c, &mut builder); - } + let res = tensor_activate(&mut builder, ab, &self.params.activation_fun); + + res.store(&self.c, &mut builder); vec![build_candidate(builder.get(), ctx)] } @@ -403,20 +402,9 @@ impl<'a, S: Scalar> Kernel<'a> for FusedMM<'a, S> { let b_shape = (self.params.k as usize, self.params.n as usize); let a = unwrap!(self.a.read_to_host(context).into_shape(a_shape)); let b = unwrap!(self.b.read_to_host(context).into_shape(b_shape)); - let mut res = a.dot(&b); - - match self.params.activation_fun { - Some(ActivationFunction::ReLU) => { - res.mapv_inplace(|c| c.max(S::zero())); - } - - Some(ActivationFunction::Sigmoid) => { - let one = S::one(); - res.mapv_inplace(|c| one / (one + S::exp(c))); - } - None => {} - }; + let mut res = a.dot(&b); + array_activate_inplace(&mut res, &self.params.activation_fun); res } @@ -777,12 +765,8 @@ impl<'a, S: Scalar> Kernel<'a> for Fused2MM<'a, S> { let aabc = matrix_matrix_multiply(&mut builder, &aab, &c); let aabcpbd = tensor_mad(&mut builder, &d, &"beta", &aabc); - if let Some(activation_fun) = &self.params.activation_fun { - let res = activation_fun.apply::(&mut builder, &aabcpbd); - res.store(&self.e, &mut builder); - } else { - aabcpbd.store(&self.e, &mut builder); - } + let res = tensor_activate(&mut builder, aabcpbd, &self.params.activation_fun); + res.store(&self.e, &mut builder); let candidate = build_candidate(builder.get(), ctx); @@ -805,18 +789,7 @@ impl<'a, S: Scalar> Kernel<'a> for Fused2MM<'a, S> { let bd = d.mapv(|x| x * S::from(self.params.beta).unwrap()); let mut aabcpbd = aabc + bd; - match self.params.activation_fun { - Some(ActivationFunction::ReLU) => { - aabcpbd.mapv_inplace(|c| c.max(S::zero())); - } - - Some(ActivationFunction::Sigmoid) => { - let one = S::one(); - aabcpbd.mapv_inplace(|c| one / (one + S::exp(c))); - } - - None => {} - }; + array_activate_inplace(&mut aabcpbd, &self.params.activation_fun); aabcpbd } @@ -835,3 +808,939 @@ impl<'a, S: Scalar> Kernel<'a> for Fused2MM<'a, S> { } } } + +#[derive(Clone, Deserialize, Serialize)] +pub struct ResNetCellP { + pub m: i32, + pub n: i32, + pub k: i32, + pub transpose_a: bool, + pub transpose_b: bool, + pub transpose_c: bool, + pub generic: bool, + pub m_tiling: Option, + pub n_tiling: Option, + pub k_tiling: Option, + pub activation_fun: Option, +} + +impl ResNetCellP { + pub fn new(m: i32, n: i32, k: i32) -> Self { + ResNetCellP { + m, + n, + k, + transpose_a: false, + transpose_b: false, + transpose_c: false, + generic: true, + m_tiling: None, + n_tiling: None, + k_tiling: None, + activation_fun: None, + } + } + + pub fn transpose_a(mut self) -> Self { + self.transpose_a = true; + self + } + + pub fn transpose_b(mut self) -> Self { + self.transpose_b = true; + self + } + + pub fn transpose_c(mut self) -> Self { + self.transpose_c = true; + self + } + + pub fn activation_fun(mut self, fun: F) -> Self + where + F: Into>, + { + self.activation_fun = fun.into(); + self + } + + /// Inline the sizes in the generated code. + pub fn static_sizes(mut self) -> Self { + self.generic = false; + self + } +} + +/// Computes `O = activation(activation(A.B).C) + A` +pub struct ResNetCell<'a, S: Scalar> { + pub params: ResNetCellP, + a: Tensor<'a, S>, + b: Tensor<'a, S>, + c: Tensor<'a, S>, + o: Tensor<'a, S>, +} + +impl<'a, S: Scalar> Kernel<'a> for ResNetCell<'a, S> { + type Parameters = ResNetCellP; + type ExpectedOutput = Array2; + + fn name() -> &'static str { + "resnetcell" + } + + fn build_signature( + params: ResNetCellP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + let m_size = create_size(params.m, "m", params.generic, builder); + let n_size = create_size(params.n, "n", params.generic, builder); + let k_size = create_size(params.k, "k", params.generic, builder); + + let a = TensorBuilder::new("a", vec![m_size.clone(), k_size.clone()]) + .doif(params.transpose_a, |b| b.transpose(0, 1)) + .finish(builder); + let b = TensorBuilder::new("b", vec![k_size.clone(), n_size.clone()]) + .doif(params.transpose_b, |b| b.transpose(0, 1)) + .finish(builder); + + let c = TensorBuilder::new("c", vec![n_size, k_size.clone()]) + .doif(params.transpose_c, |b| b.transpose(0, 1)) + .finish(builder); + + let o = builder.tensor::("o", vec![m_size, k_size], false); + ResNetCell { params, a, b, c, o } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + let m_tiling = infer_tiling(self.params.m, &self.params.m_tiling, &[32, 4]); + let n_tiling = infer_tiling(self.params.n, &self.params.n_tiling, &[32, 4]); + let k_tiling = infer_tiling(self.params.k, &self.params.k_tiling, &[32]); + + let mut builder = helper::Builder::new(signature, ctx.device()); + + let a = self + .a + .load(vec![m_tiling.clone(), k_tiling.clone()], &mut builder); + let b = self + .b + .load(vec![k_tiling.clone(), n_tiling.clone()], &mut builder); + let c = self + .c + .load(vec![n_tiling.clone(), k_tiling.clone()], &mut builder); + + let ab = matrix_matrix_multiply(&mut builder, &a, &b); + let act_ab = tensor_activate::(&mut builder, ab, &self.params.activation_fun); + let act_ab_c = matrix_matrix_multiply(&mut builder, &act_ab, &c); + let act_act_ab_c = + tensor_activate::(&mut builder, act_ab_c, &self.params.activation_fun); + + let a_copy = a.duplicate(&mut builder); + + let act_act_ab_c_pa = tensor_add(&mut builder, &act_act_ab_c, &a_copy); + + act_act_ab_c_pa.store(&self.o, &mut builder); + + let candidate = build_candidate(builder.get(), ctx); + + vec![candidate] + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Array2 { + let a_shape = (self.params.m as usize, self.params.k as usize); + let b_shape = (self.params.k as usize, self.params.n as usize); + let c_shape = (self.params.n as usize, self.params.k as usize); + + let a = unwrap!(self.a.read_to_host(context).into_shape(a_shape)); + let b = unwrap!(self.b.read_to_host(context).into_shape(b_shape)); + let c = unwrap!(self.c.read_to_host(context).into_shape(c_shape)); + + let mut ab = a.dot(&b); + array_activate_inplace(&mut ab, &self.params.activation_fun); + + let mut act_ab_c = ab.dot(&c); + array_activate_inplace(&mut act_ab_c, &self.params.activation_fun); + + act_ab_c + a + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + let o_shape = (self.params.m as usize, self.params.k as usize); + let o = unwrap!(self.o.read_to_host(context).into_shape(o_shape)); + if let Err(invalid) = check_output(&o, expected) { + Err(format!("Invalid resnetcell output: {}", invalid)) + } else { + Ok(()) + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct ResNetCellTopHalfP { + pub mm_params: FusedMMP, +} + +impl ResNetCellTopHalfP { + pub fn new(m: i32, n: i32, k: i32, activation_fun: F) -> Self + where + F: Into>, + { + ResNetCellTopHalfP { + mm_params: FusedMMP::new(m, n, k).activation_fun(activation_fun), + } + } +} + +/// Computes `O = activation(A.B)` +pub struct ResNetCellTopHalf<'a, S: Scalar> { + fmmp: FusedMM<'a, S>, +} + +impl<'a, S: Scalar> Kernel<'a> for ResNetCellTopHalf<'a, S> { + type Parameters = ResNetCellTopHalfP; + type ExpectedOutput = Array2; + + fn name() -> &'static str { + "resnetcelltophalf" + } + + fn build_signature( + params: ResNetCellTopHalfP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + ResNetCellTopHalf { + fmmp: FusedMM::build_signature(params.mm_params, builder), + } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + self.fmmp.build_body(signature, ctx) + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Array2 { + self.fmmp.get_expected_output(context) + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + self.fmmp.check_result(expected, context) + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct ResNetCellBottomHalfP { + pub m: i32, + pub n: i32, + pub k: i32, + pub transpose_actab: bool, + pub transpose_a: bool, + pub transpose_c: bool, + pub transpose_o: bool, + pub generic: bool, + pub m_tiling: Option, + pub n_tiling: Option, + pub k_tiling: Option, + pub activation_fun: Option, +} + +impl ResNetCellBottomHalfP { + pub fn new(m: i32, n: i32, k: i32, activation_fun: F) -> Self + where + F: Into>, + { + ResNetCellBottomHalfP { + m, + n, + k, + transpose_actab: false, + transpose_a: false, + transpose_c: false, + transpose_o: false, + generic: true, + m_tiling: None, + n_tiling: None, + k_tiling: None, + activation_fun: activation_fun.into(), + } + } + + pub fn transpose_actab(mut self) -> Self { + self.transpose_actab = true; + self + } + + pub fn transpose_a(mut self) -> Self { + self.transpose_a = true; + self + } + + pub fn transpose_c(mut self) -> Self { + self.transpose_c = true; + self + } + + pub fn transpose_o(mut self) -> Self { + self.transpose_o = true; + self + } + + pub fn activation_fun(mut self, fun: F) -> Self + where + F: Into>, + { + self.activation_fun = fun.into(); + self + } + + /// Inline the sizes in the generated code. + pub fn static_sizes(mut self) -> Self { + self.generic = false; + self + } +} + +/// Computes `O = activation(ACTAB.C)+A` +pub struct ResNetCellBottomHalf<'a, S: Scalar> { + pub params: ResNetCellBottomHalfP, + act_ab: Tensor<'a, S>, + c: Tensor<'a, S>, + a: Tensor<'a, S>, + o: Tensor<'a, S>, +} + +impl<'a, S: Scalar> Kernel<'a> for ResNetCellBottomHalf<'a, S> { + type Parameters = ResNetCellBottomHalfP; + type ExpectedOutput = Array2; + + fn name() -> &'static str { + "resnetcellbottomhalf" + } + + fn build_signature( + params: ResNetCellBottomHalfP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + let m_size = create_size(params.m, "m", params.generic, builder); + let n_size = create_size(params.n, "n", params.generic, builder); + let k_size = create_size(params.k, "k", params.generic, builder); + + let act_ab = TensorBuilder::new("act_ab", vec![m_size.clone(), n_size.clone()]) + .doif(params.transpose_actab, |b| b.transpose(0, 1)) + .finish(builder); + let a = TensorBuilder::new("a", vec![m_size.clone(), k_size.clone()]) + .doif(params.transpose_a, |b| b.transpose(0, 1)) + .finish(builder); + let c = TensorBuilder::new("c", vec![n_size, k_size.clone()]) + .doif(params.transpose_c, |b| b.transpose(0, 1)) + .finish(builder); + let o = builder.tensor::("o", vec![m_size, k_size], false); + + ResNetCellBottomHalf { + params, + act_ab, + c, + a, + o, + } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + let m_tiling = infer_tiling(self.params.m, &self.params.m_tiling, &[32, 4]); + let n_tiling = infer_tiling(self.params.n, &self.params.n_tiling, &[32, 4]); + let k_tiling = infer_tiling(self.params.k, &self.params.k_tiling, &[32, 4]); + + let mut builder = helper::Builder::new(signature, ctx.device()); + + let act_ab = self + .act_ab + .load(vec![m_tiling.clone(), n_tiling.clone()], &mut builder); + let c = self.c.load(vec![n_tiling, k_tiling.clone()], &mut builder); + let a = self.a.load(vec![m_tiling, k_tiling], &mut builder); + + let act_ab_c = matrix_matrix_multiply(&mut builder, &act_ab, &c); + let act_act_ab_c = + tensor_activate(&mut builder, act_ab_c, &self.params.activation_fun); + + let act_act_ab_c_pa = tensor_add(&mut builder, &act_act_ab_c, &a); + + act_act_ab_c_pa.store(&self.o, &mut builder); + + vec![build_candidate(builder.get(), ctx)] + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Array2 { + let act_ab_shape = (self.params.m as usize, self.params.n as usize); + let c_shape = (self.params.n as usize, self.params.k as usize); + let a_shape = (self.params.m as usize, self.params.k as usize); + + let act_ab = unwrap!(self.act_ab.read_to_host(context).into_shape(act_ab_shape)); + let c = unwrap!(self.c.read_to_host(context).into_shape(c_shape)); + let a = unwrap!(self.a.read_to_host(context).into_shape(a_shape)); + + let mut act_ab_c = act_ab.dot(&c); + array_activate_inplace(&mut act_ab_c, &self.params.activation_fun); + let act_act_ab_c_pa = act_ab_c + a; + + act_act_ab_c_pa + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + let o_shape = (self.params.n as usize, self.params.k as usize); + let o = unwrap!(self.o.read_to_host(context).into_shape(o_shape)); + if let Err(invalid) = check_output(&o, expected) { + Err(format!("Invalid resnetcellbottomhalf output: {}", invalid)) + } else { + Ok(()) + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct TransformerCellP { + pub m: i32, + pub n: i32, + pub p: i32, + pub r: i32, + pub transpose_q: bool, + pub transpose_k: bool, + pub transpose_v: bool, + pub generic: bool, + pub m_tiling: Option, + pub n_tiling: Option, + pub p_tiling: Option, + pub r_tiling: Option, +} + +impl TransformerCellP { + pub fn new(m: i32, n: i32, p: i32, r: i32) -> Self { + TransformerCellP { + m, + n, + p, + r, + transpose_q: false, + transpose_k: false, + transpose_v: false, + generic: true, + m_tiling: None, + n_tiling: None, + p_tiling: None, + r_tiling: None, + } + } + + pub fn transpose_q(mut self) -> Self { + self.transpose_q = true; + self + } + + pub fn transpose_k(mut self) -> Self { + self.transpose_k = true; + self + } + + pub fn transpose_v(mut self) -> Self { + self.transpose_v = true; + self + } + + /// Inline the sizes in the generated code. + pub fn static_sizes(mut self) -> Self { + self.generic = false; + self + } +} + +/// Computes `O = softmax(scale(Q.K)).V` +pub struct TransformerCell<'a, S: Scalar> { + pub params: TransformerCellP, + q: Tensor<'a, S>, + k: Tensor<'a, S>, + v: Tensor<'a, S>, + o: Tensor<'a, S>, +} + +impl<'a, S: Scalar> TransformerCell<'a, S> { + fn scaling_factor(&self) -> S { + S::from(1f64 / f64::sqrt(self.params.p as f64 * self.params.n as f64)).unwrap() + } +} + +impl<'a, S: Scalar> Kernel<'a> for TransformerCell<'a, S> { + type Parameters = TransformerCellP; + type ExpectedOutput = Array2; + + fn name() -> &'static str { + "transformercell" + } + + fn build_signature( + params: TransformerCellP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + let m_size = create_size(params.m, "m", params.generic, builder); + let n_size = create_size(params.n, "n", params.generic, builder); + let p_size = create_size(params.p, "p", params.generic, builder); + let r_size = create_size(params.r, "r", params.generic, builder); + + let q = TensorBuilder::new("q", vec![m_size.clone(), p_size.clone()]) + .doif(params.transpose_q, |b| b.transpose(0, 1)) + .finish(builder); + + let k = TensorBuilder::new("k", vec![p_size, n_size.clone()]) + .doif(params.transpose_k, |b| b.transpose(0, 1)) + .finish(builder); + + let v = TensorBuilder::new("v", vec![n_size, r_size.clone()]) + .doif(params.transpose_v, |b| b.transpose(0, 1)) + .finish(builder); + + let o = builder.tensor::("o", vec![m_size, r_size], false); + TransformerCell { params, q, k, v, o } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + let m_tiling = infer_tiling(self.params.m, &self.params.m_tiling, &[32, 4]); + let n_tiling = infer_tiling(self.params.n, &self.params.n_tiling, &[32, 4]); + let p_tiling = infer_tiling(self.params.p, &self.params.p_tiling, &[32, 4]); + let r_tiling = infer_tiling(self.params.r, &self.params.r_tiling, &[32, 4]); + + let mut builder = helper::Builder::new(signature, ctx.device()); + + let q = self.q.load(vec![m_tiling, p_tiling.clone()], &mut builder); + let k = self + .k + .load(vec![p_tiling.clone(), n_tiling.clone()], &mut builder); + let v = self.v.load(vec![n_tiling.clone(), r_tiling], &mut builder); + + let qk = matrix_matrix_multiply(&mut builder, &q, &k); + let qk_scaled = tensor_elementwise_mul(&mut builder, &self.scaling_factor(), &qk); + let qk_scaled_exp = tensor_map(&mut builder, &qk_scaled, |telem, builder| { + builder.exp(telem) + }); + + let sum = tensor_sum(&mut builder, &qk_scaled_exp); + let sum_op = sum.dim_map(&[], ir::DimMapScope::Global(()), &mut builder); + + let q_copy = q.duplicate(&mut builder); + let k_copy = k.duplicate(&mut builder); + + let qk_copy = matrix_matrix_multiply(&mut builder, &q_copy, &k_copy); + let qk_copy_scaled = + tensor_elementwise_mul(&mut builder, &self.scaling_factor(), &qk_copy); + let qk_copy_scaled_exp = + tensor_map(&mut builder, &qk_copy_scaled, |telem, builder| { + builder.exp(telem) + }); + + let qk_scaled_softmax = + tensor_elementwise_div(&mut builder, &qk_copy_scaled_exp, &sum_op); + + let res = matrix_matrix_multiply(&mut builder, &qk_scaled_softmax, &v); + + res.store(&self.o, &mut builder); + + let candidate = build_candidate(builder.get(), ctx); + + vec![candidate] + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Array2 { + let q_shape = (self.params.m as usize, self.params.p as usize); + let k_shape = (self.params.p as usize, self.params.n as usize); + let v_shape = (self.params.n as usize, self.params.r as usize); + + let q = unwrap!(self.q.read_to_host(context).into_shape(q_shape)); + let k = unwrap!(self.k.read_to_host(context).into_shape(k_shape)); + let v = unwrap!(self.v.read_to_host(context).into_shape(v_shape)); + + let mut qk = q.dot(&k); + qk.mapv_inplace(|c| c * self.scaling_factor()); + array_softmax_inplace(&mut qk); + + qk.dot(&v) + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + let o_shape = (self.params.m as usize, self.params.r as usize); + let o = unwrap!(self.o.read_to_host(context).into_shape(o_shape)); + if let Err(invalid) = check_output(&o, expected) { + Err(format!("Invalid transformercell output: {}", invalid)) + } else { + Ok(()) + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct TransformerCellTopHalfP { + pub m: i32, + pub n: i32, + pub p: i32, + pub transpose_q: bool, + pub transpose_k: bool, + pub generic: bool, + pub m_tiling: Option, + pub n_tiling: Option, + pub p_tiling: Option, +} + +impl TransformerCellTopHalfP { + pub fn new(m: i32, n: i32, p: i32) -> Self { + TransformerCellTopHalfP { + m, + n, + p, + transpose_q: false, + transpose_k: false, + generic: true, + m_tiling: None, + n_tiling: None, + p_tiling: None, + } + } + + pub fn transpose_q(mut self) -> Self { + self.transpose_q = true; + self + } + + pub fn transpose_k(mut self) -> Self { + self.transpose_k = true; + self + } + + /// Inline the sizes in the generated code. + pub fn static_sizes(mut self) -> Self { + self.generic = false; + self + } +} + +/// Computes `O = elementwise_exp(scale(Q.V))` and `S = +/// scalar_sum(O)`, which corresponds to the first half of the +/// computation of `TransformerCell` (break in the middle of softmax). +pub struct TransformerCellTopHalf<'a, S: Scalar> { + pub params: TransformerCellTopHalfP, + q: Tensor<'a, S>, + k: Tensor<'a, S>, + o: Tensor<'a, S>, + s: Tensor<'a, S>, +} + +impl<'a, S: Scalar> TransformerCellTopHalf<'a, S> { + fn scaling_factor(&self) -> S { + S::from(1f64 / f64::sqrt(self.params.p as f64 * self.params.n as f64)).unwrap() + } +} + +impl<'a, S: Scalar> Kernel<'a> for TransformerCellTopHalf<'a, S> { + type Parameters = TransformerCellTopHalfP; + type ExpectedOutput = (Array2, S); + + fn name() -> &'static str { + "transformercelltophalf" + } + + fn build_signature( + params: TransformerCellTopHalfP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + let m_size = create_size(params.m, "m", params.generic, builder); + let n_size = create_size(params.n, "n", params.generic, builder); + let p_size = create_size(params.p, "p", params.generic, builder); + + let q = TensorBuilder::new("q", vec![m_size.clone(), p_size.clone()]) + .doif(params.transpose_q, |b| b.transpose(0, 1)) + .finish(builder); + + let k = TensorBuilder::new("k", vec![p_size, n_size.clone()]) + .doif(params.transpose_k, |b| b.transpose(0, 1)) + .finish(builder); + + let o = builder.tensor::("o", vec![m_size, n_size], false); + let s = builder.tensor::("s", vec![], false); + + TransformerCellTopHalf { params, q, k, o, s } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + let m_tiling = infer_tiling(self.params.m, &self.params.m_tiling, &[32, 4]); + let n_tiling = infer_tiling(self.params.n, &self.params.n_tiling, &[32, 4]); + let p_tiling = infer_tiling(self.params.p, &self.params.p_tiling, &[32]); + + let mut builder = helper::Builder::new(signature, ctx.device()); + + let q = self.q.load(vec![m_tiling, p_tiling.clone()], &mut builder); + let k = self + .k + .load(vec![p_tiling.clone(), n_tiling.clone()], &mut builder); + + let qk = matrix_matrix_multiply(&mut builder, &q, &k); + let qk_scaled = tensor_elementwise_mul(&mut builder, &self.scaling_factor(), &qk); + let qk_scaled_exp = tensor_map(&mut builder, &qk_scaled, |telem, builder| { + builder.exp(telem) + }); + + let sum = tensor_sum(&mut builder, &qk_scaled_exp); + + qk_scaled_exp.store(&self.o, &mut builder); + sum.store(&self.s, &mut builder); + + let candidate = build_candidate(builder.get(), ctx); + + vec![candidate] + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Self::ExpectedOutput { + let q_shape = (self.params.m as usize, self.params.p as usize); + let k_shape = (self.params.p as usize, self.params.n as usize); + + let q = unwrap!(self.q.read_to_host(context).into_shape(q_shape)); + let k = unwrap!(self.k.read_to_host(context).into_shape(k_shape)); + + let mut qk = q.dot(&k); + qk.mapv_inplace(|c| S::exp(c * self.scaling_factor())); + let sum = qk.scalar_sum(); + + (qk, sum) + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + let o_shape = (self.params.m as usize, self.params.n as usize); + let o = unwrap!(self.o.read_to_host(context).into_shape(o_shape)); + let s = unwrap!(self.s.read_to_host(context).into_shape(())); + + if let Err(invalid) = check_output(&o, &expected.0) { + Err(format!( + "Invalid transformercelltophalf matrix output: {}", + invalid + )) + } else if let Err(invalid) = check_output(&s, &Array0::from_elem((), expected.1)) + { + Err(format!( + "Invalid transformercelltophalf sum output: {}", + invalid + )) + } else { + Ok(()) + } + } +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct TransformerCellBottomHalfP { + pub m: i32, + pub n: i32, + pub r: i32, + pub transpose_qk_scexp: bool, + pub transpose_v: bool, + pub generic: bool, + pub m_tiling: Option, + pub n_tiling: Option, + pub r_tiling: Option, +} + +impl TransformerCellBottomHalfP { + pub fn new(m: i32, n: i32, r: i32) -> Self { + TransformerCellBottomHalfP { + m, + n, + r, + transpose_qk_scexp: false, + transpose_v: false, + generic: true, + m_tiling: None, + n_tiling: None, + r_tiling: None, + } + } + + pub fn transpose_qk_scexp(mut self) -> Self { + self.transpose_qk_scexp = true; + self + } + + pub fn transpose_v(mut self) -> Self { + self.transpose_v = true; + self + } + + /// Inline the sizes in the generated code. + pub fn static_sizes(mut self) -> Self { + self.generic = false; + self + } +} + +/// Computes `O = (1/s_exp * QKSCEXP).V`, which corresponds to the +/// second half of the computation of `TransformerCell` (break in the +/// middle of softmax). +pub struct TransformerCellBottomHalf<'a, S: Scalar> { + pub params: TransformerCellBottomHalfP, + s_exp: Tensor<'a, S>, + qk_scexp: Tensor<'a, S>, + v: Tensor<'a, S>, + o: Tensor<'a, S>, +} + +impl<'a, S: Scalar> Kernel<'a> for TransformerCellBottomHalf<'a, S> { + type Parameters = TransformerCellBottomHalfP; + type ExpectedOutput = Array2; + + fn name() -> &'static str { + "transformercellbottomhalf" + } + + fn build_signature( + params: TransformerCellBottomHalfP, + builder: &mut SignatureBuilder, + ) -> Self + where + AM: device::ArgMap<'a> + device::Context, + { + let m_size = create_size(params.m, "m", params.generic, builder); + let n_size = create_size(params.n, "n", params.generic, builder); + let r_size = create_size(params.r, "r", params.generic, builder); + + let s_exp = builder.tensor::("s_exp", vec![], true); + + let qk_scexp = + TensorBuilder::new("qk_scexp", vec![m_size.clone(), n_size.clone()]) + .doif(params.transpose_qk_scexp, |b| b.transpose(0, 1)) + .finish(builder); + + let v = TensorBuilder::new("v", vec![n_size, r_size.clone()]) + .doif(params.transpose_v, |b| b.transpose(0, 1)) + .finish(builder); + + let o = builder.tensor::("o", vec![m_size, r_size], false); + + TransformerCellBottomHalf { + params, + s_exp, + qk_scexp, + v, + o, + } + } + + fn build_body<'b>( + &self, + signature: Arc, + ctx: &'b dyn device::Context, + ) -> Vec { + let m_tiling = infer_tiling(self.params.m, &self.params.m_tiling, &[32, 4]); + let r_tiling = infer_tiling(self.params.n, &self.params.n_tiling, &[32, 4]); + let n_tiling = infer_tiling(self.params.r, &self.params.r_tiling, &[32]); + + let mut builder = helper::Builder::new(signature, ctx.device()); + + let s_exp = self.s_exp.load(vec![], &mut builder); + let s_exp_op = s_exp.dim_map(&[], GlobalScope(()), &mut builder); + + let qk_scexp = self + .qk_scexp + .load(vec![m_tiling, r_tiling.clone()], &mut builder); + + let v = self + .v + .load(vec![r_tiling.clone(), n_tiling.clone()], &mut builder); + + let qk_scexp_div = tensor_elementwise_div(&mut builder, &qk_scexp, &s_exp_op); + let o = matrix_matrix_multiply(&mut builder, &qk_scexp_div, &v); + + o.store(&self.o, &mut builder); + + let candidate = build_candidate(builder.get(), ctx); + + vec![candidate] + } + + fn get_expected_output(&self, context: &dyn device::Context) -> Self::ExpectedOutput { + let qk_scexp_shape = (self.params.m as usize, self.params.n as usize); + let v_shape = (self.params.n as usize, self.params.r as usize); + + let mut qk_scexp = unwrap!(self + .qk_scexp + .read_to_host(context) + .into_shape(qk_scexp_shape)); + let v = unwrap!(self.v.read_to_host(context).into_shape(v_shape)); + let s_exp = unwrap!(self.s_exp.read_to_host(context).into_shape(())); + + qk_scexp.mapv_inplace(|c| c / s_exp[[]]); + + qk_scexp.dot(&v) + } + + fn check_result( + &self, + expected: &Self::ExpectedOutput, + context: &dyn device::Context, + ) -> Result<(), String> { + let o_shape = (self.params.m as usize, self.params.r as usize); + let o = unwrap!(self.o.read_to_host(context).into_shape(o_shape)); + + if let Err(invalid) = check_output(&o, &expected) { + Err(format!( + "Invalid transformercellbottomhalf output: {}", + invalid + )) + } else { + Ok(()) + } + } +} diff --git a/src/helper/tensor.rs b/src/helper/tensor.rs index 9a098f186..82b5e4862 100644 --- a/src/helper/tensor.rs +++ b/src/helper/tensor.rs @@ -186,11 +186,11 @@ where &self, tiling: Vec, builder: &mut Builder, - ) -> VirtualTensor { + ) -> VirtualTensor { let dims = self .iter_dims .iter() - .zip_eq(tiling) + .zip_eq(tiling.clone()) .map(|(dim, tiling)| { let size = dim.0.to_ir_size(builder); builder.open_tiled_dim(size, tiling) @@ -215,7 +215,14 @@ where for dim in &dims { builder.close_dim(dim); } - VirtualTensor { inst, dims } + VirtualTensor { + inst, + dims, + source: VirtualTensorSource::Tensor { + tensor: self, + tiling, + }, + } } /// Reads the tensor value in the context and copies it on the host. @@ -230,7 +237,12 @@ where (l.eval(context) as usize, (s.eval(context) / s_len) as usize) }) .unzip(); - let len = unwrap!(sizes.iter().zip_eq(&strides).map(|(&l, &s)| l * s).max()); + let len = sizes + .iter() + .zip_eq(&strides) + .map(|(&l, &s)| l * s) + .max() + .unwrap_or(1); raw.split_off(len); unwrap!(ndarray::ArrayBase::from_shape_vec( sizes.strides(strides), @@ -239,16 +251,41 @@ where } } +pub enum VirtualTensorSource<'a, S: ScalarArgument> { + Tensor { + tensor: &'a Tensor<'a, S>, + tiling: Vec, + }, + Instruction, +} + /// A tensor loaded in registers. -pub struct VirtualTensor { +pub struct VirtualTensor<'a, S: ScalarArgument> { inst: ir::InstId, dims: Vec, + source: VirtualTensorSource<'a, S>, } -impl VirtualTensor { +impl<'a, S: ScalarArgument> VirtualTensor<'a, S> { /// Creates a new `VirtualTensor`. pub fn new(inst: ir::InstId, dims: Vec) -> Self { - VirtualTensor { inst, dims } + VirtualTensor { + inst, + dims, + source: VirtualTensorSource::Instruction, + } + } + + /// Duplicates the virtual tensor. + /// + /// FIXME: Currently only implemented if VirtualTensor originates + /// from a load of a Tensor + pub fn duplicate(&self, builder: &mut Builder) -> VirtualTensor { + match &self.source { + VirtualTensorSource::Tensor { tensor, tiling } => + tensor.load(tiling.clone(), builder), + _ => panic!("Duplication of VirtualTensor is only implemented if originating from a load") + } } /// Creates an operand that yeilds the values of the tensor in the given loop nest. @@ -264,7 +301,7 @@ impl VirtualTensor { /// Stores the `VirtualTensor` in memory. Stores contiguously without taking the /// layout of the target tensor into account. - pub fn store(&self, tensor: &Tensor, builder: &mut Builder) -> VirtualTensor + pub fn store(&self, tensor: &Tensor, builder: &mut Builder) -> VirtualTensor where S: ScalarArgument, { @@ -285,6 +322,7 @@ impl VirtualTensor { VirtualTensor { inst, dims: new_dims, + source: VirtualTensorSource::Instruction, } } @@ -314,7 +352,7 @@ impl VirtualTensor { } } -impl std::ops::Index for VirtualTensor { +impl<'a, S: ScalarArgument> std::ops::Index for VirtualTensor<'a, S> { type Output = LogicalDim; fn index(&self, idx: usize) -> &Self::Output { @@ -322,7 +360,7 @@ impl std::ops::Index for VirtualTensor { } } -impl<'a> IntoIterator for &'a VirtualTensor { +impl<'a, S: ScalarArgument> IntoIterator for &'a VirtualTensor<'_, S> { type Item = &'a LogicalDim; type IntoIter = std::slice::Iter<'a, LogicalDim>; diff --git a/telamon-cli/Cargo.toml b/telamon-cli/Cargo.toml index a06e9b675..d57729b97 100644 --- a/telamon-cli/Cargo.toml +++ b/telamon-cli/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] structopt = "0.2" cuda-sys = { version = "0.1", optional = true } +cudnn = { version = "1.3", optional = true } libc = { version = "0.2", optional = true } env_logger = "0.5" log = "0.4" @@ -27,5 +28,5 @@ telamon-x86 = { path = "../backend/x86", optional = true } [features] default = ["cuda"] -cuda = ["telamon-cuda/real_gpu", "cuda-sys", "libc"] +cuda = ["telamon-cuda/real_gpu", "cuda-sys", "cudnn", "libc"] x86 = ["telamon-x86"] diff --git a/telamon-cli/src/bin/cuda_search/main.rs b/telamon-cli/src/bin/cuda_search/main.rs index 73e72f2ae..657e33bca 100644 --- a/telamon-cli/src/bin/cuda_search/main.rs +++ b/telamon-cli/src/bin/cuda_search/main.rs @@ -179,6 +179,42 @@ fn main() { idx, ) .run(&config, &executor, &reference), + ResNetCell { m, n, k, activation_fun } => Benchmark::<'_, linalg::ResNetCell<'_, f32>>::new( + linalg::ResNetCellP::new(m, n, k).activation_fun(activation_fun), + format!("ResNetCell_{}_{}_{}_{}", m, n, k, telamon_kernels::linalg::compose::ActivationFunction::opt_to_display(&activation_fun)), + idx, + ) + .run(&config, &executor, &reference), + ResNetCellTopHalf { m, n, k, activation_fun } => Benchmark::<'_, linalg::ResNetCellTopHalf<'_, f32>>::new( + linalg::ResNetCellTopHalfP::new(m, n, k, activation_fun), + format!("ResNetCellTopHalf_{}_{}_{}_{}", m, n, k, telamon_kernels::linalg::compose::ActivationFunction::opt_to_display(&activation_fun)), + idx, + ) + .run(&config, &executor, &reference), + ResNetCellBottomHalf { m, n, k, activation_fun } => Benchmark::<'_, linalg::ResNetCellBottomHalf<'_, f32>>::new( + linalg::ResNetCellBottomHalfP::new(m, n, k, activation_fun), + format!("ResNetCellBottomHalf_{}_{}_{}_{}", m, n, k, telamon_kernels::linalg::compose::ActivationFunction::opt_to_display(&activation_fun)), + idx, + ) + .run(&config, &executor, &reference), + TransformerCell { m, n, p, r } => Benchmark::<'_, linalg::TransformerCell<'_, f32>>::new( + linalg::TransformerCellP::new(m, n, p, r), + format!("TransformerCell_{}_{}_{}_{}", m, n, p, r), + idx, + ) + .run(&config, &executor, &reference), + TransformerCellTopHalf { m, n, p } => Benchmark::<'_, linalg::TransformerCellTopHalf<'_, f32>>::new( + linalg::TransformerCellTopHalfP::new(m, n, p), + format!("TransformerCellTopHalf_{}_{}_{}", m, n, p), + idx, + ) + .run(&config, &executor, &reference), + TransformerCellBottomHalf { m, n, r } => Benchmark::<'_, linalg::TransformerCellBottomHalf<'_, f32>>::new( + linalg::TransformerCellBottomHalfP::new(m, n, r), + format!("TransformerCellBottomHalf_{}_{}_{}", m, n, r), + idx, + ) + .run(&config, &executor, &reference), } } } diff --git a/telamon-cli/src/lib.rs b/telamon-cli/src/lib.rs index 0a95da523..83700b0e6 100644 --- a/telamon-cli/src/lib.rs +++ b/telamon-cli/src/lib.rs @@ -297,6 +297,62 @@ mod cuda_reference { } } + /// Reference implementation for `ResNetCell`. + fn resnetcell_reference( + handle: &CublasHandle, + params: &linalg::ResNetCellP, + context: &cuda::Context, + ) -> f64 { + panic!("NOT IMPLEMENTED!"); + } + + /// Reference implementation for `ResNetCellTopHalf`. + fn resnetcelltophalf_reference( + handle: &CublasHandle, + params: &linalg::ResNetCellTopHalfP, + context: &cuda::Context, + ) -> Result { + panic!("NOT IMPLEMENTED!"); + } + + /// Reference implementation for `ResNetCellBottomHalf`. + fn resnetcellbottomhalf_reference( + handle: &CublasHandle, + params: &linalg::ResNetCellBottomHalfP, + context: &cuda::Context, + ) -> Result { + panic!("NOT IMPLEMENTED!"); + } + + /// Reference implementation for the transformer cell. + fn transformercell_reference( + handle: &CublasHandle, + params: &linalg::TransformerCellP, + context: &cuda::Context, + ) -> f64 { + panic!("NOT IMPLEMENTED!"); + } + + /// Reference implementation for the first operations of the + /// transformer cell. + fn transformercelltophalf_reference( + handle: &CublasHandle, + params: &linalg::TransformerCellTopHalfP, + context: &cuda::Context, + ) -> f64 { + panic!("NOT IMPLEMENTED"); + } + + /// Reference implementation for the last operations of the + /// transformer cell. + fn transformercellbottomhalf_reference( + handle: &CublasHandle, + params: &linalg::TransformerCellBottomHalfP, + context: &cuda::Context, + ) -> f64 { + panic!("NOT IMPLEMENTED"); + } + impl<'a> Reference<'a, linalg::Axpy<'a, f32>> for CublasHandle { type Context = cuda::Context<'a>; @@ -352,6 +408,78 @@ mod cuda_reference { gesummv_reference(self, params, context) } } + + impl<'a> Reference<'a, linalg::ResNetCell<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::ResNetCellP, + context: &Self::Context, + ) -> f64 { + resnetcell_reference(self, params, context) + } + } + + impl<'a> Reference<'a, linalg::ResNetCellTopHalf<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::ResNetCellTopHalfP, + context: &Self::Context, + ) -> f64 { + resnetcelltophalf_reference(self, params, context).expect("") + } + } + + impl<'a> Reference<'a, linalg::ResNetCellBottomHalf<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::ResNetCellBottomHalfP, + context: &Self::Context, + ) -> f64 { + resnetcellbottomhalf_reference(self, params, context).expect("") + } + } + + impl<'a> Reference<'a, linalg::TransformerCell<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::TransformerCellP, + context: &Self::Context, + ) -> f64 { + transformercell_reference(self, params, context) + } + } + + impl<'a> Reference<'a, linalg::TransformerCellTopHalf<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::TransformerCellTopHalfP, + context: &Self::Context, + ) -> f64 { + transformercelltophalf_reference(self, params, context) + } + } + + impl<'a> Reference<'a, linalg::TransformerCellBottomHalf<'a, f32>> for CublasHandle { + type Context = cuda::Context<'a>; + + fn eval_reference( + &self, + params: &linalg::TransformerCellBottomHalfP, + context: &Self::Context, + ) -> f64 { + transformercellbottomhalf_reference(self, params, context) + } + } } #[cfg(feature = "cuda")] @@ -360,11 +488,62 @@ pub use cuda_reference::CublasHandle; /// Helper enum to create the supported kernel parameters. #[derive(Debug, Clone, PartialEq, Eq)] pub enum KernelParam { - Axpy { n: i32 }, - MatVec { m: i32, n: i32 }, - Gesummv { m: i32, n: i32 }, - Gemm { m: i32, n: i32, k: i32 }, - BatchMM { b: i32, m: i32, n: i32, k: i32 }, + Axpy { + n: i32, + }, + MatVec { + m: i32, + n: i32, + }, + Gesummv { + m: i32, + n: i32, + }, + Gemm { + m: i32, + n: i32, + k: i32, + }, + BatchMM { + b: i32, + m: i32, + n: i32, + k: i32, + }, + ResNetCell { + m: i32, + n: i32, + k: i32, + activation_fun: Option, + }, + ResNetCellTopHalf { + m: i32, + n: i32, + k: i32, + activation_fun: Option, + }, + ResNetCellBottomHalf { + m: i32, + n: i32, + k: i32, + activation_fun: Option, + }, + TransformerCell { + m: i32, + n: i32, + p: i32, + r: i32, + }, + TransformerCellTopHalf { + m: i32, + n: i32, + p: i32, + }, + TransformerCellBottomHalf { + m: i32, + n: i32, + r: i32, + }, } impl KernelParam { @@ -410,6 +589,51 @@ impl KernelParam { context, ) } + KernelParam::ResNetCell { + m, + n, + k, + activation_fun, + } => build::<'_, '_, linalg::ResNetCell<'_, f32>, C>( + linalg::ResNetCellP::new(m, n, k).activation_fun(activation_fun), + context, + ), + KernelParam::ResNetCellTopHalf { + m, + n, + k, + activation_fun, + } => build::<'_, '_, linalg::ResNetCellTopHalf<'_, f32>, C>( + linalg::ResNetCellTopHalfP::new(m, n, k, activation_fun), + context, + ), + KernelParam::ResNetCellBottomHalf { + m, + n, + k, + activation_fun, + } => build::<'_, '_, linalg::ResNetCellBottomHalf<'_, f32>, C>( + linalg::ResNetCellBottomHalfP::new(m, n, k, activation_fun), + context, + ), + KernelParam::TransformerCell { m, n, p, r } => { + build::<'_, '_, linalg::TransformerCell<'_, f32>, C>( + linalg::TransformerCellP::new(m, n, p, r), + context, + ) + } + KernelParam::TransformerCellTopHalf { m, n, p } => { + build::<'_, '_, linalg::TransformerCellTopHalf<'_, f32>, C>( + linalg::TransformerCellTopHalfP::new(m, n, p), + context, + ) + } + KernelParam::TransformerCellBottomHalf { m, n, r } => { + build::<'_, '_, linalg::TransformerCellBottomHalf<'_, f32>, C>( + linalg::TransformerCellBottomHalfP::new(m, n, r), + context, + ) + } } } } @@ -424,6 +648,54 @@ impl fmt::Display for KernelParam { KernelParam::BatchMM { b, m, n, k } => { write!(fmt, "batchmm_{}_{}_{}_{}", b, m, n, k) } + KernelParam::ResNetCell { + m, + n, + k, + activation_fun, + } => write!( + fmt, + "resnetcell_{}_{}_{}_{}", + m, + n, + k, + linalg::compose::ActivationFunction::opt_to_display(activation_fun) + ), + KernelParam::ResNetCellTopHalf { + m, + n, + k, + activation_fun, + } => write!( + fmt, + "resnetcelltophalf_{}_{}_{}_{}", + m, + n, + k, + linalg::compose::ActivationFunction::opt_to_display(activation_fun) + ), + KernelParam::ResNetCellBottomHalf { + m, + n, + k, + activation_fun, + } => write!( + fmt, + "resnetcellbottomhalf_{}_{}_{}_{}", + m, + n, + k, + linalg::compose::ActivationFunction::opt_to_display(activation_fun) + ), + KernelParam::TransformerCell { m, n, p, r } => { + write!(fmt, "transformercell_{}_{}_{}_{}", m, n, p, r) + } + KernelParam::TransformerCellTopHalf { m, n, p } => { + write!(fmt, "transformercelltophalf_{}_{}_{}", m, n, p) + } + KernelParam::TransformerCellBottomHalf { m, n, r } => { + write!(fmt, "transformercellbottomhalf_{}_{}_{}", m, n, r) + } } } } @@ -452,6 +724,9 @@ pub enum KernelErrorKind { /// A non-integer value was found where an integer value was expected. IntError(std::num::ParseIntError), + + /// An invalid activation function was found. + ActivationFunctionError(linalg::compose::ActivationFunctionParseError), } impl ParseKernelError { @@ -475,6 +750,9 @@ impl fmt::Display for ParseKernelError { fmt.write_str("extraneous unexpected kernel parameter") } KernelErrorKind::IntError(error) => fmt::Display::fmt(error, fmt), + KernelErrorKind::ActivationFunctionError(error) => { + fmt::Display::fmt(error, fmt) + } } } } @@ -496,6 +774,16 @@ impl From for ParseKernelError { } } +impl From for ParseKernelError { + fn from( + error: telamon_kernels::compose::ActivationFunctionParseError, + ) -> ParseKernelError { + ParseKernelError { + kind: KernelErrorKind::ActivationFunctionError(error), + } + } +} + impl std::str::FromStr for KernelParam { type Err = ParseKernelError; @@ -551,6 +839,70 @@ impl std::str::FromStr for KernelParam { let k = parse_i32(next_part(&mut parts)?)?; BatchMM { b, m, n, k } } + "resnetcell" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let k = parse_i32(next_part(&mut parts)?)?; + let activation_fun = + linalg::compose::ActivationFunction::opt_from_string(next_part( + &mut parts, + )?)?; + ResNetCell { + m, + n, + k, + activation_fun, + } + } + "resnetcelltophalf" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let k = parse_i32(next_part(&mut parts)?)?; + let activation_fun = + linalg::compose::ActivationFunction::opt_from_string(next_part( + &mut parts, + )?)?; + ResNetCellTopHalf { + m, + n, + k, + activation_fun, + } + } + "resnetcellbottomhalf" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let k = parse_i32(next_part(&mut parts)?)?; + let activation_fun = + linalg::compose::ActivationFunction::opt_from_string(next_part( + &mut parts, + )?)?; + ResNetCellBottomHalf { + m, + n, + k, + activation_fun, + } + } + "transformercell" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let p = parse_i32(next_part(&mut parts)?)?; + let r = parse_i32(next_part(&mut parts)?)?; + TransformerCell { m, n, p, r } + } + "transformercelltophalf" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let p = parse_i32(next_part(&mut parts)?)?; + TransformerCellTopHalf { m, n, p } + } + "transformercellbottomhalf" => { + let m = parse_i32(next_part(&mut parts)?)?; + let n = parse_i32(next_part(&mut parts)?)?; + let r = parse_i32(next_part(&mut parts)?)?; + TransformerCellBottomHalf { m, n, r } + } _ => { return Err(ParseKernelError { kind: KernelErrorKind::InvalidName,