#include <windows.h>
#include <commctrl.h>
#include <stdio.h>
#include <vd2/VDXFrame/VideoFilter.h>
#include <vd2/VDXFrame/VideoFilterDialog.h>
#include "resource.h"

#include <math.h>
#include <omp.h>

extern int g_VFVAPIVersion;

class BirdRemovalFilterConfig {
public:
	BirdRemovalFilterConfig()
		: threshold(6.5f)
		, searchRadius(7.0f)
		, paddingRadius(1.0f)
		, blendRadius(2.0f)
		, debugMode(0)
	{
	}

public:
	float threshold;
	float searchRadius;
	float paddingRadius;
	float blendRadius;
	int debugMode;
};

class BirdRemovalFilterDialog : public VDXVideoFilterDialog {
public:
	BirdRemovalFilterDialog(BirdRemovalFilterConfig& config, IVDXFilterPreview *ifp) : mConfig(config), mifp(ifp) {lock = false;}

	bool Show(HWND parent) {
		return 0 != VDXVideoFilterDialog::Show(NULL, MAKEINTRESOURCE(IDD_FILTER_BIRDREMOVAL), parent);
	}

	virtual INT_PTR DlgProc(UINT msg, WPARAM wParam, LPARAM lParam);

protected:
	bool OnInit();
	bool OnCommand(int cmd);
	void OnDestroy();

	void LoadFromConfig();
	bool SaveToConfig();

	bool lock;

	BirdRemovalFilterConfig& mConfig;
	BirdRemovalFilterConfig mOldConfig;
	IVDXFilterPreview *const mifp;
};

INT_PTR BirdRemovalFilterDialog::DlgProc(UINT msg, WPARAM wParam, LPARAM lParam) {
	switch(msg) {
		case WM_INITDIALOG:
			return !OnInit();

		case WM_DESTROY:
			OnDestroy();
			break;

		case WM_COMMAND:
			if (!lock && (HIWORD(wParam) == EN_CHANGE || HIWORD(wParam) == CBN_SELCHANGE)) {
				//MessageBoxA ( NULL, "Called", "Test", MB_OK );
				lock = true;
				if (mifp && SaveToConfig())
					mifp->RedoFrame();
				lock = false;
				return TRUE;
			}
			if (OnCommand(LOWORD(wParam)))
				return TRUE;
			break;
	}

	return FALSE;
}

bool BirdRemovalFilterDialog::OnInit() {
	lock = true;
	mOldConfig = mConfig;

	HWND hwndMap = GetDlgItem(mhdlg, IDC_DEBUG);  
	SendMessage(hwndMap, CB_ADDSTRING, 0, (LPARAM)TEXT("None"));
	SendMessage(hwndMap, CB_ADDSTRING, 0, (LPARAM)TEXT("Detected pixels"));
	SendMessage(hwndMap, CB_ADDSTRING, 0, (LPARAM)TEXT("Mask or Alpha of changes"));
	SendMessage(hwndMap, CB_ADDSTRING, 0, (LPARAM)TEXT("Outline changed regions"));
	//SendMessage(hwndMap, CB_ADDSTRING, 0, (LPARAM)TEXT("Movement over threshold"));
	
	LoadFromConfig();

	HWND hwndFirst = GetDlgItem(mhdlg, IDC_EDIT2);
	if (hwndFirst)
		SendMessage(mhdlg, WM_NEXTDLGCTL, (WPARAM)hwndFirst, TRUE);

	HWND hwndPreview = GetDlgItem(mhdlg, IDC_PREVIEW);
	if (hwndPreview && mifp) {
		EnableWindow(hwndPreview, TRUE);
		mifp->InitButton((VDXHWND)hwndPreview);
	}

	lock = false;
	return false;
}

void BirdRemovalFilterDialog::OnDestroy() {
	if (mifp)
		mifp->InitButton(NULL);
}

bool BirdRemovalFilterDialog::OnCommand(int cmd) {
	switch(cmd) {
		case IDOK:
			SaveToConfig();
			EndDialog(mhdlg, true);
			return true;

		case IDCANCEL:
			mConfig = mOldConfig;
			EndDialog(mhdlg, false);
			return true;

		case IDC_PREVIEW:
			if (mifp)
				mifp->Toggle((VDXHWND)mhdlg);
			return true;
	}

	return false;
}

