/* distort.cpp
*
*/
#include <stdio.h>
#include <math.h>
#include "image.h"

/*
*    x =  (i-xcen)*dx
*    y =  (ycen-j)*dy
*
*/
class Image_Scale {
public:
	Image_Scale(IMAGE *ip);
	double dx, dy;
	int xcen, ycen;
	double x(int i) { return (i-xcen)*dx; }
	double y(int j) { return (ycen-j)*dy; }
	int nx(double x) { return (int)(xcen + x/dx); }
	int ny(double y) { return (int)(ycen - y/dy); }
};

Image_Scale::Image_Scale(IMAGE *ip)
{
	dx = dy = 4.0/(ip->hlen+ip->vlen);
	xcen = ip->hlen/2;
	ycen = ip->vlen/2;
}

void distort(IMAGE *in, IMAGE *out, double distortion);

// defined in copyimg.cpp
void copy_image(IMAGE *src, IMAGE *dst);


int main(int argc, char *argv[])
{
	IMAGE *in, *out, *tmp;
	char *infile,*outfile;
	double distortion;

	if (argc<4) {
		printf("usage: distortion infile outfile distortion\n");
		return -1;
	}
	infile=argv[1];
	outfile=argv[2];
	sscanf(argv[3]," %lf",&distortion);
	

	in = open_image(infile);
	if (!in) return -1;
	printf("input image: %s\n",infile);
	printf("image size: %d x %d\n",in->hlen,in->vlen);

	tmp = make_image(NULL, in->hlen, in->vlen,in->type);
	if (!in) return -1;
	printf("copying input to local memory . . .\n");
	copy_image(in,tmp);

	printf("distortion: %g\n",distortion);

	out = make_image(outfile,in->hlen,in->vlen,in->type);
	if (!out) return -1;
	printf("distorting image . . .\n");
	distort(tmp,out,distortion);
	printf("output image: %s\n",outfile);
	printf("image size: %d x %d\n",out->hlen,out->vlen);
	return 0;
}



void distort(IMAGE *in, IMAGE *out, double distortion)
{
	pixel *buf, *bp;
	int i,j;
	int nx, ny;
	double mag, base;
	double vin[2];
	double vout[2];
	double hsq;
	Image_Scale uv(out), xy(in);
	buf = make_buffer(out);
	vin[0] = vin[1] = 0.0;
	base = (distortion<0.0? 1.0 - distortion: 1.0 - 2.0*distortion);
	for (j=0;  j<out->vlen;  j++) {
		vout[1] = uv.y(j);
		bp = buf;
		for (i=0;  i<out->hlen;  i++, bp++) {
			*bp = 0;
			vout[0] = uv.x(i);
			hsq = vout[0]*vout[0]+vout[1]*vout[1];
			mag = base+distortion*hsq;
			vin[0] = vout[0]/mag;
			vin[1] = vout[1]/mag;
			nx = xy.nx(vin[0]);
			if (nx<0 || nx>=in->hlen) continue;
			ny = xy.ny(vin[1]);
			if (ny<0 || ny>=in->vlen) continue;
			*bp = in->get_pixel(nx, ny, in->type);
		}
		put_line(out,j,buf,in->type);
	}
	free_buffer(buf);
}