diff --git a/src/stdlib_experimental_stats.fypp b/src/stdlib_experimental_stats.fypp
index b19bb2178..0edb2ba9e 100644
--- a/src/stdlib_experimental_stats.fypp
+++ b/src/stdlib_experimental_stats.fypp
@@ -15,8 +15,9 @@ module stdlib_experimental_stats
 
     #:for k1, t1 in REAL_KINDS_TYPES
       #:for rank in RANKS
-        module function mean_${rank}$_all_${k1}$_${k1}$(x) result(res)
+        module function mean_${rank}$_all_${k1}$_${k1}$(x, mask) result(res)
           ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          logical, intent(in), optional :: mask
           ${t1}$ :: res
         end function mean_${rank}$_all_${k1}$_${k1}$
       #:endfor
@@ -24,8 +25,9 @@ module stdlib_experimental_stats
 
     #:for k1, t1 in INT_KINDS_TYPES
       #:for rank in RANKS
-        module function mean_${rank}$_all_${k1}$_dp(x) result(res)
+        module function mean_${rank}$_all_${k1}$_dp(x, mask) result(res)
           ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          logical, intent(in), optional :: mask
           real(dp) :: res
         end function mean_${rank}$_all_${k1}$_dp
       #:endfor
@@ -33,9 +35,10 @@ module stdlib_experimental_stats
 
     #:for k1, t1 in REAL_KINDS_TYPES
       #:for rank in RANKS
-        module function mean_${rank}$_${k1}$_${k1}$(x, dim) result(res)
+        module function mean_${rank}$_${k1}$_${k1}$(x, dim, mask) result(res)
           ${t1}$, intent(in) :: x${ranksuffix(rank)}$
           integer, intent(in) :: dim
