Home > samsung-code > detSingleGP.m

detSingleGP

PURPOSE ^

DETSINGLEGP performs GP-based Fine-grained Target Search (FTS) on a single image with initial detection outputs

SYNOPSIS ^

function [newBs, newSs] = detSingleGP( I, bboxes, S, det_model, gp_thresh,CategIdOfInterest, displayProposal )

DESCRIPTION ^

 DETSINGLEGP performs GP-based Fine-grained Target Search (FTS) on a single image with initial detection outputs
 
 Usage:

   [newB, newS] = detSingleGP( I, bboxes, S, det_model, thresh, CategIdOfInterest, displayProposal )

 Input:

   I: an image matrix loaded by imread (e.g., I = imread('000220.jpg');) 

   det_model: is the detection model loaded by detInit(...)

   bboxes: can be M*4 maxtrix for intial bounding box coordinates, where 
     M is the number of initial bounding boxes. Each row should be in the
     form of [ymin, xmin, ymax, xmax].

   gp_thresh: threshold for candidate region for the GP-based FTS.
       Default value: -1

   CategOfInterest4GP: can be a string cell array indicating which
     categories FTS should be applied. It is only useful when gp_enable==1
     By default FTS is applied to all the categories.
     E.g. CategOfInterest4GP = {'aeroplane','cow'}

   displayProposal: can be 0 (default) or 1 to indicate whether to show
     step-by-step FTS proposals

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [newBs, newSs] = detSingleGP( I, bboxes, S, det_model, gp_thresh, ...
0002     CategIdOfInterest, displayProposal )
0003 % DETSINGLEGP performs GP-based Fine-grained Target Search (FTS) on a single image with initial detection outputs
0004 %
0005 % Usage:
0006 %
0007 %   [newB, newS] = detSingleGP( I, bboxes, S, det_model, thresh, CategIdOfInterest, displayProposal )
0008 %
0009 % Input:
0010 %
0011 %   I: an image matrix loaded by imread (e.g., I = imread('000220.jpg');)
0012 %
0013 %   det_model: is the detection model loaded by detInit(...)
0014 %
0015 %   bboxes: can be M*4 maxtrix for intial bounding box coordinates, where
0016 %     M is the number of initial bounding boxes. Each row should be in the
0017 %     form of [ymin, xmin, ymax, xmax].
0018 %
0019 %   gp_thresh: threshold for candidate region for the GP-based FTS.
0020 %       Default value: -1
0021 %
0022 %   CategOfInterest4GP: can be a string cell array indicating which
0023 %     categories FTS should be applied. It is only useful when gp_enable==1
0024 %     By default FTS is applied to all the categories.
0025 %     E.g. CategOfInterest4GP = {'aeroplane','cow'}
0026 %
0027 %   displayProposal: can be 0 (default) or 1 to indicate whether to show
0028 %     step-by-step FTS proposals
0029 %
0030 
0031 if ~exist('thresh','var') || isempty(gp_thresh)
0032     gp_thresh = -1;
0033 end
0034 if isscalar( gp_thresh )
0035     gp_thresh = repmat(gp_thresh, 1, length( det_model.categ_list ));
0036 end
0037 
0038 if ~exist('CategIdOfInterest','var') || isempty(CategIdOfInterest)
0039     CategIdOfInterest = 1:length(det_model.categ_list);
0040 end
0041 CategIdOfInterest = reshape(CategIdOfInterest,1,numel(CategIdOfInterest));
0042 
0043 if ~exist('displayProposal','var') || isempty(displayProposal)
0044     displayProposal = 0;
0045 end
0046 
0047 
0048 newBs = cell(length(det_model.categ_list),1);
0049 newSs = cell(length(det_model.categ_list),1);
0050 
0051 %% Prepare models
0052 classifier_model = cell(length(CategIdOfInterest),1);
0053 GPmodel = cell(length(CategIdOfInterest),1);
0054 for c = 1:length(CategIdOfInterest)
0055     classifier_model{c} = struct( ...
0056         'w',     {det_model.classifier.w(c,:)}, ...
0057         'bias',  {det_model.classifier.bias(c)}, ...
0058         'type',  {det_model.classifier.type} );
0059     GPmodel{c} = sgp_model_from_general( det_model.gp(c).hyp );
0060 end
0061 classifier_model = cell2mat(classifier_model);
0062 GPmodel = cell2mat(GPmodel);
0063 
0064 feat_func = @(varargin) ind( det_model.cnn.feat_func(varargin{:}), 1 );
0065 
0066 %% set up solver
0067 
0068 Solver_Timeout = 10;
0069 minFunc_Method = 'lbfgs';
0070 minFuncX_OPTS = struct();
0071 minFuncX_OPTS.timeout = Solver_Timeout;
0072 minFuncX_OPTS.Display = 'off';
0073 minFuncX_OPTS.Method  = minFunc_Method;
0074 min_func = @(func,x0) minFuncX( func,x0, minFuncX_OPTS);
0075 
0076 
0077 %% GPSearch
0078 bboxParamType = det_model.gp(1).BBoxParamType;
0079 
0080 scoreThreshold         = gp_thresh;
0081 maxGPIter              = 32;
0082 maxGPGapNum            = 8;
0083 maxLocalityNumPerImage = inf;
0084 
0085 nmsThreshold       = 0.3;
0086 localIoUThresholds = sort([0.3,0.5,0.7], 'ascend' );
0087 
0088 baggingNum = length(localIoUThresholds);
0089 categNum = length(CategIdOfInterest);
0090 
0091 % set up FTS input
0092 
0093 curScores = cell(categNum,1);
0094 curBoxes  = bboxes;
0095 [curBoxes, ia] = unique(curBoxes,'rows','last');
0096 curBoxes  = curBoxes(ia,:);
0097 
0098 for c = 1:categNum
0099     curScores{c} = vec(S(c,ia));
0100 end
0101 
0102 imHeight = size(I,1);
0103 imWidth  = size(I,2);
0104 
0105 % ============ FTS procedure START
0106 
0107 curBParams = bbox_ltrb2param( curBoxes, bboxParamType );
0108 curBoxes   = repmat( {curBoxes}, categNum, 1 );
0109 curBParams = repmat( {curBParams}, categNum, 1 );
0110 
0111 gapIterNum = zeros(categNum,1);
0112 activeClassIdx = 1:categNum;
0113 
0114 for iter=1:maxGPIter
0115 
0116     nextBParams = cell(categNum,1);
0117     nextOrigin  = cell(categNum,1);
0118     bestScores  = cell(categNum,1);
0119 
0120     for c = activeClassIdx
0121 
0122         goodIdxB = ( curScores{c}>=scoreThreshold(c) );
0123         goodIdx  = find( goodIdxB );
0124         furtherGoodIdx  = nms( [curBoxes{c}(goodIdxB,:), ...
0125             curScores{c}(goodIdxB)], nmsThreshold, 'iou' );
0126         anchorIdx = goodIdx(furtherGoodIdx);
0127 
0128         if isempty(anchorIdx)
0129             activeClassIdx = setdiff( activeClassIdx, c );
0130             continue;
0131         end
0132 
0133         anchorBoxes  = curBoxes{c}(anchorIdx,:);
0134         anchorScores = curScores{c}(anchorIdx);
0135 
0136         [~,further2BestIdx] = sort( anchorScores, 'descend' );
0137         further2BestIdx = further2BestIdx( ...
0138             1:min( length(further2BestIdx), maxLocalityNumPerImage ) );
0139         anchorScores  = anchorScores(further2BestIdx);
0140         anchorBoxes   = anchorBoxes(further2BestIdx,:);
0141         anchorBParams = bbox_ltrb2param( anchorBoxes, bboxParamType );
0142         [~,~,anchorScales,~]  = bbox_ltrb2param( anchorBoxes, 'yxsal' );
0143 
0144         % local gp
0145         cur_IoU = PairedIoU( curBoxes{c}, anchorBoxes );
0146 
0147         nextBParams{c} = cell(size(anchorBoxes,1),baggingNum);
0148         nextOrigin{c}  = cell(size(anchorBoxes,1),baggingNum);
0149         bestScores{c}  = cell(size(anchorBoxes,1),baggingNum);
0150 
0151         for j=1:size(anchorBoxes, 1)
0152             for bag_id = 1:baggingNum
0153                 localIdxB = (cur_IoU(:,j)>localIoUThresholds(bag_id));
0154                 if sum(localIdxB)<3
0155                     break; % note that localIoUThresholds is in ascending order
0156                 end
0157 
0158                 PsiN1 = curBParams{c}(localIdxB,:).';
0159                 fN    = curScores{c}(localIdxB);
0160                 if bag_id == 1
0161                     fN_hat = max(fN);
0162                 end
0163                 
0164                 latent_obj = @(z) sgp_negloglik( GPmodel(c), z, PsiN1, fN );
0165                 z0 = anchorScales(j);
0166                 try
0167                     z_hat = min_func( latent_obj, z0);
0168                 catch
0169                     % warning( 'Optimization on z is failed' );
0170                     z_hat = anchorScales(j);
0171                 end
0172 
0173                 expnz = exp(-z_hat);
0174 
0175                 PsiN = PsiN1;
0176                 PsiN(GPmodel(c).idxbScaleEnabled,:)  = PsiN(GPmodel(c).idxbScaleEnabled,:)*expnz;
0177                 KN = sgp_cov( GPmodel(c), 0, PsiN );
0178 
0179                 search_obj = @(psiNp1) sgp_neg_acquisition_ei( GPmodel(c), ...
0180                     psiNp1, PsiN, fN, fN_hat, KN );
0181 
0182                 psiNp1_0 = anchorBParams(j,:).';
0183                 psiNp1_0(GPmodel(c).idxbScaleEnabled) = psiNp1_0(GPmodel(c).idxbScaleEnabled)*expnz;
0184                 try
0185                     psiNp1_hat = min_func( search_obj, psiNp1_0 );
0186                 catch
0187                     warning( 'Optimization on psiNp1_hat is failed' );
0188                     continue;
0189                 end
0190                 psiNp1_hat_1 = psiNp1_hat;
0191                 psiNp1_hat_1(GPmodel(c).idxbScaleEnabled) = psiNp1_hat_1(GPmodel(c).idxbScaleEnabled) / expnz;
0192 
0193                 if displayProposal
0194                     if isempty(I)
0195                         I = imread( TEST_DATA_LIST(k).im );
0196                     end
0197                     pbox = bbox_param2ltrb( psiNp1_hat_1.', bboxParamType );
0198                     show_bboxes(I,anchorBoxes(j,:),[],'green');
0199                     show_bboxes([],pbox,[],'yellow');
0200                     keyboard
0201                 end
0202 
0203                 bestScores{c}{j, bag_id} = fN_hat;
0204                 nextBParams{c}{j,bag_id} = psiNp1_hat_1.';
0205                 nextOrigin{c}{j,bag_id}  = [c,j,bag_id];
0206             end
0207         end
0208 
0209     end
0210 
0211     % put things together
0212     bestScores_noncell  = cell(categNum,1);
0213     nextBParams_noncell = cell(categNum,1);
0214     nextOrigin_noncell  = cell(categNum,1);
0215     for c = activeClassIdx
0216         bestScores_noncell{c}  = cat(1,bestScores{c}{:});
0217         nextBParams_noncell{c} = cat(1,nextBParams{c}{:});
0218         nextOrigin_noncell{c}  = cat(1,nextOrigin{c}{:});
0219     end
0220     bestScores_noncell  = cat(1,bestScores_noncell{:});
0221     nextBParams_noncell = cat(1,nextBParams_noncell{:});
0222     nextOrigin_noncell  = cat(1,nextOrigin_noncell{:});
0223 
0224     if isempty(bestScores_noncell)
0225         fprintf('x');
0226         break;
0227     end
0228 
0229     nextBoxes_noncell = round( bbox_param2ltrb( nextBParams_noncell, bboxParamType ) );
0230 
0231     % pruning apparent bad solutions
0232     prunedIdxB              = any( isnan(nextBoxes_noncell), 2 ) | any( abs(nextBoxes_noncell)>1e5, 2 );
0233     prunedIdxB(~prunedIdxB) = any( nextBoxes_noncell(~prunedIdxB,[3 4])<nextBoxes_noncell(~prunedIdxB,[1 2]),2);
0234     prunedIdxB(~prunedIdxB) = ...
0235         nextBoxes_noncell(~prunedIdxB,3)<1 | ...
0236         nextBoxes_noncell(~prunedIdxB,1)>imHeight | ...
0237         nextBoxes_noncell(~prunedIdxB,4)<1 | ...
0238         nextBoxes_noncell(~prunedIdxB,2)>imWidth;
0239 
0240     bestScores_noncell(prunedIdxB)    = [];
0241     nextOrigin_noncell(prunedIdxB,:)  = [];
0242     nextBoxes_noncell(prunedIdxB,:)   = [];
0243 
0244     if isempty(bestScores_noncell)
0245         fprintf('x');
0246         break;
0247     end
0248 
0249     % pruning duplicated bboxes
0250     [bestScores_noncell, sorted_idx] = sort(bestScores_noncell,'ascend');
0251     nextBoxes_noncell = nextBoxes_noncell(sorted_idx,:);
0252     nextOrigin_noncell = nextOrigin_noncell(sorted_idx,:);
0253     [nextBoxes_noncell, uq_idx] = unique([nextBoxes_noncell,nextOrigin_noncell(:,1)],'rows','first','legacy');
0254     nextBoxes_noncell  = nextBoxes_noncell(:,1:4);
0255     bestScores_noncell = bestScores_noncell(uq_idx); % use the lowest best score
0256     nextOrigin_noncell = nextOrigin_noncell(uq_idx,:);
0257 
0258     dupIdxB = false( size(bestScores_noncell) );
0259     for c = activeClassIdx
0260         thisIdxB          = (nextOrigin_noncell(:,1) == c);
0261         dupIdxB(thisIdxB) = ismember( nextBoxes_noncell(thisIdxB,:), curBoxes{c}, 'rows' );
0262     end
0263     bestScores_noncell(dupIdxB)    = [];
0264     nextOrigin_noncell(dupIdxB,:)  = [];
0265     nextBoxes_noncell(dupIdxB,:)   = [];
0266 
0267     if isempty(bestScores_noncell)
0268         fprintf('x');
0269         break;
0270     end
0271 
0272     nextBParams_noncell = bbox_ltrb2param( nextBoxes_noncell, bboxParamType );
0273 
0274     % extract features
0275     [uqBoxes, ~, ci] = unique(nextBoxes_noncell,'rows');
0276     
0277     uqF = features_from_bboxes( I, uqBoxes, ...
0278         det_model.cnn.canonical_patchsize, ...
0279         det_model.cnn.padding, feat_func, ...
0280         det_model.cnn.max_batch_num * det_model.cnn.batch_size  );
0281     uqF = cell2mat( uqF );
0282     nextF_noncell = uqF(:,ci);
0283 
0284     % compute scores
0285     gapIterNum = gapIterNum + 1;
0286     for c = activeClassIdx
0287         thisIdxB = (nextOrigin_noncell(:,1) == c);
0288         if any(thisIdxB)
0289             nextScores_c = ApplyClassifier( nextF_noncell(:,thisIdxB), classifier_model(c) ).';
0290             if any( nextScores_c>bestScores_noncell(thisIdxB) )
0291                 gapIterNum(c) = 0;
0292             end
0293             curScores{c} = [curScores{c};nextScores_c];
0294             curBoxes{c}  = [curBoxes{c};nextBoxes_noncell(thisIdxB,:)];
0295             curBParams{c}= [curBParams{c};nextBParams_noncell(thisIdxB,:)];
0296         end
0297     end
0298 
0299     if ~all(gapIterNum)
0300         fprintf('*');
0301     else
0302         fprintf('.');
0303     end
0304 
0305     activeClassIdxB = false(1,categNum);
0306     activeClassIdxB(activeClassIdx) = true;
0307     activeClassIdxB(gapIterNum>=maxGPGapNum) = false;
0308     activeClassIdx = find(activeClassIdxB);
0309 
0310     if isempty(activeClassIdx), break; end
0311 
0312 end
0313 
0314 for c = 1:categNum
0315     c1 = CategIdOfInterest(c);
0316     newN_c = length( curScores{c} ) - size(S,2);
0317     newBs{c1} = curBoxes{c}(end-newN_c+1:end,:);
0318     newSs{c1} = curScores{c}(end-newN_c+1:end);
0319 end
0320

Generated on Thu 18-Dec-2014 22:27:44 by m2html © 2005