#include "png.h"
#include <stdio.h>

const int PNG_BYTES_TO_CHECK = 4;
const int ERROR = -1;

/* read in some signature bytes, check for match
 * returns nonzero if match, otherwise zero
 */
int check_if_png(FILE *fp) {
  unsigned char buf[PNG_BYTES_TO_CHECK];
  if (fread(buf, 1, PNG_BYTES_TO_CHECK, fp) != PNG_BYTES_TO_CHECK)
    return 0;
  return(!png_sig_cmp(&buf[0], (png_size_t)0, PNG_BYTES_TO_CHECK));
}

/* reads a png file into png_ptr and info_ptr (defined in png.h),
 * then returns the grayscale bytes that compose the image
 */
png_bytep* read_png(char *fname, int &width, int &height, int &bit_depth) {
  png_structp png_ptr;
  png_infop info_ptr;

  FILE *fp = fopen(fname, "rb");
  if(fp == NULL)
    return NULL;

  if(check_if_png(fp) == 0)
    return NULL;

  png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
  if (png_ptr == NULL) {
    fclose(fp);
    return NULL;
  }

  info_ptr = png_create_info_struct(png_ptr);
  if (info_ptr == NULL) {
    fclose(fp);
    png_destroy_read_struct(&png_ptr, png_infopp_NULL, png_infopp_NULL);
    return NULL;
  }

  if (setjmp(png_jmpbuf(png_ptr))) {
    png_destroy_read_struct(&png_ptr, &info_ptr, png_infopp_NULL);
    fclose(fp);
    return NULL;
  }

  png_init_io(png_ptr, fp);
  png_set_sig_bytes(png_ptr, PNG_BYTES_TO_CHECK);

  // high-level read
  png_read_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, png_voidp_NULL);
  width = (int)png_get_image_width(png_ptr, info_ptr);
  height = (int)png_get_image_height(png_ptr, info_ptr);
  bit_depth = png_get_bit_depth(png_ptr, info_ptr);

  printf("done\n");
  fclose(fp);
  return png_get_rows(png_ptr, info_ptr);
}

/* writes the grayscale image with the given width, height, bit_depth, and bytes
 */
int write_png(char *file_name, int width, int height, int bit_depth, png_bytep *row_pointers) {
  FILE *fp;
  png_structp png_ptr;
  png_infop info_ptr;

  fp = fopen(file_name, "wb");
  if (fp == NULL)
    return ERROR;

  png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
  if (png_ptr == NULL) {
    fclose(fp);
    return ERROR;
  }

  info_ptr = png_create_info_struct(png_ptr);
  if (info_ptr == NULL) {
    fclose(fp);
    png_destroy_read_struct(&png_ptr, png_infopp_NULL, png_infopp_NULL);
    return ERROR;
  }

  if (setjmp(png_jmpbuf(png_ptr))) {
    png_destroy_read_struct(&png_ptr, &info_ptr, png_infopp_NULL);
    fclose(fp);
    return ERROR;
  }

  png_init_io(png_ptr, fp);
  png_set_IHDR(png_ptr, info_ptr, width, height, bit_depth, PNG_COLOR_TYPE_GRAY,
	       PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, PNG_FILTER_TYPE_DEFAULT);
  png_set_rows(png_ptr, info_ptr, row_pointers);
  png_write_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, png_voidp_NULL);

  return 0;
}

/* allocates new storage for an 8-bit image of the specified dimensions
 */
png_bytep* make_pixels(int width, int height) {
  int pixel_size = 1;
  png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
  if (png_ptr == NULL) {
    return NULL;
  }
  png_bytep *row_pointers = (unsigned char **)png_malloc(png_ptr, height*sizeof(png_bytep));
  for (int i = 0; i < height; i++)
    row_pointers[i] = (unsigned char *)png_malloc(png_ptr, width*pixel_size);
  return row_pointers;
}

int main(int argc, char **argv) {
  if(argc != 2) {
    fprintf(stderr, "usage: %s filename.png\n", argv[0]);
    return -1;
  }
  char *fname = argv[1];
  png_bytep *src_rows;
  int width, height, bit_depth;

  // open the image and get the image bytes
  if((src_rows = read_png(fname, width, height, bit_depth)) == NULL) {
    fprintf(stderr, "png read failed: %s\n", fname);
    return -1;
  }

  // be sure we're working in 8-bit grayscale
  if(bit_depth != 8) {
    fprintf(stderr, "image not grayscale, aborting\n");
    return -1;
  }

  // make new array of inverted pixels
  png_bytep *dest_rows = make_pixels(width, height);
  for(int row = 0; row < height; row++) {
    for(int col = 0; col < width; col++) {
      dest_rows[row][col] = src_rows[row][col] ^ 0xff;
    }
  }

  // save the result
  char oname[] = "output.png";
  if(write_png(oname, width, height, bit_depth, dest_rows) == ERROR) {
    fprintf(stderr, "png write failed: %s\n", oname);
    return -1;
  }

  return 0;
}