+          logical, intent(in), optional :: mask
           ${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
         end function mean_${rank}$_${k1}$_${k1}$
       #:endfor
@@ -43,14 +46,61 @@ module stdlib_experimental_stats
 
     #:for k1, t1 in INT_KINDS_TYPES
       #:for rank in RANKS
-        module function mean_${rank}$_${k1}$_dp(x, dim) result(res)
+        module function mean_${rank}$_${k1}$_dp(x, dim, mask) result(res)
           ${t1}$, intent(in) :: x${ranksuffix(rank)}$
           integer, intent(in) :: dim
+          logical, intent(in), optional :: mask
           real(dp) :: res${reduced_shape('x', rank, 'dim')}$
         end function mean_${rank}$_${k1}$_dp
       #:endfor
     #:endfor
 
+
+    #:for k1, t1 in REAL_KINDS_TYPES
+      #:for rank in RANKS
+        module function mean_${rank}$_mask_all_${k1}$_${k1}$(x, mask) result(res)
+          ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          logical, intent(in) :: mask${ranksuffix(rank)}$
+          ${t1}$ :: res
+        end function mean_${rank}$_mask_all_${k1}$_${k1}$
+      #:endfor
+    #:endfor
+
+
+    #:for k1, t1 in INT_KINDS_TYPES
+      #:for rank in RANKS
+        module function mean_${rank}$_mask_all_${k1}$_dp(x, mask) result(res)
+          ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          logical, intent(in) :: mask${ranksuffix(rank)}$
+          real(dp) :: res
+        end function mean_${rank}$_mask_all_${k1}$_dp
+      #:endfor
+    #:endfor
+
+
+    #:for k1, t1 in REAL_KINDS_TYPES
+      #:for rank in RANKS
+        module function mean_${rank}$_mask_${k1}$_${k1}$(x, dim, mask) result(res)
+          ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          integer, intent(in) :: dim
+          logical, intent(in) :: mask${ranksuffix(rank)}$
+          ${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
+        end function mean_${rank}$_mask_${k1}$_${k1}$
+      #:endfor
+    #:endfor
+
+
+    #:for k1, t1 in INT_KINDS_TYPES
+      #:for rank in RANKS
+        module function mean_${rank}$_mask_${k1}$_dp(x, dim, mask) result(res)
+          ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+          integer, intent(in) :: dim
+          logical, intent(in) :: mask${ranksuffix(rank)}$
+          real(dp) :: res${reduced_shape('x', rank, 'dim')}$
+        end function mean_${rank}$_mask_${k1}$_dp
+      #:endfor
+    #:endfor
+
   end interface mean
 
 end module stdlib_experimental_stats
diff --git a/src/stdlib_experimental_stats.md b/src/stdlib_experimental_stats.md
index 420d55580..1059ef65c 100644
--- a/src/stdlib_experimental_stats.md
+++ b/src/stdlib_experimental_stats.md
@@ -8,13 +8,13 @@
 
 ### Description
 
-Returns the mean of all the elements of `array`, or of the elements of `array` along dimension `dim` if provided.
+Returns the mean of all the elements of `array`, or of the elements of `array` along dimension `dim` if provided, and if the corresponding element in `mask` is `true`.
 
 ### Syntax
 
-`result = mean(array)`
+`result = mean(array [, mask])`
 
-`result = mean(array, dim)`
+`result = mean(array, dim [, mask])`
 
 ### Arguments
 
@@ -22,6 +22,8 @@ Returns the mean of all the elements of `array`, or of the elements of `array` a
 
 `dim`: Shall be a scalar of type `integer` with a value in the range from 1 to n, where n is the rank of `array`.
 
+`mask` (optional): Shall be of type `logical` and either by a scalar or an array of the same shape as `array`.
+
 ### Return value
 
 If `array` is of type `real`, the result is of the same type as `array`.
@@ -29,6 +31,8 @@ If `array` is of type `integer`, the result is of type `double precision`.
 
 If `dim` is absent, a scalar with the mean of all elements in `array` is returned. Otherwise, an array of rank n-1, where n equals the rank of `array`, and a shape similar to that of `array` with dimension `dim` dropped is returned.
 
+If `mask` is specified, the result is the mean of all elements of `array` corresponding to `true` elements of `mask`. If every element of `mask` is `false`, the result is IEEE `NaN`.
+
 ### Example
 
 ```fortran
@@ -36,8 +40,10 @@ program demo_mean
     use stdlib_experimental_stats, only: mean
     implicit none
     real :: x(1:6) = [ 1., 2., 3., 4., 5., 6. ]
-    print *, mean(x)                            !returns 21.
-    print *, mean( reshape(x, [ 2, 3 ] ))       !returns 21.
-    print *, mean( reshape(x, [ 2, 3 ] ), 1)    !returns [ 3., 7., 11. ]
+    print *, mean(x)                            !returns 3.5
+    print *, mean( reshape(x, [ 2, 3 ] ))       !returns 3.5
+    print *, mean( reshape(x, [ 2, 3 ] ), 1)    !returns [ 1.5, 3.5, 5.5 ]
+    print *, mean( reshape(x, [ 2, 3 ] ), 1,&
+                   reshape(x, [ 2, 3 ] ) > 3.)  !returns [ NaN, 4.0, 5.5 ]
 end program demo_mean
 ```
diff --git a/src/stdlib_experimental_stats_mean.fypp b/src/stdlib_experimental_stats_mean.fypp
index e16a781ed..26b5d47ba 100644
--- a/src/stdlib_experimental_stats_mean.fypp
+++ b/src/stdlib_experimental_stats_mean.fypp
@@ -5,17 +5,25 @@
 
 submodule (stdlib_experimental_stats) stdlib_experimental_stats_mean
 
+  use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan
   use stdlib_experimental_error, only: error_stop
+  use stdlib_experimental_optval, only: optval
   implicit none
 
 contains
 
   #:for k1, t1 in REAL_KINDS_TYPES
     #:for rank in RANKS
-      module function mean_${rank}$_all_${k1}$_${k1}$(x) result(res)
+      module function mean_${rank}$_all_${k1}$_${k1}$(x, mask) result(res)
         ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        logical, intent(in), optional :: mask
         ${t1}$ :: res
 
+        if (.not.optval(mask, .true.)) then
+          res = ieee_value(res, ieee_quiet_nan)
+          return
+        end if
+
         res = sum(x) / real(size(x, kind = int64), ${k1}$)
 
       end function mean_${rank}$_all_${k1}$_${k1}$
@@ -25,10 +33,16 @@ contains
 
   #:for k1, t1 in INT_KINDS_TYPES
     #:for rank in RANKS
-      module function mean_${rank}$_all_${k1}$_dp(x) result(res)
+      module function mean_${rank}$_all_${k1}$_dp(x, mask) result(res)
         ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        logical, intent(in), optional :: mask
         real(dp) :: res
 
+        if (.not.optval(mask, .true.)) then
+          res = ieee_value(res, ieee_quiet_nan)
+          return
+        end if
+
         res = sum(real(x, dp)) / real(size(x, kind = int64), dp)
 
       end function mean_${rank}$_all_${k1}$_dp
@@ -38,11 +52,17 @@ contains
 
   #:for k1, t1 in REAL_KINDS_TYPES
     #:for rank in RANKS
-      module function mean_${rank}$_${k1}$_${k1}$(x, dim) result(res)
+      module function mean_${rank}$_${k1}$_${k1}$(x, dim, mask) result(res)
         ${t1}$, intent(in) :: x${ranksuffix(rank)}$
         integer, intent(in) :: dim
+        logical, intent(in), optional :: mask
         ${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
 
+        if (.not.optval(mask, .true.)) then
+          res = ieee_value(res, ieee_quiet_nan)
+          return
+        end if
+
         if (dim >= 1 .and. dim <= ${rank}$) then
           res = sum(x, dim) / real(size(x, dim), ${k1}$)
         else
@@ -56,13 +76,19 @@ contains
 
   #:for k1, t1 in INT_KINDS_TYPES
     #:for rank in RANKS
-      module function mean_${rank}$_${k1}$_dp(x, dim) result(res)
+      module function mean_${rank}$_${k1}$_dp(x, dim, mask) result(res)
         ${t1}$, intent(in) :: x${ranksuffix(rank)}$
         integer, intent(in) :: dim
+        logical, intent(in), optional :: mask
         real(dp) :: res${reduced_shape('x', rank, 'dim')}$
 
+        if (.not.optval(mask, .true.)) then
+          res = ieee_value(res, ieee_quiet_nan)
+          return
+        end if
+
         if (dim >= 1 .and. dim <= ${rank}$) then
-          res = sum(x, dim) / real(size(x, dim), dp)
+          res = sum(real(x, dp), dim) / real(size(x, dim), dp)
         else
           call error_stop("ERROR (mean): wrong dimension")
         end if
@@ -71,4 +97,70 @@ contains
     #:endfor
   #:endfor
 
+
+  #:for k1, t1 in REAL_KINDS_TYPES
+    #:for rank in RANKS
+      module function mean_${rank}$_mask_all_${k1}$_${k1}$(x, mask) result(res)
+        ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        logical, intent(in) :: mask${ranksuffix(rank)}$
+        ${t1}$ :: res
+
+        res = sum(x, mask) / real(count(mask, kind = int64), ${k1}$)
+
+      end function mean_${rank}$_mask_all_${k1}$_${k1}$
+    #:endfor
+  #:endfor
+
+
+  #:for k1, t1 in INT_KINDS_TYPES
+    #:for rank in RANKS
+      module function mean_${rank}$_mask_all_${k1}$_dp(x, mask) result(res)
+        ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        logical, intent(in) :: mask${ranksuffix(rank)}$
+        real(dp) :: res
+
+        res = sum(real(x, dp), mask) / real(count(mask, kind = int64), dp)
+
+      end function mean_${rank}$_mask_all_${k1}$_dp
+    #:endfor
+  #:endfor
+
+
+  #:for k1, t1 in REAL_KINDS_TYPES
+    #:for rank in RANKS
+      module function mean_${rank}$_mask_${k1}$_${k1}$(x, dim, mask) result(res)
+        ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        integer, intent(in) :: dim
+        logical, intent(in) :: mask${ranksuffix(rank)}$
+        ${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
+
+        if (dim >= 1 .and. dim <= ${rank}$) then
+          res = sum(x, dim, mask) / real(count(mask, dim), ${k1}$)
+        else
+          call error_stop("ERROR (mean): wrong dimension")
+        end if
+
+      end function mean_${rank}$_mask_${k1}$_${k1}$
+    #:endfor
+  #:endfor
+
+
+  #:for k1, t1 in INT_KINDS_TYPES
+    #:for rank in RANKS
+      module function mean_${rank}$_mask_${k1}$_dp(x, dim, mask) result(res)
+        ${t1}$, intent(in) :: x${ranksuffix(rank)}$
+        integer, intent(in) :: dim
+        logical, intent(in) :: mask${ranksuffix(rank)}$
+        real(dp) :: res${reduced_shape('x', rank, 'dim')}$
+
+        if (dim >= 1 .and. dim <= ${rank}$) then
+          res = sum(real(x, dp), dim, mask) / real(count(mask, dim), dp)
+        else
+          call error_stop("ERROR (mean): wrong dimension")
+        end if
+
+      end function mean_${rank}$_mask_${k1}$_dp
+    #:endfor
+  #:endfor
+
 end submodule
diff --git a/src/tests/stats/test_mean.f90 b/src/tests/stats/test_mean.f90
index 2471fb67a..49e5a98fc 100644
--- a/src/tests/stats/test_mean.f90
+++ b/src/tests/stats/test_mean.f90
@@ -37,6 +37,17 @@ program test_mean
 call assert( sum( abs( mean(d,2) - sum(d,2)/real(size(d,2), dp) )) < dptol)
 
 
+! check mask = .false.
+
+call assert( isnan(mean(d, .false.)))
+call assert( any(isnan(mean(d, 1, .false.))))
+call assert( any(isnan(mean(d, 2, .false.))))
+
+! check mask of the same shape as input
+call assert( abs(mean(d, d > 0) - sum(d, d > 0)/real(count(d > 0), dp)) < dptol)
+call assert( sum(abs(mean(d, 1, d > 0) - sum(d, 1, d > 0)/real(count(d > 0, 1), dp))) < dptol)
+call assert( sum(abs(mean(d, 2, d > 0) - sum(d, 2, d > 0)/real(count(d > 0, 2), dp))) < dptol)
+
 !int32
 call loadtxt("array3.dat", d)
 
@@ -56,8 +67,8 @@ program test_mean
 !dp rank 3
 allocate(d3(size(d,1),size(d,2),3))
 d3(:,:,1)=d;
-d3(:,:,2)=d*1.5_dp;
-d3(:,:,3)=d*4._dp;
+d3(:,:,2)=d*1.5;
+d3(:,:,3)=d*4;
 
 call assert( abs(mean(d3) - sum(d3)/real(size(d3), dp)) < dptol)
 call assert( sum( abs( mean(d3,1) - sum(d3,1)/real(size(d3,1), dp) )) < dptol)
@@ -67,11 +78,11 @@ program test_mean
 
 !dp rank 4
 allocate(d4(size(d,1),size(d,2),3,9))
-d4 = 1.
+d4 = -1
 d4(:,:,1,1)=d;
-d4(:,:,2,1)=d*1.5_dp;
-d4(:,:,3,1)=d*4._dp;
-d4(:,:,3,9)=d*4._dp;
+d4(:,:,2,1)=d*1.5;
+d4(:,:,3,1)=d*4;
+d4(:,:,3,9)=d*4;
 
 call assert( abs(mean(d4) - sum(d4)/real(size(d4), dp)) < dptol)
 call assert( sum( abs( mean(d4,1) - sum(d4,1)/real(size(d4,1), dp) )) < dptol)
@@ -79,4 +90,20 @@ program test_mean
 call assert( sum( abs( mean(d4,3) - sum(d4,3)/real(size(d4,3), dp) )) < dptol)
 call assert( sum( abs( mean(d4,4) - sum(d4,4)/real(size(d4,4), dp) )) < dptol)
 
+! check mask = .false.
+
+call assert( isnan(mean(d4, .false.)))
+call assert( any(isnan(mean(d4, 1, .false.))))
+call assert( any(isnan(mean(d4, 2, .false.))))
+call assert( any(isnan(mean(d4, 3, .false.))))
+call assert( any(isnan(mean(d4, 4, .false.))))
+
+
+! check mask of the same shape as input
+call assert( abs(mean(d4, d4 > 0) - sum(d4, d4 > 0)/real(count(d4 > 0), dp)) < dptol)
+call assert( any(isnan(mean(d4, 1, d4 > 0))) )
+call assert( any(isnan(mean(d4, 2, d4 > 0))) )
+call assert( any(isnan(mean(d4, 3, d4 > 0))) )
+call assert( sum(abs(mean(d4, 4, d4 > 0) - sum(d4, 4, d4 > 0)/real(count(d4 > 0, 4), dp))) < dptol)
+
 end program