#define STR_BUFFER 1024

LPARAM toText(float f) {
	char str[STR_BUFFER] = "";
	sprintf(str, "%f", f);
	return (LPARAM)str;
}

void BirdRemovalFilterDialog::LoadFromConfig() {
	SendDlgItemMessage(mhdlg, IDC_EDIT2, WM_SETTEXT, 0, toText(mConfig.threshold));
	SendDlgItemMessage(mhdlg, IDC_EDIT3, WM_SETTEXT, 0, toText(mConfig.searchRadius));
	SendDlgItemMessage(mhdlg, IDC_EDIT5, WM_SETTEXT, 0, toText(mConfig.paddingRadius));
	SendDlgItemMessage(mhdlg, IDC_EDIT4, WM_SETTEXT, 0, toText(mConfig.blendRadius));	
	SendMessage(GetDlgItem(mhdlg, IDC_DEBUG), CB_SETCURSEL, (WPARAM)mConfig.debugMode, 0);		
}

float getFloat(HWND dialog, int idc) {
	char str[STR_BUFFER] = "";
	SendDlgItemMessage(dialog, idc, WM_GETTEXT, STR_BUFFER, (LPARAM)&str);
	return atof(str);
}

bool BirdRemovalFilterDialog::SaveToConfig() {
	float threshold = getFloat(mhdlg, IDC_EDIT2);
	float searchRadius = getFloat(mhdlg, IDC_EDIT3);
	float paddingRadius = getFloat(mhdlg, IDC_EDIT5);
	float blendRadius = getFloat(mhdlg, IDC_EDIT4);
	int debugMode = SendMessage(GetDlgItem(mhdlg, IDC_DEBUG), CB_GETCURSEL, 0, 0);		
	
	if (mConfig.debugMode != debugMode 
		|| mConfig.threshold != threshold 
		|| mConfig.searchRadius != searchRadius 
		|| mConfig.paddingRadius != paddingRadius 
		|| mConfig.blendRadius != blendRadius) {

		mConfig.threshold = threshold;
		mConfig.searchRadius = searchRadius;
		mConfig.paddingRadius = paddingRadius;
		mConfig.blendRadius = blendRadius;
		mConfig.debugMode = debugMode;
		return true;
	}

	return false;
}

class BirdRemovalFilter : public VDXVideoFilter {
public:
	virtual uint32 GetParams();
	virtual void Start();
	virtual void Run();
	virtual bool Configure(VDXHWND hwnd);
	virtual void GetSettingString(char *buf, int maxlen);
//	virtual void GetScriptString(char *buf, int maxlen);
	virtual bool Prefetch2(sint64 frame, IVDXVideoPrefetcher *prefetcher);

//	VDXVF_DECLARE_SCRIPT_METHODS();

protected:
	void BirdRemovalFilter::DetectBirds2(void *dst, void *prev, void *next, ptrdiff_t dstpitch, uint32 w, uint32 h);
	BirdRemovalFilterConfig mConfig;
};

//VDXVF_BEGIN_SCRIPT_METHODS(BirdRemovalFilter)
//VDXVF_DEFINE_SCRIPT_METHOD(BirdRemovalFilter, ScriptConfig, "iii")
//VDXVF_END_SCRIPT_METHODS()

uint32 BirdRemovalFilter::GetParams() {
	if (g_VFVAPIVersion >= 12) {
		switch(fa->src.mpPixmapLayout->format) {
			case nsVDXPixmap::kPixFormat_XRGB8888:
				break;
			default:
				return FILTERPARAM_NOT_SUPPORTED;
		}
	}

	fa->dst.offset = fa->src.offset;
	return FILTERPARAM_SUPPORTS_ALTFORMATS;
}

void BirdRemovalFilter::Start() {
}

bool BirdRemovalFilter::Prefetch2(sint64 frame, IVDXVideoPrefetcher *prefetcher) {
	prefetcher->PrefetchFrame(0, frame, 0);
	prefetcher->PrefetchFrame(0, frame-1, 0);
	prefetcher->PrefetchFrame(0, frame+1, 0);
	return true;
}

void BirdRemovalFilter::Run() {
	if (g_VFVAPIVersion >= 12) {
		if (fa->mSourceFrameCount < 3)
			return;		
		const VDXPixmap& pxdst = *fa->dst.mpPixmap;
		//const VDXPixmap& pxsrc = *fa->src.mpPixmap;
		const VDXPixmap& prev = *fa->mpSourceFrames[1]->mpPixmap;
		const VDXPixmap& next = *fa->mpSourceFrames[2]->mpPixmap;
		if (pxdst.format ==  nsVDXPixmap::kPixFormat_XRGB8888) {
			DetectBirds2(pxdst.data, prev.data, next.data, pxdst.pitch, pxdst.w, pxdst.h);
		}
	}
}

bool BirdRemovalFilter::Configure(VDXHWND hwnd) {
	BirdRemovalFilterDialog dlg(mConfig, fa->ifp);
	return dlg.Show((HWND)hwnd);
}

void BirdRemovalFilter::GetSettingString(char *buf, int maxlen) {
	SafePrintf(buf, maxlen, " (T%.1f%%, S%.1f, P%.1f, B%.1f, D%d)"
			, mConfig.threshold
			, mConfig.searchRadius
			, mConfig.paddingRadius
			, mConfig.blendRadius
			, mConfig.debugMode
			);
}

/*void BirdRemovalFilter::GetScriptString(char *buf, int maxlen) {
	SafePrintf(buf, maxlen, "Config(%u, %u, %u)"
			, (unsigned)(mConfig.mYScale * 1000.0f + 0.5f)
			, (unsigned)(mConfig.mUScale * 1000.0f + 0.5f)
			, (unsigned)(mConfig.mVScale * 1000.0f + 0.5f)
			);
}

void BirdRemovalFilter::ScriptConfig(IVDXScriptInterpreter *isi, const VDXScriptValue *argv, int argc) {
	mConfig.mYScale = argv[0].asInt() / 1000.0f;
	mConfig.mUScale = argv[1].asInt() / 1000.0f;
	mConfig.mVScale = argv[2].asInt() / 1000.0f;
}*/

int distanceSqr(int r, int g, int b, int x, int y, int z) {
	int rx = r-x;
	int gy = g-y;
	int bz = b-z;
	return rx*rx + gy*gy + bz*bz;
}

int distanceSqr(int a, int b, int x, int y) {
	int ax = a - x;
	int by = b - y;
	return ax*ax + by*by;
}

inline int floatTo256(float f) {	
	if (f <= 0.f)
		return 0;
	else if (f > 1.f)
		return 255;	
	return 255*f;
}

inline float scaleBy(float f, float scale) {
	return (f + scale * .5f) / scale; 
}

int limit(int i, int min, int max) {
	if (i < min)
		return min;
	if (i > max)
		return max;
	return i;
}

int shortestDistanceSqr(void *dst, int x2, int y2, int radius, int r, int g, int b, ptrdiff_t dstpitch, uint32 w, uint32 h) {
	int shortest = 1000;

	int xf = limit(x2 - radius, 0, w-1);
	int xt = limit(x2 + radius, 0, w-1);
	int yf = limit(y2 - radius, 0, h-1);
	int yt = limit(y2 + radius, 0, h-1);
	
	for(uint32 y=yf; y<=yt; ++y) {
		for(uint32 x=xf; x<=xt; ++x) {
			uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;

			int dist = distanceSqr(r, g, b, p[0], p[1], p[2]);
			if (dist < shortest) {
				shortest = dist;
			}
		}
	}
	return shortest;
}

int closest(bool *detected, int x2, int y2, float r, int w, int h) {
	if (detected[w*y2+x2]) {
		return 0;
	}

	int shortest = r*r*r;
	int radius = (int)ceil(r);

	int xf = limit(x2 - radius, 0, w-1);
	int xt = limit(x2 + radius, 0, w-1);
	int yf = limit(y2 - radius, 0, h-1);
	int yt = limit(y2 + radius, 0, h-1);

	for(int y3=yf; y3<=yt; y3++) {
		for(int x3=xf; x3<=xt; x3++) {
			if (detected[w*y3+x3]) {
				int dist = distanceSqr(x3,y3,x2,y2);
				if (dist < shortest) {
					shortest = dist;
				}
			}
		}
	}
	return shortest;
}

bool nextTo(bool *detected, int x2, int y2, int w, int h) {
	int radius = 1;
	int xf = limit(x2 - radius, 0, w-1);
	int xt = limit(x2 + radius, 0, w-1);
	int yf = limit(y2 - radius, 0, h-1);
	int yt = limit(y2 + radius, 0, h-1);
	for(uint32 y3=yf; y3<=yt; y3++) {
		for(uint32 x3=xf; x3<=xt; x3++) {
			if (detected[w*y3+x3]) {
				return true;
			}
		}
	}
	return false;
}

float lerp(float a, float b, float blend) {
	return b + blend * (a-b);
}

float clamp(float f) {
	if (f < 0.f)
		return 0.f;
	if (f > 1.f)
		return 1.f;
	return f;
}

void BirdRemovalFilter::DetectBirds2(void *dst, void *prev, void *next, ptrdiff_t dstpitch, uint32 w, uint32 h) {
	int SEARCH_RADIUS = mConfig.searchRadius;	
	float maskRadius = mConfig.blendRadius + mConfig.paddingRadius;
	maskRadius *= maskRadius;
	float paddingSqr = mConfig.paddingRadius * mConfig.paddingRadius;


	int DEBUG_LEVEL = mConfig.debugMode;

	// sqrt (a^2 + b^2 + c^2), a=1, b=1, c=1
	int DETECTION_THRESHOLD = (int)(mConfig.threshold / 100.f * sqrt(3.f) * 255.f);
	DETECTION_THRESHOLD *= DETECTION_THRESHOLD;

	omp_set_num_threads(omp_get_num_procs());

	uint8 *dst8 = (uint8 *)dst;
	uint8 *prev8 = (uint8 *)prev;
	uint8 *next8 = (uint8 *)next;	
	bool* detected = new bool[w*h];	
	//bool* movement = new bool[w*h];	

	signed int x = 0;

	/*#pragma omp parallel for private(x)
	for(signed int y=0; y<h; ++y) {
		for(x=0; x<w; ++x) {					
			uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;
			p[0] = floatTo256((float)y / (float)h);
			p[1] = floatTo256((float)x / (float)w);
			p[2] = 0;
		}
	}
	delete detected;
	return;*/

	// STATIC PIXELS
	/*
	if (DEBUG_LEVEL == 4) {
		#pragma omp parallel for private(x)
		for(signed int y=0; y<h; ++y) {
			for(x=0; x<w; ++x) {
				uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;
				uint8 *pp = (uint8 *)prev + y * dstpitch + x * 4;
				uint8 *pn = (uint8 *)next + y * dstpitch + x * 4;			

				int dist = distanceSqr(pp[0], pp[1], pp[2], p[0], p[1], p[2]);
				int dist2 = distanceSqr(pn[0], pn[1], pn[2], p[0], p[1], p[2]);
				int dist3 = distanceSqr(pn[0], pn[1], pp[2], pp[0], pp[1], pp[2]);

				if (dist > DETECTION_THRESHOLD && dist2 > DETECTION_THRESHOLD && dist3 > DETECTION_THRESHOLD) {
					movement[w*y+x] = true;
				} else {
					movement[w*y+x] = false;
				}				
			}
		}
	}
	*/

	// ARTIFACT DETECTION
	if (DEBUG_LEVEL != 4) {
		#pragma omp parallel for private(x)
		for(signed int y=0; y<h; ++y) {
			for(x=0; x<w; ++x) {
				uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;
				uint8 *pn = (uint8 *)prev + y * dstpitch + x * 4;

				int dist = distanceSqr(pn[0], pn[1], pn[2], p[0], p[1], p[2]);

				if (dist > DETECTION_THRESHOLD) {
					int shortest = shortestDistanceSqr(prev, x, y, SEARCH_RADIUS, p[0], p[1], p[2], dstpitch, w, h);
					int shortest2 = shortestDistanceSqr(next, x, y, SEARCH_RADIUS, p[0], p[1], p[2], dstpitch, w, h);

					if (shortest > DETECTION_THRESHOLD && shortest2 > DETECTION_THRESHOLD) {
						detected[w*y+x] = true;
					} else {
						detected[w*y+x] = false;
					}				
				} else {
					detected[w*y+x] = false;
				}

			}
		}
	}

	// FINAL RESULT
	if (DEBUG_LEVEL == 3) {
		bool* outline = new bool[w*h];	
		#pragma omp parallel for private(x)
		for(signed int y=0; y<h; ++y) {
			for(x=0; x<w; ++x) {		
				float c = closest(detected, x, y, maskRadius, w, h);
				if (c <= maskRadius) {				
					outline[w*y+x] = true;
				} else {
					outline[w*y+x] = false;
				}
			}
		}
		#pragma omp parallel for private(x)
		for(signed int y=0; y<h; ++y) {
			for(x=0; x<w; ++x) {					
				if (outline[w*y+x] == false && nextTo(outline, x, y, w, h) ) {
					uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;
					p[0] = (uint8)0;
					p[1] = (uint8)0;
					p[2] = (uint8)255;
				}
			}
		}
		delete outline;
	} else {
		#pragma omp parallel for private(x)
		for(signed int y=0; y<h; y++) {
			for(x=0; x<w; x++) {
				uint8 *p = (uint8 *)dst + y * dstpitch + x * 4;
				if (DEBUG_LEVEL == 1) {
					if ( detected[w*y+x] ) {
						p[0] = (uint8)0;
						p[1] = (uint8)0;
						p[2] = (uint8)255;
					}
					continue;
				}
				/*if (DEBUG_LEVEL == 4) {
					if ( movement[w*y+x] ) {
						p[0] = (uint8)0;
						p[1] = (uint8)0;
						p[2] = (uint8)255;
					}
					continue;
				}*/
				float c = closest(detected, x, y, maskRadius, w, h);
				if (c < paddingSqr) {
					c = 1.f;
				} else {
					c = sqrt(c);
					c -= mConfig.paddingRadius;
					c = clamp((mConfig.blendRadius - c) / mConfig.blendRadius);
				}

				if (DEBUG_LEVEL == 2) {			
					int val = floatTo256(c);
					p[0] = val;
					p[1] = val;
					p[2] = val;
				} else {
					uint8 *pp = (uint8 *)prev + y * dstpitch + x * 4;
					uint8 *pn = (uint8 *)next + y * dstpitch + x * 4;
					for (int i=0; i<3; i++)
						p[i] = lerp(((pp[i] + pn[i]) / 2), p[i], c);
				}
			}
		}
	}
	delete detected;
	//delete movement;
}

