Commit 6008ad21 authored by Davis King's avatar Davis King

Made the resizable_tensor's assignment operator work in a more sensible way.

parent 7272bc74
...@@ -332,31 +332,12 @@ namespace dlib ...@@ -332,31 +332,12 @@ namespace dlib
const matrix_exp<EXP>& item const matrix_exp<EXP>& item
) )
{ {
set_size(item.nr(), item.nc()); if (!(num_samples() == item.nr() && k()*nr()*nc() == item.nc()))
set_size(item.nr(), item.nc());
tensor::operator=(item); tensor::operator=(item);
return *this; return *this;
} }
template <typename EXP>
resizable_tensor& operator+= (
const matrix_exp<EXP>& item
)
{
set_size(item.nr(), item.nc());
tensor::operator+=(item);
return *this;
}
template <typename EXP>
resizable_tensor& operator-= (
const matrix_exp<EXP>& item
)
{
set_size(item.nr(), item.nc());
tensor::operator-=(item);
return *this;
}
void set_size( void set_size(
long n_, long k_ = 1, long nr_ = 1, long nc_ = 1 long n_, long k_ = 1, long nr_ = 1, long nc_ = 1
) )
......
...@@ -495,45 +495,16 @@ namespace dlib ...@@ -495,45 +495,16 @@ namespace dlib
requires requires
- item contains float values - item contains float values
ensures ensures
- #num_samples() == item.nr() - if (num_samples() == item.nr() && k()*nr()*nc() == item.nc()) then
- #k() == item.nc() - the dimensions of this tensor are not changed
- #nr() == 1 - else
- #nc() == 1 - #num_samples() == item.nr()
- #k() == item.nc()
- #nr() == 1
- #nc() == 1
- Assigns item to *this tensor by performing: - Assigns item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) = item; set_ptrm(host(), num_samples(), k()*nr()*nc()) = item;
!*/ !*/
template <typename EXP>
resizable_tensor& operator+= (
const matrix_exp<EXP>& item
);
/*!
requires
- item contains float values
ensures
- #num_samples() == item.nr()
- #k() == item.nc()
- #nr() == 1
- #nc() == 1
- Adds item to *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) += item;
!*/
template <typename EXP>
resizable_tensor& operator-= (
const matrix_exp<EXP>& item
);
/*!
requires
- item contains float values
ensures
- #num_samples() == item.nr()
- #k() == item.nc()
- #nr() == 1
- #nc() == 1
- Subtracts item from *this tensor by performing:
set_ptrm(host(), num_samples(), k()*nr()*nc()) -= item;
!*/
}; };
void serialize(const tensor& item, std::ostream& out); void serialize(const tensor& item, std::ostream& out);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment