Skip to content
Snippets Groups Projects
Select Git revision
  • da280b09d8d44d103e39bd209ad179d633f0e81c
  • master default protected
2 results

stats_integration.py

Blame
  • stats_integration.py 16.78 KiB
    import datetime
    from pathlib import Path
    import re
    from collections.abc import Iterable
    
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.optimize import curve_fit
    import pandas as pd
    
    import allantools
    import forcag_tools
    import Rb
    
    cvi_epoch = datetime.datetime(1900, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
    
    # A difference in these keys is not significant to distinguish configs
    allowed_key_diff = {'dds2_f1', 
                        'integrator_center', 
                        'integrator_Rfluctuations',
                        'integrator_relaxed_Rfluctuations',
                        'standby',
                        'resize'
                        } # TODO prendre en compte les champs non pertinents de la config bloch
    
    # A difference in these keys is significant only if there the only ones
    optionnal_key_diff = {'integrator_gain', 
                          'integrator_sweep', 
    
                          }
    
    nu_g = 568.509003
    
    plt.rcParams['figure.constrained_layout.use'] = True
    
    def cvitime2datetime(time):
        delta = np.timedelta64(np.array(time), 's')
    
        return cvi_epoch + delta 
    
    def sqrt_fit(x, A):
        return A / np.sqrt(x)
    
    def fit_adev(allan):
        mask = slice(-4, -1)
        return curve_fit(sqrt_fit, allan['tau'][mask], allan['adev'][mask])[0][0]
    
    def couple_bloch_config(config1, config2):
        """
        Check if the two configs are made to measure the same Bloch config by a differential measurement
    
        Return None if they differs, and the number of bloch frequency measured else.
        """
        if config1 is config2:
            return False
    
        diff = config1.items() ^ config2.items()
    
        diff_keys = {e[0] for e in diff}
        
        if "bloch_config" in diff_keys:
            b1 = config1['bloch_config']
            b2 = config2['bloch_config']
            if np.abs(forcag_tools.compute_bloch(b1.param_seqs[0], b1.freq) -
                      forcag_tools.compute_bloch(b2.param_seqs[0], b2.freq)) < 10e-9 :
                diff_keys.remove("bloch_config")
    
        return all(k in allowed_key_diff for k in diff_keys)
            #return None
    
    def nb_bloch_couple_config(config1, config2):
        """Return the number of bloch frequency measured in a differential measurment"""
    
        bloch_f = 568
        
        if config1['integrator_type'] == 1: # Raman integration
            n_bloch_osc = round((config1['integrator_center'] - config2['integrator_center']) 
                                / bloch_f)
        elif config1['integrator_type'] == 2: # MW integration
            n_bloch_osc = round((config1['dds2_f1'] - config2['dds2_f1']) / bloch_f)
    
        else:
            n_bloch_osc = 1
    
        return n_bloch_osc
    
    def process_data(data_dir, misc=None, plot=True, couples=True):
        data_dir = [Path(data_d) for data_d in data_dir]
    
        if misc:
            if not len(misc) == len(data_dir):
                raise ValueError(f'misc if given must be of equal length ({len(misc)}) than data_dir ({len(data_dir)})')
    
        else:
            misc = [None] * len(data_dir)
    
        # test which config where used
        lock_re = re.compile(r'^lock(\d+)_') 
        config_used = set()
    
        for data_d, misc_desc in zip(data_dir, misc):
            for file in data_d.glob('lock*_data.dat'):
                if match := lock_re.match(file.name):
                    config_used.add((data_d, int(match.group(1)), misc_desc))
    
        if not config_used:
            raise ValueError(f"No integration in path {data_dir}")
    
        config_used = sorted(config_used)
    
        # parse monitors
        data = []
        configs = []
        for data_d, n_config, misc_desc in config_used:
            m = pd.read_csv(data_d / f"lock{n_config}_monitoring.dat", index_col='number', delim_whitespace=True)
            m['Nat'] = m['aire0'] + m['aire1'] 
            m['utc'] = pd.to_datetime(m['utc'], unit='s', origin=pd.Timestamp('1900-01-01'), utc=True)
            #m['config'] = n_config
    
            d = pd.read_csv(data_d / f"lock{n_config}_data.dat", delim_whitespace=True,
                            index_col=0, names=['correction'])
    
            config = forcag_tools.parse_bin_config(data_d / f"config{n_config}")[0]._asdict()
            configs.append(config)
    
            d['freq'] = config['integrator_center'] + d['correction']
            if config['integrator_type'] == 2: #MW
                d['freq'] = config['dds2_f1'] - d['freq'] # TODO take into account phase shift due to pi/2 pulses
            d['misc'] = misc_desc
    
            data.append(m.join(d, how='inner'))
    
        # Name of used config
        list_config_used = tuple(f"{data_d}/config{c}" for data_d, c, misc_desc in config_used)
    
        data = pd.concat(data, keys=list_config_used)
        data.index.rename(['config', 'number'], inplace=True)
    
        # Test all the appariements to find config pairs
        config_couples = ([[couple_bloch_config(c1, c2) for c2 in configs] for c1 in configs]) 
        #print(couple_bloch_config(configs[27], configs[28]))
        list_couples = []
    
        if couples:
            for i in range(len(configs)):
                if sum(config_couples[i]) == 1:
                    j = config_couples[i].index(True)
                    if j > i and sum(config_couples[j]) == 1:
                        list_couples.append((i, j))
    
        bloch_f = []
        result = []
        for k, (i, j) in enumerate(list_couples):
            c1, c2 = configs[i], configs[j]
    
            n_bloch = nb_bloch_couple_config(c2, c1)
    
            d = data.loc[list_config_used[i]].join(data.loc[list_config_used[j]], how='inner', lsuffix='_c1', rsuffix='_c2')
            d['n bloch'] = n_bloch
            bloch_f.append(d)
    
            # Small assumptions : misc description is the same for the two configs
            result.append((k, c1, c2, 'bloch f', list_config_used[i], list_config_used[j], config_used[i][2], i, j))
            result.append((k, c1, c2, 'center f', list_config_used[i], list_config_used[j], config_used[i][2], i, j))
    
        # Gestion des configs sans couple
        config_in_couples = set(i for couple in list_couples for i in couple)
        for k, i in enumerate(set(range(len(list_config_used))) - config_in_couples, start=len(list_couples)):
            result.append((k, configs[i], None, 'single', list_config_used[i], None, config_used[i][2], i, 0))
    
            d = data.loc[list_config_used[i]].join(data.loc[list_config_used[i]], how='inner', lsuffix='_c1', rsuffix='_c2')
            d['n bloch'] = 1
            bloch_f.append(d)
    
        bloch_f = pd.concat(bloch_f, keys=range(len(bloch_f)))
    
        bloch_f['bloch f'] = (bloch_f['freq_c2'] - bloch_f['freq_c1']) / bloch_f['n bloch']
        bloch_f['center f'] = (bloch_f['freq_c2'] + bloch_f['freq_c1']) / 2
    
        result = pd.DataFrame(result, columns=('n_couple', 'config1', 'config2', 'type', 'config used 1', 'config used 2', 'misc', 'n config 1', 'n config 2'))
    
        config_names =  result['config1'].apply(config_name, args=[configs]) 
        result['config_name'] = config_names 
        # Si on a des doublons dans les noms de config, on en reprend des un peu plus détaillés
        if not all(s.is_unique for t, s in result.groupby('type')['config_name']):
            config_names = result['config1'].apply(config_name, args=[configs, False])
            result['config_name'] = config_names 
    
        result.sort_values(by=['config_name'], inplace=True)
    
        #return result, bloch_f
        result = result.join(result.apply(compute_stats, axis=1, args=[bloch_f]))
    
    
        if plot:
            plot_Nat2(data)
            plot_correction(data)
    
            plot_freq(bloch_f, result)
            plot_allan(result)
    
        return result, bloch_f, data
    
    
    def compute_stats(row, bloch_f):
        ftype = row['type']
    
        if ftype != 'single':
            mean = bloch_f.loc[row['n_couple'], ftype].mean()
    
            cycle_time = row['config1']['cycle'] * 1e-3
            allan = allan_dev(bloch_f.loc[row['n_couple'], ftype], rate=1/(2 * cycle_time))
    
        else:
            mean = bloch_f.loc[row['n_couple'], 'center f'].mean()
    
            cycle_time = row['config1']['cycle'] * 1e-3
            allan = allan_dev(bloch_f.loc[row['n_couple'], 'center f'], rate=1/(cycle_time))
    
        allan['mean'] = mean
        allan['std_fit'] = fit_adev(allan)
        
        return allan
    
    def allan_dev(df, rate=1):
        allan = allantools.oadev(df.to_numpy(), data_type="freq", rate=rate)
    
        return pd.Series(allan, index=('tau', 'adev', 'adeverr', 'n'))
    
    def config_name(config, list_configs, skip_otionnal=True):
        """
        Find a human readable and pertinant name for a given config.
    
        Will find all the differences with other config to find 
        a distinguishable name
        """
        diff_keys = set()
    
        diff_keys.update(*[{key for key, item in config.items() ^ conf.items()} for conf in list_configs])
        diff_keys -= allowed_key_diff
    
        # Remove optionnal key difference only if there is something else
        if skip_otionnal and diff_keys - optionnal_key_diff:
            diff_keys -= optionnal_key_diff
    
        def find_field_name(config, list_configs):
            diff_keys = set()
    
            diff_keys.update(*[{key for key, item in config.items() ^ conf.items()} for conf in list_configs])
            diff_keys -= allowed_key_diff
    
            # Remove optionnal key difference only if there is something else
            if skip_otionnal and diff_keys - optionnal_key_diff:
                diff_keys -= optionnal_key_diff
    
            diff_fields = {}
            for key in diff_keys:
                if key in config:
                    if key == 'bloch_config':
                        launch_h = forcag_tools.compute_bloch(config[key].param_seqs[0], config[key].freq) 
                        diff_fields['d mir'] = (forcag_tools.ref_miroir - launch_h) * 1e6
                    elif not isinstance(config[key], tuple):
                        diff_fields[key] = config[key]
                    else:
                        try:
                            key_config = config[key]._asdict()
                            list_key_configs = [c[key]._asdict() for c in list_configs]
                        except AttributeError:
                            key_config = dict(enumerate(config[key]))
                            list_key_configs = [dict(enumerate(c[key])) for c in list_configs]
                        diff_fields.update(find_field_name(key_config, list_key_configs))
            return diff_fields 
    
        diff_fields = find_field_name(config, list_configs)
        name = ', '.join(sorted(f'{k}={v}' for k, v in diff_fields.items()))
        return (name[:60] + '') if len(name) > 60 else name
    
    
        diff_fields = {}
        for key in diff_keys:
            if key in config:
                if not isinstance(config[key], tuple):
                    diff_fields[key] = config[key]
                else:
                    # difference in a tuple, find which elements really differ
                    diff_idx = set()
                    diff_idx.update(*[{i for i in range(len(config[key])) if config[key][i] != c[key][i]}
                                      for c in list_configs])
                    for i in diff_idx:
                        try:
                            idx_name = config[key]._fields[i]
                            kname = f"{idx_name}" # We make the assumption that in the wfm case, only one field will change
                        except AttributeError:
                            kname = f"{key}[{i}]"
    
                        diff_fields[kname] = config[key][i]
    
        return ', '.join(sorted(f'{k}={v}' for k, v in diff_fields.items()))
    
    
    
    def plot_Nat2(monitor):
        # Gestion des coups ratés à faire.
        plt.figure(4)
        plt.clf()
        fig, (ax, ax1) = plt.subplots(2, 1, num=4, constrained_layout=True, sharex=True)
    
        for n_config, m in monitor.groupby(level=0):
            conf_name = '/'.join(n_config.split('/')[-3:])
            ax.plot(m['utc'], m['Nat'], '.', label=conf_name)
    
        ax.legend()
    
        ax.set_xlabel('N coup')
        ax.set_ylabel('Nat')
    
        for n_config, m in monitor.groupby(level=0):
            conf_name = '/'.join(n_config.split('/')[-3:])
            ax1.plot(m['utc'], m['ratio'], '.', label=conf_name)
    
        ax1.legend()
    
        ax1.set_xlabel('N coup')
        ax1.set_ylabel('ratio')
    
    def plot_correction(data):
        fig, ax = plt.subplots(num=6)
        plt.clf()
    
        fig, (ax, ax2) = plt.subplots(2, 1, num=6)
    
        for i, (n_config, d) in enumerate(data.groupby(level=0)):
            ax.plot(d.index.get_level_values('number'), d['correction'], f'-C{i}', label=f"config {n_config}")
            ax.axhline(d['correction'].mean(), ls='--', c=f'C{i}')
    
    
        ax.legend()
    
        ax.set_xlabel('N coup')
        ax.set_ylabel('Correction apportée')
    
        for i, (n_config, d) in enumerate(data.groupby(level=0)):
            ax2.plot(d['utc'], d['correction'], f'-C{i}', label=f"config {n_config}")
            ax2.axhline(d['correction'].mean(), ls='--', c=f'C{i}')
    
    
        ax2.legend()
    
        ax2.set_xlabel('N coup')
        ax2.set_ylabel('Correction apportée')
    
    def plot_freq(data, result):
        data_couple = data[data['bloch f'] != 0]
        data_single = data[data['bloch f'] == 0]
        if not data_couple.empty:
            fig, ax = plt.subplots(num=5)
            plt.clf()
    
            fig, (ax1, ax2) = plt.subplots(2, 1, num=5)
    
            ax1.axhline(nu_g, ls=':',  c='k')
            #for i, (couple, d) in enumerate(data_couple.groupby(level=0)):
            #    r = result[(result['n couple'] == couple) & (result['type'] == 'bloch f')].iloc[0]
            #    ax1.plot(d.index.get_level_values('number'), d['bloch f'], '-', label=r['config_name'])
            #    ax1.axhline(r['mean'], ls='--', color=f'C{i}')
    
            for i, row in enumerate(result[result['type'] == 'bloch f'].itertuples()):
                d = data.loc[row.n_couple]
                ax1.plot(d.index.get_level_values('number'), d['bloch f'], '-', label=row.config_name)
                ax1.axhline(row.mean, ls='--', color=f'C{i}')
    
            ax1.legend()
    
            ax1.set_xlabel('N coup')
            ax1.set_ylabel('Bloch frequency (Hz)')
    
            for i, row in enumerate(result[result['type'] == 'center f'].itertuples()):
                d = data.loc[row.n_couple]
                ax2.plot(d.index.get_level_values('number'), d['center f'], '-', label=row.config_name)
    
            ax2.legend()
    
            ax2.set_xlabel('N coup')
            ax2.set_ylabel('Center frequency (Hz)')
    
        if not data_single.empty:
            fig, ax = plt.subplots(num=8)
            plt.clf()
    
            fig, ax1 = plt.subplots(1, 1, num=8)
    
            for i, row in enumerate(result[result['type'] == 'single'].itertuples()):
                d = data.loc[row.n_couple]
                ax1.plot(d.index.get_level_values('number'), d['center f'], '-', label=row.config_name)
                ax1.axhline(row.mean, ls='--', color=f'C{i}')
    
            ax1.legend()
    
            ax1.set_xlabel('N coup')
            ax1.set_ylabel('Center frequency (Hz)')
    
    def plot_allan(result):
        bloch_f = result[result['type'] == 'bloch f']
        center_f = result[result['type'] == 'center f']
        single_f = result[result['type'] == 'single']
    
        if not bloch_f.empty:
    
            fig, ax = plt.subplots(num=7)
            plt.clf()
            fig, axs = plt.subplots(2, 1, num=7, constrained_layout=True, sharex=True)
    
    
            max_tau = max(max(t for t in r) for r in result['tau'])
            min_tau = min(min(t for t in r) for r in result['tau'])
            t = np.linspace(min_tau, max_tau)
            for data, ax in zip((bloch_f, center_f), axs):
                for k, i in enumerate(data.itertuples()):
                    ax.errorbar(i.tau, i.adev, i.adeverr, capsize=3, label=i.config_name, color=f'C{k}')
                    ax.plot(t, sqrt_fit(t, i.std_fit), lw=.4, c=f'C{k}')
    
    
                ax.set_xscale('log')
                ax.set_yscale('log')
    
                ax.grid(True, which="minor", ls="-", color='0.65')
                ax.grid(True, which="major", ls="-", color='0.25')
    
    
                ax.legend()
    
            axs[0].set_ylabel('Bloch frequency (Hz)')
            axs[1].set_ylabel('Shift frequency (Hz)')
            axs[1].set_xlabel('Averaging time (s)')
    
            fig.suptitle('Overlapping Allan deviation')
    
            #for couple, d in data.groupby(level=0):
            #    ax2.plot(d.index.get_level_values('number'), d['center f'], '-', label=f"config{couple}")
    
            #ax2.legend()
    
            #ax2.set_xlabel('N coup')
            #ax2.set_ylabel('Center frequency')
            
        if not single_f.empty:
            fig, ax = plt.subplots(num=9)
            plt.clf()
            fig, ax = plt.subplots(1, 1, num=9, constrained_layout=True)
    
    
            max_tau = max(max(t for t in r) for r in single_f['tau'])
            min_tau = min(min(t for t in r) for r in single_f['tau'])
            t = np.linspace(min_tau, max_tau)
            for k, i in enumerate(single_f.itertuples()):
                ax.errorbar(i.tau, i.adev, i.adeverr, capsize=3, label=i.config_name, color=f'C{k}')
                ax.plot(t, sqrt_fit(t, i.std_fit), lw=.4, c=f'C{k}')
    
                #if std:
                #    pass
                #    ax.plot(taus, sqrt_fit(taus, std), '--', label='A / sqrt(τ)')
    
    
            ax.set_xlabel('Averaging time (s)')
    
            ax.set_xscale('log')
            ax.set_yscale('log')
    
            ax.grid(True, which="minor", ls="-", color='0.65')
            ax.grid(True, which="major", ls="-", color='0.25')
    
    
            ax.legend()
    
            ax.set_ylabel('Peak frequency (Hz)')
            ax.set_xlabel('Averaging time (s)')
    
            fig.suptitle('Overlapping Allan deviation')
    
    
    if __name__ == '__main__':
        import argparse
    
        parser = argparse.ArgumentParser(description="Processing of ForcaG integration data")
        parser.add_argument("dir", help="integration data path", nargs='+')
        args = parser.parse_args()
    
        r = process_data(args.dir, plot=True)
        plt.show()