#!/usr/bin/env python3 # ****************************************************************************** # $Id: pct2rgb.py 13087 2007-11-26 20:56:29Z hobu $ # # Name: gdal_lut # Project: GDAL Python Interface # Purpose: Utility to apply a lookup table provided in a text file. # Author: Frank Warmerdam, warmerdam@pobox.com # # ****************************************************************************** # Copyright (c) 2008, Frank Warmerdam # # SPDX-License-Identifier: MIT # ****************************************************************************** import sys import numpy as np from osgeo import gdal gdal.TermProgress = gdal.TermProgress_nocb # ============================================================================= # read_lut() # # Read and parse the LUT file. # ============================================================================= def read_lut(filename): lines = open(filename).readlines() lut = [] for line in lines: lut.append(int(line)) return lut # ============================================================================= # Usage() # ============================================================================= def Usage(): print( """ Usage: gdal_lut.py src_file [-srcband] [dst_file] [-dstband] -lutfile filename [-of format] [-co name=value]... If dst_file is not specified, the result will be applied back to src_file. The text file specified with -lutfile should have one line per LUT entry with just the output pixel value. Thus a LUT file like: 0 5 11 12 12 13 would map input pixel values 0,1,2,3,4,5 to 0,5,11,12,12,13 respectively. Values not mapped by the lut file (for instance values 6-255 in the above case) will be left unaltered. Sixteen bit (UInt16) output values are supported as well as luts of more than 256 input values. """ ) return 2 def main(argv=sys.argv): driver_name = "GTiff" src_filename = None dst_filename = None src_band_n = 1 dst_band_n = 1 lut_filename = None create_options = [] gdal.AllRegister() argv = gdal.GeneralCmdLineProcessor(argv) if argv is None: return 0 # Parse command line arguments. i = 1 while i < len(argv): arg = argv[i] if arg == "-of": i = i + 1 driver_name = argv[i] elif arg == "-co": i = i + 1 create_options.append(argv[i]) elif arg == "-lutfile": i = i + 1 lut_filename = argv[i] elif arg == "-srcband": i = i + 1 src_band_n = int(argv[i]) elif arg == "-dstband": i = i + 1 dst_band_n = int(argv[i]) elif src_filename is None: src_filename = argv[i] elif dst_filename is None: dst_filename = argv[i] else: return Usage() i = i + 1 if src_filename is None or lut_filename is None: return Usage() # ---------------------------------------------------------------------------- # Load the LUT file. lut = read_lut(lut_filename) max_val = 0 for entry in lut: if entry > max_val: max_val = entry if max_val > 255: tc = np.uint16 gc = gdal.GDT_UInt16 else: tc = np.uint8 gc = gdal.GDT_Byte # ---------------------------------------------------------------------------- # Convert the LUT from a normal array to a numpy style array. if len(lut) <= 256: lookup = np.arange(256) for i in range(min(256, len(lut))): lookup[i] = lut[i] else: lookup = np.arange(65536) for i in range(min(65536, len(lut))): lookup[i] = lut[i] lookup = lookup.astype(tc) # ---------------------------------------------------------------------------- # Open source file if dst_filename is None: src_ds = gdal.Open(src_filename, gdal.GA_Update) dst_ds = src_ds else: src_ds = gdal.Open(src_filename) dst_ds = None if src_ds is None: print("Unable to open ", src_filename) return 1 src_band = src_ds.GetRasterBand(src_band_n) # ---------------------------------------------------------------------------- # Open or create output file. dst_driver = gdal.GetDriverByName(driver_name) if dst_driver is None: print('"%s" driver not registered.' % driver_name) return 1 if dst_ds is None: try: dst_ds = gdal.Open(dst_filename, gdal.GA_Update) except Exception: dst_ds = None if dst_ds is None: dst_ds = dst_driver.Create( dst_filename, src_ds.RasterXSize, src_ds.RasterYSize, 1, gc, options=create_options, ) dst_ds.SetProjection(src_ds.GetProjection()) dst_ds.SetGeoTransform(src_ds.GetGeoTransform()) dst_band = dst_ds.GetRasterBand(dst_band_n) # ---------------------------------------------------------------------------- # Do the processing one scanline at a time. gdal.TermProgress(0.0) for iY in range(src_ds.RasterYSize): src_data = src_band.ReadAsArray(0, iY, src_ds.RasterXSize, 1) dst_data = np.take(lookup, src_data) dst_band.WriteArray(dst_data, 0, iY) gdal.TermProgress((iY + 1.0) / src_ds.RasterYSize) src_ds = None dst_ds = None return 0 if __name__ == "__main__": sys.exit(main(sys.argv))