/*void convertVDUBtoMAT(void *vdub, Mat3b& mat, ptrdiff_t dstpitch, uint32 w, uint32 h) {
	uint8 *dst8 = (uint8 *)vdub;
	for(uint32 y=0; y<h; y++) {
		uint8 *p = dst8;
		for(uint32 x=0; x<w; x++) {			
			mat(Point(x,y))[0] = p[0];
			mat(Point(x,y))[1] = p[1];
			mat(Point(x,y))[2] = p[2];
 			p += 4;
		}
		dst8 += dstpitch;
	}
	
}

void rainbow(Mat3b& mat,  uint32 w, uint32 h) {
	for(uint32 y=0; y<h; y++) {
		for(uint32 x=0; x<w; x++) {

			mat(Point(x,y))[0] = (float)x / (float)w * 255.f;
			mat(Point(x,y))[1] = 0;
			mat(Point(x,y))[2] = 0;	
		}
	}
}

void convertMATtoVDUB(void *vdub, Mat3b& mat, ptrdiff_t dstpitch, uint32 w, uint32 h) {
	uint8 *dst8 = (uint8 *)vdub;
	for(uint32 y=0; y<h; y++) {
		uint8 *p = dst8;
		for(uint32 x=0; x<w; x++) {
			Vec3b v = mat(Point(x,y));
			p[0] = v[0];
			p[1] = v[1];
			p[2] = v[2];
 			p += 4;
		}
		dst8 += dstpitch;		
	}	
}*/


///////////////////////////////////////////////////////////////////////////////

extern VDXFilterDefinition filterDef_birdRemoval = VDXVideoFilterDefinition<BirdRemovalFilter>("Jerry Ylilammi", "Bird Removal", "Filter for removing birds/artifacts from timelapse videos");
